Skip to content

Commit

Permalink
feat(tracing): Backfill missing sample_rand on PropagationContext
Browse files Browse the repository at this point in the history
Whenever the `PropagationContext` continues an incoming trace (i.e. whenever the `trace_id` is set, rather than being randomly generated as for a new trace), check if the `sample_rand` is present and valid in the incoming DSC. If the `sample_rand` is missing, generate it deterministically based on the `trace_id` and backfill it into the DSC on the `PropagationContext`.

When generating the backfilled `sample_rand`, we ensure the generated value is consistent with the incoming trace's sampling decision and sample rate, if both of these are present. Otherwise, we generate a new value in the range [0, 1).

Future PRs will address propagating the `sample_rand` to transactions generated with `continue_trace` (allowing the `sample_rand` to be propagated on outgoing traces), and will also allow `sample_rand` to be used for making sampling decisions.

Ref #3998
  • Loading branch information
szokeasaurusrex committed Feb 25, 2025
1 parent 189e4a9 commit 2b3f4f7
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 2 deletions.
105 changes: 105 additions & 0 deletions sentry_sdk/tracing_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import contextlib
from decimal import ROUND_DOWN, Decimal
import inspect
import os
import re
import sys
from collections.abc import Mapping
from datetime import timedelta
from functools import wraps
from random import Random
from urllib.parse import quote, unquote
import uuid

Expand All @@ -19,6 +21,7 @@
match_regex_list,
qualname_from_function,
to_string,
try_convert,
is_sentry_url,
_is_external_source,
_is_in_project_root,
Expand All @@ -45,6 +48,7 @@
"[ \t]*$" # whitespace
)


# This is a normal base64 regex, modified to reflect that fact that we strip the
# trailing = or == off
base64_stripped = (
Expand Down Expand Up @@ -418,13 +422,17 @@ def from_incoming_data(cls, incoming_data):
propagation_context = PropagationContext()
propagation_context.update(sentrytrace_data)

if propagation_context is not None:
propagation_context._fill_sample_rand()

return propagation_context

@property
def trace_id(self):
# type: () -> str
"""The trace id of the Sentry trace."""
if not self._trace_id:
# New trace, don't fill in sample_rand
self._trace_id = uuid.uuid4().hex

return self._trace_id
Expand Down Expand Up @@ -469,6 +477,60 @@ def __repr__(self):
self.dynamic_sampling_context,
)

def _fill_sample_rand(self):
# type: () -> None
"""
Ensure that there is a valid sample_rand value in the dynamic_sampling_context.
If there is a valid sample_rand value in the dynamic_sampling_context, we keep it.
Otherwise, we generate a sample_rand value according to the following:
- If we have a parent_sampled value and a sample_rate in the DSC, we compute
a sample_rand value randomly in the range:
- [0, sample_rate) if parent_sampled is True,
- or, in the range [sample_rate, 1) if parent_sampled is False.
- If either parent_sampled or sample_rate is missing, we generate a random
value in the range [0, 1).
The sample_rand is deterministically generated from the trace_id, if present.
This function does nothing if there is no dynamic_sampling_context.
"""
if self.dynamic_sampling_context is None:
return

sample_rand = try_convert(
Decimal, self.dynamic_sampling_context.get("sample_rand")
)
if sample_rand is not None and 0 <= sample_rand < 1:
# sample_rand is present and valid, so don't overwrite it
return

# Get the sample rate and compute the transformation that will map the random value
# to the desired range: [0, 1), [0, sample_rate), or [sample_rate, 1).
sample_rate = try_convert(
float, self.dynamic_sampling_context.get("sample_rate")
)
lower, upper = _sample_rand_range(self.parent_sampled, sample_rate)

