This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new a6115b172f1 Allow forcing trace head-sampling per DAG run via conf
(#68860)
a6115b172f1 is described below
commit a6115b172f168841e27ee9fe2b94987791dda293
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Jun 24 13:06:17 2026 -0700
Allow forcing trace head-sampling per DAG run via conf (#68860)
Deployments running a head sampler sometimes need to guarantee a specific
run
is traced (debugging a flaky DAG) or excluded (a noisy run), regardless of
the
configured sampling ratio. A reserved run conf key, airflow/trace_sampled,
provides that per-run override: true always traces the run, false never
does,
and absent leaves the decision to the sampler. Only an explicit bool is
honored
so a malformed value can neither silently change sampling nor fail run
creation.
---
airflow-core/src/airflow/models/dagrun.py | 20 +++++++++
airflow-core/src/airflow/models/taskinstance.py | 7 ++-
airflow-core/tests/unit/models/test_dagrun.py | 32 ++++++++++++++
.../tests/unit/models/test_taskinstance.py | 18 ++++++++
.../observability/traces/__init__.py | 50 ++++++++++++++--------
.../tests/observability/test_traces.py | 39 +++++++++++++++++
6 files changed, 146 insertions(+), 20 deletions(-)
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index f5ae6a28d58..056151cefa1 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -65,6 +65,7 @@ from sqlalchemy.sql.functions import coalesce
from airflow._shared.observability.metrics import stats
from airflow._shared.observability.traces import (
TASK_SPAN_DETAIL_LEVEL_KEY,
+ TRACE_SAMPLED_KEY,
new_dagrun_trace_carrier,
override_ids,
)
@@ -176,6 +177,24 @@ def dagrun_trace_attributes(dr) -> dict[str, str]:
return attributes
+def trace_sampled_override(conf) -> bool | None:
+ """
+ Head-sampling override from the ``airflow/trace_sampled`` run conf key.
+
+ Returns the forced SAMPLED flag only when the conf value is an explicit
bool
+ (True = always trace this run, False = never); otherwise None, meaning no
+ override and the configured sampler decides. Non-bool values are ignored
+ rather than coerced, so a malformed value never silently flips sampling or
+ fails run creation.
+ """
+ if not conf:
+ return None
+ raw = conf.get(TRACE_SAMPLED_KEY)
+ if isinstance(raw, bool):
+ return raw
+ return None
+
+
class DagRun(Base, LoggingMixin):
"""
Invocation instance of a DAG.
@@ -409,6 +428,7 @@ class DagRun(Base, LoggingMixin):
self.context_carrier: dict[str, str] = new_dagrun_trace_carrier(
task_span_detail_level=self.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY,
None),
attributes=dagrun_trace_attributes(self), # these are for
potential use by head sampler
+ force_sampled=trace_sampled_override(self.conf),
)
if not isinstance(partition_key, str | None):
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index f905e4c8e3d..fcf4367c910 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -407,7 +407,11 @@ def clear_task_instances(
session.merge(ti)
if dag_run_state is not False and tis:
- from airflow.models.dagrun import DagRun, dagrun_trace_attributes #
Avoid circular import
+ from airflow.models.dagrun import ( # Avoid circular import
+ DagRun,
+ dagrun_trace_attributes,
+ trace_sampled_override,
+ )
run_ids_by_dag_id = defaultdict(set)
for instance in tis:
@@ -431,6 +435,7 @@ def clear_task_instances(
dr.context_carrier = new_dagrun_trace_carrier(
task_span_detail_level=dr.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY)
if dr.conf else None,
attributes=dagrun_trace_attributes(dr),
+ force_sampled=trace_sampled_override(dr.conf),
)
_recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session)
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index a2e852f5834..4ec048a7f4b 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -4410,6 +4410,38 @@ class TestDagRunTracing:
span = trace.get_current_span(ctx)
assert get_task_span_detail_level(span) == 2
+ @pytest.mark.parametrize(
+ ("conf", "expected"),
+ [
+ ({"airflow/trace_sampled": True}, True),
+ ({"airflow/trace_sampled": False}, False),
+ ({}, None),
+ (None, None),
+ ({"airflow/trace_sampled": "true"}, None),
+ ({"airflow/trace_sampled": 1}, None),
+ ({"other": True}, None),
+ ],
+ )
+ def test_trace_sampled_override(self, conf, expected):
+ """Only an explicit bool conf value is honored; anything else falls
through to the sampler."""
+ from airflow.models.dagrun import trace_sampled_override
+
+ assert trace_sampled_override(conf) is expected
+
+ @pytest.mark.parametrize("flag", [True, False])
+ def test_context_carrier_honors_trace_sampled_conf(self, dag_maker, flag):
+ """airflow/trace_sampled in conf forces the carrier's SAMPLED flag
regardless of the sampler."""
+ from opentelemetry import trace
+ from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
+
+ with dag_maker("test_trace_sampled_conf"):
+ EmptyOperator(task_id="t1")
+ dr = dag_maker.create_dagrun(conf={"airflow/trace_sampled": flag})
+
+ ctx = TraceContextTextMapPropagator().extract(dr.context_carrier)
+ span_ctx = trace.get_current_span(ctx).get_span_context()
+ assert span_ctx.trace_flags.sampled is flag
+
class TestDagRunStatsTagsTeamName:
def test_stats_tags_without_team_name(self, dag_maker):
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 97ec0242235..9b30c71cfa9 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -4076,6 +4076,24 @@ def
test_clear_task_instances_preserves_detail_level(dag_maker, session):
assert get_task_span_detail_level(span) == 2
[email protected]_test
[email protected]("flag", [True, False])
+def test_clear_task_instances_honors_trace_sampled_conf(dag_maker, session,
flag):
+ """The regenerated carrier honors the airflow/trace_sampled override from
dag run conf."""
+ with dag_maker("test_clear_trace_sampled"):
+ EmptyOperator(task_id="t1")
+ dag_run = dag_maker.create_dagrun(conf={"airflow/trace_sampled": flag})
+ ti = dag_run.get_task_instance("t1", session=session)
+ ti.state = TaskInstanceState.SUCCESS
+ session.flush()
+
+ clear_task_instances([ti], session)
+
+ new_ctx = TraceContextTextMapPropagator().extract(dag_run.context_carrier)
+ span_ctx = trace.get_current_span(new_ctx).get_span_context()
+ assert span_ctx.trace_flags.sampled is flag
+
+
@pytest.mark.db_test
def test_task_instance_repr_does_not_raise_for_deferred_columns(dag_maker,
session):
"""``TaskInstance.__repr__`` must survive *any* deferred column it reads.
diff --git
a/shared/observability/src/airflow_shared/observability/traces/__init__.py
b/shared/observability/src/airflow_shared/observability/traces/__init__.py
index ec2ee6e2a0c..f0ef4352527 100644
--- a/shared/observability/src/airflow_shared/observability/traces/__init__.py
+++ b/shared/observability/src/airflow_shared/observability/traces/__init__.py
@@ -59,43 +59,55 @@ class OverrideableRandomIdGenerator(RandomIdGenerator):
TASK_SPAN_DETAIL_LEVEL_KEY = "airflow/task_span_detail_level"
DEFAULT_TASK_SPAN_DETAIL_LEVEL = 1
+TRACE_SAMPLED_KEY = "airflow/trace_sampled"
-def new_dagrun_trace_carrier(task_span_detail_level=None, attributes=None) ->
dict[str, str]:
+def new_dagrun_trace_carrier(
+ task_span_detail_level=None, attributes=None, force_sampled=None
+) -> dict[str, str]:
"""
Generate a fresh W3C traceparent carrier without creating a recordable
span.
The SAMPLED flag is set from an honest *root* sampling decision made by the
configured tracer provider's sampler (driven by ``OTEL_TRACES_SAMPLER`` /
``OTEL_TRACES_SAMPLER_ARG``), rather than being hardcoded. This makes the
- carrier the single head-sampling decision point for a DAG run: every
- downstream span (dag_run, task_run, worker) rides on this flag.
+ carrier the head-sampling decision point for a DAG run: every downstream
span
+ (dag_run, task_run, worker) rides on this flag.
``attributes`` are forwarded to the sampler as ``should_sample``
attributes so
a custom sampler can differentiate the decision by run kind (e.g. by
``airflow.dag_id`` / ``airflow.dag_run.run_type``). The built-in samplers
ignore them.
They are decision input only -- they are not persisted in the carrier.
+
+ ``force_sampled`` overrides the sampler entirely: when not None it sets the
+ SAMPLED flag directly (True = always trace this run, False = never) and the
+ sampler is not consulted. Airflow wires this from the
``airflow/trace_sampled``
+ run conf key; when None the configured sampler makes the decision.
"""
gen = RandomIdGenerator()
trace_id = gen.generate_trace_id()
- provider = trace.get_tracer_provider()
- sampler = getattr(provider, "sampler", None)
- if sampler is not None:
- result = sampler.should_sample(
- parent_context=None, # root decision
- trace_id=trace_id,
- name="dag_run",
- attributes=attributes or {},
- )
- sampled = result.decision == Decision.RECORD_AND_SAMPLE
- sampler_trace_state = result.trace_state
- else:
- # No sampler attribute means a proxy/no-op provider (otel disabled).
- # Nothing exports in that case, so the flag is irrelevant; mirror the
- # observable behavior of today when otel is off.
- sampled = False
+ if force_sampled is not None:
+ sampled = force_sampled
sampler_trace_state = None
+ else:
+ provider = trace.get_tracer_provider()
+ sampler = getattr(provider, "sampler", None)
+ if sampler is not None:
+ result = sampler.should_sample(
+ parent_context=None, # root decision
+ trace_id=trace_id,
+ name="dag_run",
+ attributes=attributes or {},
+ )
+ sampled = result.decision == Decision.RECORD_AND_SAMPLE
+ sampler_trace_state = result.trace_state
+ else:
+ # No sampler attribute means a proxy/no-op provider (otel
disabled).
+ # Nothing exports in that case, so the flag is irrelevant; mirror
the
+ # observable behavior of today when otel is off.
+ sampled = False
+ sampler_trace_state = None
# Preserve the detail-level tracestate by merging it onto whatever the
# sampler returned. TraceState is immutable, so update() returns a new one.
diff --git a/shared/observability/tests/observability/test_traces.py
b/shared/observability/tests/observability/test_traces.py
index ba09ad59fe0..e2b43da9237 100644
--- a/shared/observability/tests/observability/test_traces.py
+++ b/shared/observability/tests/observability/test_traces.py
@@ -210,6 +210,45 @@ class TestNewDagrunTraceCarrierSampling:
new_dagrun_trace_carrier()
assert captured["attributes"] == {}
+ def test_force_sampled_true_overrides_sampler(self, with_sampler):
+ """force_sampled=True samples the run even when the sampler says no."""
+ with_sampler(ALWAYS_OFF)
+ assert
_carrier_is_sampled(new_dagrun_trace_carrier(force_sampled=True)) is True
+
+ def test_force_sampled_false_overrides_sampler(self, with_sampler):
+ """force_sampled=False drops the run even when the sampler says yes."""
+ with_sampler(ParentBased(ALWAYS_ON))
+ assert
_carrier_is_sampled(new_dagrun_trace_carrier(force_sampled=False)) is False
+
+ def test_force_sampled_bypasses_sampler(self, monkeypatch):
+ """When force_sampled is set, the sampler is not consulted at all."""
+ called = False
+
+ class _RecordingSampler:
+ def should_sample(self, *args, **kwargs):
+ nonlocal called
+ called = True
+ return ALWAYS_ON.should_sample(*args, **kwargs)
+
+ class _Provider:
+ sampler = _RecordingSampler()
+
+ monkeypatch.setattr(
+ "airflow_shared.observability.traces.trace.get_tracer_provider",
+ lambda: _Provider(),
+ )
+ new_dagrun_trace_carrier(force_sampled=True)
+ assert called is False
+
+ def test_force_sampled_preserves_detail_level(self, with_sampler):
+ """Detail-level tracestate still round-trips when the decision is
forced."""
+ with_sampler(ALWAYS_OFF)
+ carrier = new_dagrun_trace_carrier(task_span_detail_level=2,
force_sampled=True)
+ ctx = TraceContextTextMapPropagator().extract(carrier)
+ span = trace.get_current_span(ctx)
+ assert get_task_span_detail_level(span) == 2
+ assert _carrier_is_sampled(carrier) is True
+
class TestGetTaskSpanDetailLevel:
def _make_span_with_trace_state(self, entries: list[tuple[str, str]]) ->
NonRecordingSpan: