This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a6115b172f1 Allow forcing trace head-sampling per DAG run via conf 
(#68860)
a6115b172f1 is described below

commit a6115b172f168841e27ee9fe2b94987791dda293
Author: Daniel Standish <[email protected]>
AuthorDate: Wed Jun 24 13:06:17 2026 -0700

    Allow forcing trace head-sampling per DAG run via conf (#68860)
    
    Deployments running a head sampler sometimes need to guarantee a specific 
run
    is traced (debugging a flaky DAG) or excluded (a noisy run), regardless of 
the
    configured sampling ratio. A reserved run conf key, airflow/trace_sampled,
    provides that per-run override: true always traces the run, false never 
does,
    and absent leaves the decision to the sampler. Only an explicit bool is 
honored
    so a malformed value can neither silently change sampling nor fail run 
creation.
---
 airflow-core/src/airflow/models/dagrun.py          | 20 +++++++++
 airflow-core/src/airflow/models/taskinstance.py    |  7 ++-
 airflow-core/tests/unit/models/test_dagrun.py      | 32 ++++++++++++++
 .../tests/unit/models/test_taskinstance.py         | 18 ++++++++
 .../observability/traces/__init__.py               | 50 ++++++++++++++--------
 .../tests/observability/test_traces.py             | 39 +++++++++++++++++
 6 files changed, 146 insertions(+), 20 deletions(-)

diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index f5ae6a28d58..056151cefa1 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -65,6 +65,7 @@ from sqlalchemy.sql.functions import coalesce
 from airflow._shared.observability.metrics import stats
 from airflow._shared.observability.traces import (
     TASK_SPAN_DETAIL_LEVEL_KEY,
+    TRACE_SAMPLED_KEY,
     new_dagrun_trace_carrier,
     override_ids,
 )
@@ -176,6 +177,24 @@ def dagrun_trace_attributes(dr) -> dict[str, str]:
     return attributes
 
 
+def trace_sampled_override(conf) -> bool | None:
+    """
+    Head-sampling override from the ``airflow/trace_sampled`` run conf key.
+
+    Returns the forced SAMPLED flag only when the conf value is an explicit 
bool
+    (True = always trace this run, False = never); otherwise None, meaning no
+    override and the configured sampler decides. Non-bool values are ignored
+    rather than coerced, so a malformed value never silently flips sampling or
+    fails run creation.
+    """
+    if not conf:
+        return None
+    raw = conf.get(TRACE_SAMPLED_KEY)
+    if isinstance(raw, bool):
+        return raw
+    return None
+
+
 class DagRun(Base, LoggingMixin):
     """
     Invocation instance of a DAG.
@@ -409,6 +428,7 @@ class DagRun(Base, LoggingMixin):
         self.context_carrier: dict[str, str] = new_dagrun_trace_carrier(
             task_span_detail_level=self.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY, 
None),
             attributes=dagrun_trace_attributes(self),  # these are for 
potential use by head sampler
+            force_sampled=trace_sampled_override(self.conf),
         )
 
         if not isinstance(partition_key, str | None):
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index f905e4c8e3d..fcf4367c910 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -407,7 +407,11 @@ def clear_task_instances(
             session.merge(ti)
 
     if dag_run_state is not False and tis:
-        from airflow.models.dagrun import DagRun, dagrun_trace_attributes  # 
Avoid circular import
+        from airflow.models.dagrun import (  # Avoid circular import
+            DagRun,
+            dagrun_trace_attributes,
+            trace_sampled_override,
+        )
 
         run_ids_by_dag_id = defaultdict(set)
         for instance in tis:
@@ -431,6 +435,7 @@ def clear_task_instances(
             dr.context_carrier = new_dagrun_trace_carrier(
                 task_span_detail_level=dr.conf.get(TASK_SPAN_DETAIL_LEVEL_KEY) 
if dr.conf else None,
                 attributes=dagrun_trace_attributes(dr),
+                force_sampled=trace_sampled_override(dr.conf),
             )
 
             _recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session)
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index a2e852f5834..4ec048a7f4b 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -4410,6 +4410,38 @@ class TestDagRunTracing:
         span = trace.get_current_span(ctx)
         assert get_task_span_detail_level(span) == 2
 
+    @pytest.mark.parametrize(
+        ("conf", "expected"),
+        [
+            ({"airflow/trace_sampled": True}, True),
+            ({"airflow/trace_sampled": False}, False),
+            ({}, None),
+            (None, None),
+            ({"airflow/trace_sampled": "true"}, None),
+            ({"airflow/trace_sampled": 1}, None),
+            ({"other": True}, None),
+        ],
+    )
+    def test_trace_sampled_override(self, conf, expected):
+        """Only an explicit bool conf value is honored; anything else falls 
through to the sampler."""
+        from airflow.models.dagrun import trace_sampled_override
+
+        assert trace_sampled_override(conf) is expected
+
+    @pytest.mark.parametrize("flag", [True, False])
+    def test_context_carrier_honors_trace_sampled_conf(self, dag_maker, flag):
+        """airflow/trace_sampled in conf forces the carrier's SAMPLED flag 
regardless of the sampler."""
+        from opentelemetry import trace
+        from opentelemetry.trace.propagation.tracecontext import 
TraceContextTextMapPropagator
+
+        with dag_maker("test_trace_sampled_conf"):
+            EmptyOperator(task_id="t1")
+        dr = dag_maker.create_dagrun(conf={"airflow/trace_sampled": flag})
+
+        ctx = TraceContextTextMapPropagator().extract(dr.context_carrier)
+        span_ctx = trace.get_current_span(ctx).get_span_context()
+        assert span_ctx.trace_flags.sampled is flag
+
 
 class TestDagRunStatsTagsTeamName:
     def test_stats_tags_without_team_name(self, dag_maker):
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 97ec0242235..9b30c71cfa9 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -4076,6 +4076,24 @@ def 
test_clear_task_instances_preserves_detail_level(dag_maker, session):
     assert get_task_span_detail_level(span) == 2
 
 
[email protected]_test
[email protected]("flag", [True, False])
+def test_clear_task_instances_honors_trace_sampled_conf(dag_maker, session, 
flag):
+    """The regenerated carrier honors the airflow/trace_sampled override from 
dag run conf."""
+    with dag_maker("test_clear_trace_sampled"):
+        EmptyOperator(task_id="t1")
+    dag_run = dag_maker.create_dagrun(conf={"airflow/trace_sampled": flag})
+    ti = dag_run.get_task_instance("t1", session=session)
+    ti.state = TaskInstanceState.SUCCESS
+    session.flush()
+
+    clear_task_instances([ti], session)
+
+    new_ctx = TraceContextTextMapPropagator().extract(dag_run.context_carrier)
+    span_ctx = trace.get_current_span(new_ctx).get_span_context()
+    assert span_ctx.trace_flags.sampled is flag
+
+
 @pytest.mark.db_test
 def test_task_instance_repr_does_not_raise_for_deferred_columns(dag_maker, 
session):
     """``TaskInstance.__repr__`` must survive *any* deferred column it reads.
diff --git 
a/shared/observability/src/airflow_shared/observability/traces/__init__.py 
b/shared/observability/src/airflow_shared/observability/traces/__init__.py
index ec2ee6e2a0c..f0ef4352527 100644
--- a/shared/observability/src/airflow_shared/observability/traces/__init__.py
+++ b/shared/observability/src/airflow_shared/observability/traces/__init__.py
@@ -59,43 +59,55 @@ class OverrideableRandomIdGenerator(RandomIdGenerator):
 
 TASK_SPAN_DETAIL_LEVEL_KEY = "airflow/task_span_detail_level"
 DEFAULT_TASK_SPAN_DETAIL_LEVEL = 1
+TRACE_SAMPLED_KEY = "airflow/trace_sampled"
 
 
-def new_dagrun_trace_carrier(task_span_detail_level=None, attributes=None) -> 
dict[str, str]:
+def new_dagrun_trace_carrier(
+    task_span_detail_level=None, attributes=None, force_sampled=None
+) -> dict[str, str]:
     """
     Generate a fresh W3C traceparent carrier without creating a recordable 
span.
 
     The SAMPLED flag is set from an honest *root* sampling decision made by the
     configured tracer provider's sampler (driven by ``OTEL_TRACES_SAMPLER`` /
     ``OTEL_TRACES_SAMPLER_ARG``), rather than being hardcoded. This makes the
-    carrier the single head-sampling decision point for a DAG run: every
-    downstream span (dag_run, task_run, worker) rides on this flag.
+    carrier the head-sampling decision point for a DAG run: every downstream 
span
+    (dag_run, task_run, worker) rides on this flag.
 
     ``attributes`` are forwarded to the sampler as ``should_sample`` 
attributes so
     a custom sampler can differentiate the decision by run kind (e.g. by
     ``airflow.dag_id`` / ``airflow.dag_run.run_type``). The built-in samplers 
ignore them.
     They are decision input only -- they are not persisted in the carrier.
+
+    ``force_sampled`` overrides the sampler entirely: when not None it sets the
+    SAMPLED flag directly (True = always trace this run, False = never) and the
+    sampler is not consulted. Airflow wires this from the 
``airflow/trace_sampled``
+    run conf key; when None the configured sampler makes the decision.
     """
     gen = RandomIdGenerator()
     trace_id = gen.generate_trace_id()
 
-    provider = trace.get_tracer_provider()
-    sampler = getattr(provider, "sampler", None)
-    if sampler is not None:
-        result = sampler.should_sample(
-            parent_context=None,  # root decision
-            trace_id=trace_id,
-            name="dag_run",
-            attributes=attributes or {},
-        )
-        sampled = result.decision == Decision.RECORD_AND_SAMPLE
-        sampler_trace_state = result.trace_state
-    else:
-        # No sampler attribute means a proxy/no-op provider (otel disabled).
-        # Nothing exports in that case, so the flag is irrelevant; mirror the
-        # observable behavior of today when otel is off.
-        sampled = False
+    if force_sampled is not None:
+        sampled = force_sampled
         sampler_trace_state = None
+    else:
+        provider = trace.get_tracer_provider()
+        sampler = getattr(provider, "sampler", None)
+        if sampler is not None:
+            result = sampler.should_sample(
+                parent_context=None,  # root decision
+                trace_id=trace_id,
+                name="dag_run",
+                attributes=attributes or {},
+            )
+            sampled = result.decision == Decision.RECORD_AND_SAMPLE
+            sampler_trace_state = result.trace_state
+        else:
+            # No sampler attribute means a proxy/no-op provider (otel 
disabled).
+            # Nothing exports in that case, so the flag is irrelevant; mirror 
the
+            # observable behavior of today when otel is off.
+            sampled = False
+            sampler_trace_state = None
 
     # Preserve the detail-level tracestate by merging it onto whatever the
     # sampler returned. TraceState is immutable, so update() returns a new one.
diff --git a/shared/observability/tests/observability/test_traces.py 
b/shared/observability/tests/observability/test_traces.py
index ba09ad59fe0..e2b43da9237 100644
--- a/shared/observability/tests/observability/test_traces.py
+++ b/shared/observability/tests/observability/test_traces.py
@@ -210,6 +210,45 @@ class TestNewDagrunTraceCarrierSampling:
         new_dagrun_trace_carrier()
         assert captured["attributes"] == {}
 
+    def test_force_sampled_true_overrides_sampler(self, with_sampler):
+        """force_sampled=True samples the run even when the sampler says no."""
+        with_sampler(ALWAYS_OFF)
+        assert 
_carrier_is_sampled(new_dagrun_trace_carrier(force_sampled=True)) is True
+
+    def test_force_sampled_false_overrides_sampler(self, with_sampler):
+        """force_sampled=False drops the run even when the sampler says yes."""
+        with_sampler(ParentBased(ALWAYS_ON))
+        assert 
_carrier_is_sampled(new_dagrun_trace_carrier(force_sampled=False)) is False
+
+    def test_force_sampled_bypasses_sampler(self, monkeypatch):
+        """When force_sampled is set, the sampler is not consulted at all."""
+        called = False
+
+        class _RecordingSampler:
+            def should_sample(self, *args, **kwargs):
+                nonlocal called
+                called = True
+                return ALWAYS_ON.should_sample(*args, **kwargs)
+
+        class _Provider:
+            sampler = _RecordingSampler()
+
+        monkeypatch.setattr(
+            "airflow_shared.observability.traces.trace.get_tracer_provider",
+            lambda: _Provider(),
+        )
+        new_dagrun_trace_carrier(force_sampled=True)
+        assert called is False
+
+    def test_force_sampled_preserves_detail_level(self, with_sampler):
+        """Detail-level tracestate still round-trips when the decision is 
forced."""
+        with_sampler(ALWAYS_OFF)
+        carrier = new_dagrun_trace_carrier(task_span_detail_level=2, 
force_sampled=True)
+        ctx = TraceContextTextMapPropagator().extract(carrier)
+        span = trace.get_current_span(ctx)
+        assert get_task_span_detail_level(span) == 2
+        assert _carrier_is_sampled(carrier) is True
+
 
 class TestGetTaskSpanDetailLevel:
     def _make_span_with_trace_state(self, entries: list[tuple[str, str]]) -> 
NonRecordingSpan:

Reply via email to