try:
sample_rand = _generate_sample_rand(self.trace_id, interval=(lower, upper))
except ValueError:
# ValueError is raised if the interval is invalid, i.e. lower >= upper.
# lower >= upper might happen if the incoming trace's sampled flag
# and sample_rate are inconsistent, e.g. sample_rate=0.0 but sampled=True.
# We cannot generate a sensible sample_rand value in this case.
logger.debug(
f"Could not backfill sample_rand, since parent_sampled={self.parent_sampled} "
f"and sample_rate={sample_rate}."
)
return

self.dynamic_sampling_context["sample_rand"] = (
f"{sample_rand:.6f}" # noqa: E231
)


class Baggage:
"""
Expand Down Expand Up @@ -748,6 +810,49 @@ def get_current_span(scope=None):
return current_span


def _generate_sample_rand(
trace_id, # type: Optional[str]
*,
interval=(0.0, 1.0), # type: tuple[float, float]
):
# type: (...) -> Decimal
"""Generate a sample_rand value from a trace ID.
The generated value will be pseudorandomly chosen from the provided
interval. Specifically, given (lower, upper) = interval, the generated
value will be in the range [lower, upper). The value has 6-digit precision,
so when printing with .6f, the value will never be rounded up.
The pseudorandom number generator is seeded with the trace ID.
"""
lower, upper = interval
if not lower < upper: # using `if lower >= upper` would handle NaNs incorrectly
raise ValueError("Invalid interval: lower must be less than upper")

rng = Random(trace_id)
sample_rand = upper
while sample_rand >= upper:
sample_rand = rng.uniform(lower, upper)

# Round down to exactly six decimal-digit precision.
return Decimal(sample_rand).quantize(Decimal("0.000001"), rounding=ROUND_DOWN)


def _sample_rand_range(parent_sampled, sample_rate):
# type: (Optional[bool], Optional[float]) -> tuple[float, float]
"""
Compute the lower (inclusive) and upper (exclusive) bounds of the range of values
that a generated sample_rand value must fall into, given the parent_sampled and
sample_rate values.
"""
if parent_sampled is None or sample_rate is None:
return 0.0, 1.0
elif parent_sampled is True:
return 0.0, sample_rate
else: # parent_sampled is False
return sample_rate, 1.0


# Circular imports
from sentry_sdk.tracing import (
BAGGAGE_HEADER_NAME,
Expand Down
17 changes: 17 additions & 0 deletions sentry_sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,3 +1888,20 @@ def should_be_treated_as_error(ty, value):
return False

return True


if TYPE_CHECKING:
T = TypeVar("T")


def try_convert(convert_func, value):
# type: (Callable[[Any], T], Any) -> Optional[T]
"""
Attempt to convert from an unknown type to a specific type, using the
given function. Return None if the conversion fails, i.e. if the function
raises an exception.
"""
try:
return convert_func(value)
except Exception:
return None
5 changes: 3 additions & 2 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_continue_trace(sentry_init):
transaction = continue_trace(
{
"sentry-trace": "{}-{}-{}".format(trace_id, parent_span_id, parent_sampled),
"baggage": "sentry-trace_id=566e3688a61d4bc888951642d6f14a19",
"baggage": "sentry-trace_id=566e3688a61d4bc888951642d6f14a19,sentry-sample_rand=0.123456",
},
name="some name",
)
Expand All @@ -123,7 +123,8 @@ def test_continue_trace(sentry_init):
assert propagation_context.parent_span_id == parent_span_id
assert propagation_context.parent_sampled == parent_sampled
assert propagation_context.dynamic_sampling_context == {
"trace_id": "566e3688a61d4bc888951642d6f14a19"
"trace_id": "566e3688a61d4bc888951642d6f14a19",
"sample_rand": "0.123456",
}


Expand Down
99 changes: 99 additions & 0 deletions tests/test_propagationcontext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
from unittest import mock
from unittest.mock import Mock

import pytest

from sentry_sdk.tracing_utils import PropagationContext


SAMPLED_FLAG = {
None: "",
False: "-0",
True: "-1",
}
"""Maps the `sampled` value to the flag appended to the sentry-trace header."""


def test_empty_context():
ctx = PropagationContext()

Expand Down Expand Up @@ -51,13 +64,15 @@ def test_lazy_uuids():

def test_property_setters():
ctx = PropagationContext()

ctx.trace_id = "X234567890abcdef1234567890abcdef"
ctx.span_id = "X234567890abcdef"

assert ctx._trace_id == "X234567890abcdef1234567890abcdef"
assert ctx.trace_id == "X234567890abcdef1234567890abcdef"
assert ctx._span_id == "X234567890abcdef"
assert ctx.span_id == "X234567890abcdef"
assert ctx.dynamic_sampling_context is None


def test_update():
Expand All @@ -81,3 +96,87 @@ def test_update():
assert ctx.dynamic_sampling_context is None

assert not hasattr(ctx, "foo")


def test_existing_sample_rand_kept():
ctx = PropagationContext(
trace_id="00000000000000000000000000000000",
dynamic_sampling_context={"sample_rand": "0.5"},
)

# If sample_rand was regenerated, the value would be 0.919221 based on the trace_id
assert ctx.dynamic_sampling_context["sample_rand"] == "0.5"


@pytest.mark.parametrize(
("parent_sampled", "sample_rate", "expected_interval"),
(
# Note that parent_sampled and sample_rate do not scale the
# sample_rand value, only determine the range of the value.
# Expected values are determined by parent_sampled, sample_rate,
# and the trace_id.
(None, None, (0.0, 1.0)),
(None, "0.5", (0.0, 1.0)),
(False, None, (0.0, 1.0)),
(True, None, (0.0, 1.0)),
(False, "0.0", (0.0, 1.0)),
(False, "0.01", (0.01, 1.0)),
(True, "0.01", (0.0, 0.01)),
(False, "0.1", (0.1, 1.0)),
(True, "0.1", (0.0, 0.1)),
(False, "0.5", (0.5, 1.0)),
(True, "0.5", (0.0, 0.5)),
(True, "1.0", (0.0, 1.0)),
),
)
def test_sample_rand_filled(parent_sampled, sample_rate, expected_interval):
"""When continuing a trace, we want to fill in the sample_rand value if it's missing."""
if sample_rate is not None:
sample_rate_str = f",sentry-sample_rate={sample_rate}" # noqa: E231
else:
sample_rate_str = ""

