From 986350614fb035bad02fecbeddd6aa9475b62d0a Mon Sep 17 00:00:00 2001 From: Daniel Szoke Date: Tue, 11 Feb 2025 14:10:24 +0100 Subject: [PATCH] feat(tracing): Backfill missing `sample_rand` on `PropagationContext` Whenever the `PropagationContext` continues an incoming trace (i.e. whenever the `trace_id` is set, rather than being randomly generated as for a new trace), check if the `sample_rand` is present and valid in the incoming DSC. If the `sample_rand` is missing, generate it deterministically based on the `trace_id` and backfill it into the DSC on the `PropagationContext`. When generating the backfilled `sample_rand`, we ensure the generated value is consistent with the incoming trace's sampling decision and sample rate, if both of these are present. Otherwise, we generate a new value in the range [0, 1). Future PRs will address propagating the `sample_rand` to transactions generated with `continue_trace` (allowing the `sample_rand` to be propagated on outgoing traces), and will also allow `sample_rand` to be used for making sampling decisions. Ref #3998 --- sentry_sdk/tracing_utils.py | 172 ++++++++++++++++++++++++++++++- sentry_sdk/utils.py | 17 ++- tests/test_api.py | 5 +- tests/test_propagationcontext.py | 59 +++++++++++ 4 files changed, 249 insertions(+), 4 deletions(-) diff --git a/sentry_sdk/tracing_utils.py b/sentry_sdk/tracing_utils.py index ae72b8cce9..63de492ddb 100644 --- a/sentry_sdk/tracing_utils.py +++ b/sentry_sdk/tracing_utils.py @@ -1,4 +1,5 @@ import contextlib +from decimal import Decimal import inspect import os import re @@ -6,6 +7,7 @@ from collections.abc import Mapping from datetime import timedelta from functools import wraps +from random import Random from urllib.parse import quote, unquote import uuid @@ -19,6 +21,7 @@ match_regex_list, qualname_from_function, to_string, + try_decimal, is_sentry_url, _is_external_source, _is_in_project_root, @@ -418,6 +421,9 @@ def from_incoming_data(cls, incoming_data): propagation_context = PropagationContext() propagation_context.update(sentrytrace_data) + if propagation_context is not None: + propagation_context._fill_sample_rand() + return propagation_context @property @@ -425,6 +431,7 @@ def trace_id(self): # type: () -> str """The trace id of the Sentry trace.""" if not self._trace_id: + # New trace, don't fill in sample_rand self._trace_id = uuid.uuid4().hex return self._trace_id @@ -469,6 +476,54 @@ def __repr__(self): self.dynamic_sampling_context, ) + def _fill_sample_rand(self): + # type: () -> None + """ + Ensure that there is a valid sample_rand value in the dynamic_sampling_context. + + If there is a valid sample_rand value in the dynamic_sampling_context, we keep it. + Otherwise, we generate a sample_rand value according to the following: + + - If we have a parent_sampled value and a sample_rate in the DSC, we compute + a sample_rand value randomly in the range: + - [0, sample_rate) if parent_sampled is True, + - or, in the range [sample_rate, 1) if parent_sampled is False. + + - If either parent_sampled or sample_rate is missing, we generate a random + value in the range [0, 1). + + The sample_rand is deterministically generated from the trace_id, if present. + + This function does nothing if there is no dynamic_sampling_context. + """ + if self.dynamic_sampling_context is None: + return + + sample_rand = SampleRandValue.try_from_incoming( + self.dynamic_sampling_context.get("sample_rand") + ) + if sample_rand is not None and 0 <= sample_rand.inner() < 1: + # sample_rand is present and valid, so don't overwrite it + return + + # Get the sample rate and compute the transformation that will map the random value + # to the desired range: [0, 1), [0, sample_rate), or [sample_rate, 1). + sample_rate = try_decimal(self.dynamic_sampling_context.get("sample_rate")) + lower, upper = _sample_rand_range(self.parent_sampled, sample_rate) + + try: + sample_rand = SampleRandValue.generate( + self.trace_id, interval=(lower, upper) + ) + except ValueError: + # ValueError is raised if the interval is invalid, i.e. lower >= upper. + # lower >= upper might happen if the incoming trace's sampled flag + # and sample_rate are inconsistent, e.g. sample_rate=0.0 but sampled=True. + # We cannot generate a sensible sample_rand value in this case. + return + + self.dynamic_sampling_context["sample_rand"] = str(sample_rand) + class Baggage: """ @@ -643,9 +698,109 @@ def __repr__(self): return f'' +class SampleRandValue: + """ + Lightweight wrapper around a Decimal value, with utilities for + generating a sample rand value from a trace ID, parsing incoming + sample_rand values, and for consistent serialization to a string. + + SampleRandValue instances are immutable. + """ + + DECIMAL_0 = Decimal(0) + DECIMAL_1 = Decimal(1) + + PRECISION = 6 + """We use this many decimal places for the sample_rand value. + + If this value ever needs to be changed, also update the formatting + in the __str__ method. + """ + + def __init__(self, value): + # type: (Decimal) -> None + """ + Initialize SampleRandValue from a Decimal value. This constructor + should only be called internally by the SampleRandValue class. + """ + self._value = value + + @classmethod + def try_from_incoming(cls, incoming_value): + # type: (Optional[str]) -> Optional[SampleRandValue] + """ + Attempt to parse an incoming sample_rand value from a string. + + Returns None if the incoming value is None or cannot be parsed as a Decimal. + """ + value = try_decimal(incoming_value) + if value is not None and cls.DECIMAL_0 <= value < cls.DECIMAL_1: + return cls(value) + + return None + + @classmethod + def generate( + cls, + trace_id, # type: Optional[str] + *, + interval=(DECIMAL_0, DECIMAL_1), # type: tuple[Decimal, Decimal] + ): + # type: (...) -> SampleRandValue + """Generate a sample_rand value from a trace ID. + + The generated value will be pseudorandomly chosen from the provided + interval. Specifically, given (lower, upper) = interval, the generated + value will be in the range [lower, upper). + + The pseudorandom number generator is seeded with the trace ID. + """ + lower_decimal, upper_decimal = interval + if not lower_decimal < upper_decimal: + raise ValueError("Invalid interval: lower must be less than upper") + + # Since sample_rand values have 6-digit precision, we generate the + # value as an integer in the range [lower_decimal * 10**6, upper_decimal * 10**6), + # and then scale it to the desired range. + lower_int = int(lower_decimal.scaleb(cls.PRECISION)) + upper_int = int(upper_decimal.scaleb(cls.PRECISION)) + + if lower_int == upper_int: + # Edge case: lower_decimal < upper_decimal, but due to rounding, + # lower_int == upper_int. In this case, we return + # lower_int.scaleb(-SCALE_EXPONENT) here, since calling randrange() + # with the same lower and upper bounds will raise an error. + return cls(Decimal(lower_int).scaleb(-cls.PRECISION)) + + value = Random(trace_id).randrange(lower_int, upper_int) + return cls(Decimal(value).scaleb(-cls.PRECISION)) + + def inner(self): + # type: () -> Decimal + """ + Return the inner Decimal value. + """ + return self._value + + def __str__(self): + # type: () -> str + """ + Return a string representation of the SampleRandValue. + + The string representation has 6 decimal places. + """ + # Lint E231 is a false-positive here. If we add a space after the :, + # then the formatter puts an extra space before the decimal numbers. + return f"{self._value:.6f}" # noqa: E231 + + def __repr__(self): + # type: () -> str + return f"" + + def should_propagate_trace(client, url): # type: (sentry_sdk.client.BaseClient, str) -> bool - """ + """u Returns True if url matches trace_propagation_targets configured in the given client. Otherwise, returns False. """ trace_propagation_targets = client.options["trace_propagation_targets"] @@ -748,6 +903,21 @@ def get_current_span(scope=None): return current_span +def _sample_rand_range(parent_sampled, sample_rate): + # type: (Optional[bool], Optional[Decimal]) -> tuple[Decimal, Decimal] + """ + Compute the lower (inclusive) and upper (exclusive) bounds of the range of values + that a generated sample_rand value must fall into, given the parent_sampled and + sample_rate values. + """ + if parent_sampled is None or sample_rate is None: + return Decimal(0), Decimal(1) + elif parent_sampled is True: + return Decimal(0), sample_rate + else: # parent_sampled is False + return sample_rate, Decimal(1) + + # Circular imports from sentry_sdk.tracing import ( BAGGAGE_HEADER_NAME, diff --git a/sentry_sdk/utils.py b/sentry_sdk/utils.py index b2a39b7af1..82155890f1 100644 --- a/sentry_sdk/utils.py +++ b/sentry_sdk/utils.py @@ -12,7 +12,7 @@ import time from collections import namedtuple from datetime import datetime, timezone -from decimal import Decimal +from decimal import Decimal, InvalidOperation from functools import partial, partialmethod, wraps from numbers import Real from urllib.parse import parse_qs, unquote, urlencode, urlsplit, urlunsplit @@ -1888,3 +1888,18 @@ def should_be_treated_as_error(ty, value): return False return True + + +def try_decimal(value): + # type: (Optional[str]) -> Optional[Decimal] + """Small utility which attempts to convert an Optional[str] to a Decimal. + + Returns None if the value is None or if the value cannot be parsed as a Decimal. + """ + if value is None: + return None + + try: + return Decimal(value) + except InvalidOperation: + return None diff --git a/tests/test_api.py b/tests/test_api.py index 3b2a9c8fb7..22fb82cea4 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -111,7 +111,7 @@ def test_continue_trace(sentry_init): transaction = continue_trace( { "sentry-trace": "{}-{}-{}".format(trace_id, parent_span_id, parent_sampled), - "baggage": "sentry-trace_id=566e3688a61d4bc888951642d6f14a19", + "baggage": "sentry-trace_id=566e3688a61d4bc888951642d6f14a19,sentry-sample_rand=0.123456", }, name="some name", ) @@ -123,7 +123,8 @@ def test_continue_trace(sentry_init): assert propagation_context.parent_span_id == parent_span_id assert propagation_context.parent_sampled == parent_sampled assert propagation_context.dynamic_sampling_context == { - "trace_id": "566e3688a61d4bc888951642d6f14a19" + "trace_id": "566e3688a61d4bc888951642d6f14a19", + "sample_rand": "0.123456", } diff --git a/tests/test_propagationcontext.py b/tests/test_propagationcontext.py index 85f82913f8..6de80711a7 100644 --- a/tests/test_propagationcontext.py +++ b/tests/test_propagationcontext.py @@ -1,6 +1,16 @@ +import pytest + from sentry_sdk.tracing_utils import PropagationContext +SAMPLED_FLAG = { + None: "", + False: "-0", + True: "-1", +} +"""Maps the `sampled` value to the flag appended to the sentry-trace header.""" + + def test_empty_context(): ctx = PropagationContext() @@ -51,6 +61,7 @@ def test_lazy_uuids(): def test_property_setters(): ctx = PropagationContext() + ctx.trace_id = "X234567890abcdef1234567890abcdef" ctx.span_id = "X234567890abcdef" @@ -58,6 +69,7 @@ def test_property_setters(): assert ctx.trace_id == "X234567890abcdef1234567890abcdef" assert ctx._span_id == "X234567890abcdef" assert ctx.span_id == "X234567890abcdef" + assert ctx.dynamic_sampling_context is None def test_update(): @@ -81,3 +93,50 @@ def test_update(): assert ctx.dynamic_sampling_context is None assert not hasattr(ctx, "foo") + + +def test_existing_sample_rand_kept(): + ctx = PropagationContext( + trace_id="00000000000000000000000000000000", + dynamic_sampling_context={"sample_rand": "0.5"}, + ) + + # If sample_rand was regenerated, the value would be 0.8766381713144122 based on the trace_id + assert ctx.dynamic_sampling_context["sample_rand"] == "0.5" + + +@pytest.mark.parametrize( + ("parent_sampled", "sample_rate", "expected_sample_rand"), + ( + # Note that parent_sampled and sample_rate do not scale the + # sample_rand value, only determine the range of the value. + # Expected values are determined by parent_sampled, sample_rate, + # and the trace_id. + (None, None, "0.919221"), + (None, "0.5", "0.919221"), + (False, None, "0.919221"), + (True, None, "0.919221"), + (False, "0.0", "0.919221"), + (False, "0.01", "0.929221"), + (True, "0.01", "0.006073"), + (False, "0.1", "0.762590"), + (True, "0.1", "0.082823"), + (False, "0.5", "0.959610"), + (True, "0.5", "0.459610"), + (True, "1.0", "0.919221"), + ), +) +def test_sample_rand_filled(parent_sampled, sample_rate, expected_sample_rand): + """When continuing a trace, we want to fill in the sample_rand value if it's missing.""" + dsc = {} + if sample_rate is not None: + dsc["sample_rate"] = sample_rate + + ctx = PropagationContext().from_incoming_data( + { + "sentry-trace": f"00000000000000000000000000000000-0000000000000000{SAMPLED_FLAG[parent_sampled]}", + "baggage": f"sentry-sample_rate={sample_rate}", + } + ) + + assert ctx.dynamic_sampling_context["sample_rand"] == expected_sample_rand