# for convenience, we'll just return the lower bound of the interval
mock_uniform = mock.Mock(return_value=expected_interval[0])

def mock_random_class(seed):
assert seed == "00000000000000000000000000000000", "seed should be the trace_id"
rv = Mock()
rv.uniform = mock_uniform
return rv

with mock.patch("sentry_sdk.tracing_utils.Random", mock_random_class):
ctx = PropagationContext().from_incoming_data(
{
"sentry-trace": f"00000000000000000000000000000000-0000000000000000{SAMPLED_FLAG[parent_sampled]}",
# Placeholder is needed, since we only add sample_rand if sentry items are present in baggage
"baggage": f"sentry-placeholder=asdf{sample_rate_str}",
}
)

assert (
ctx.dynamic_sampling_context["sample_rand"]
== f"{expected_interval[0]:.6f}" # noqa: E231
)
assert mock_uniform.call_count == 1
assert mock_uniform.call_args[0] == expected_interval


def test_sample_rand_rounds_down():
# Mock value that should round down to 0.999_999
mock_uniform = mock.Mock(return_value=0.999_999_9)

def mock_random_class(_):
rv = Mock()
rv.uniform = mock_uniform
return rv

with mock.patch("sentry_sdk.tracing_utils.Random", mock_random_class):
ctx = PropagationContext().from_incoming_data(
{
"sentry-trace": "00000000000000000000000000000000-0000000000000000",
"baggage": "sentry-placeholder=asdf",
}
)

assert ctx.dynamic_sampling_context["sample_rand"] == "0.999999"

0 comments on commit 2b3f4f7

Please sign in to comment.