This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0dc4d33b602 Introduce parent task spans and nest worker and trigger
spans under them (#63839)
0dc4d33b602 is described below
commit 0dc4d33b602c9fbeeebd1fa5a6fba3c06e120f5a
Author: Daniel Standish <[email protected]>
AuthorDate: Mon Mar 23 18:56:13 2026 -0700
Introduce parent task spans and nest worker and trigger spans under them
(#63839)
This lets us tie together the worker and trigger phases of task execution.
Also lets us see the delta between task queued time and task start time.
---
.../execution_api/routes/task_instances.py | 46 ++++++++
.../src/airflow/executors/workloads/task.py | 1 -
.../src/airflow/jobs/triggerer_job_runner.py | 88 +++++++++++-----
airflow-core/src/airflow/models/dagrun.py | 24 ++++-
airflow-core/src/airflow/models/taskinstance.py | 13 ++-
airflow-core/src/airflow/models/taskmap.py | 16 +++
airflow-core/tests/integration/otel/test_otel.py | 5 +-
.../versions/head/test_task_instances.py | 116 +++++++++++++++++++++
airflow-core/tests/unit/jobs/test_triggerer_job.py | 101 +++++++++++++++++-
airflow-core/tests/unit/models/test_dagrun.py | 35 ++-----
.../tests/unit/models/test_taskinstance.py | 104 ++++++++++++++++++
.../observability/traces/__init__.py | 12 +++
.../src/airflow/sdk/execution_time/task_runner.py | 4 +-
.../task_sdk/execution_time/test_task_runner.py | 87 +++++++++++-----
14 files changed, 562 insertions(+), 90 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 96d8f3a6c86..5f5073c916b 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -29,6 +29,9 @@ import attrs
import structlog
from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, Security, status
+from opentelemetry import trace
+from opentelemetry.trace import StatusCode
+from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from pydantic import JsonValue
from sqlalchemy import and_, func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
@@ -37,6 +40,7 @@ from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select
from structlog.contextvars import bind_contextvars
+from airflow._shared.observability.traces import override_ids
from airflow._shared.timezones import timezone
from airflow.api_fastapi.common.dagbag import DagBagDep,
get_latest_version_of_dag
from airflow.api_fastapi.common.db.common import SessionDep
@@ -87,6 +91,7 @@ ti_id_router = VersionedAPIRouter(
log = structlog.get_logger(__name__)
+tracer = trace.get_tracer(__name__)
@ti_id_router.patch(
@@ -431,6 +436,46 @@ def ti_update_state(
)
+def _emit_task_span(ti, state):
+ # just to be safe
+ if not ti.dag_run:
+ return
+ if not isinstance(ti.dag_run.context_carrier, dict):
+ return
+ if not isinstance(ti.context_carrier, dict):
+ return
+ dr_ctx =
TraceContextTextMapPropagator().extract(ti.dag_run.context_carrier)
+
+ ti_ctx = TraceContextTextMapPropagator().extract(ti.context_carrier)
+ ti_span = trace.get_current_span(context=ti_ctx)
+ span_context = ti_span.get_span_context()
+ start_time_candidates = (x for x in (ti.queued_dttm, ti.start_date,
timezone.utcnow()) if x)
+ name = f"task_run.{ti.task_id}"
+ if ti.map_index >= 0:
+ name += f"[{ti.map_index}]"
+ with override_ids(span_context.trace_id, span_context.span_id):
+ span = tracer.start_span(
+ name=name,
+ start_time=int(min(start_time_candidates).timestamp() * 1e9),
+ context=dr_ctx,
+ )
+
+ span.set_attributes(
+ {
+ "airflow.dag_id": ti.dag_id,
+ "airflow.task_id": ti.task_id,
+ "airflow.dag_run.run_id": ti.run_id,
+ "airflow.task_instance.try_number": ti.try_number,
+ "airflow.task_instance.map_index": ti.map_index if
ti.map_index is not None else -1,
+ "airflow.task_instance.state": state,
+ "airflow.task_instance.id": ti.id,
+ }
+ )
+ status_code = StatusCode.OK if state == TaskInstanceState.SUCCESS else
StatusCode.ERROR
+ span.set_status(status_code)
+ span.end()
+
+
def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep,
dag_bag: DagBagDep) -> None:
dr = ti.dag_run
@@ -479,6 +524,7 @@ def _create_ti_state_update_query_and_update_state(
ti_patch_payload.outlet_events,
session,
)
+ _emit_task_span(ti, state=updated_state)
elif isinstance(ti_patch_payload, TIDeferredStatePayload):
# Calculate timeout if it was passed
timeout = None
diff --git a/airflow-core/src/airflow/executors/workloads/task.py
b/airflow-core/src/airflow/executors/workloads/task.py
index a5939cf4244..4ca8c310fb5 100644
--- a/airflow-core/src/airflow/executors/workloads/task.py
+++ b/airflow-core/src/airflow/executors/workloads/task.py
@@ -86,7 +86,6 @@ class ExecuteTask(BaseDagBundleWorkload):
from airflow.utils.helpers import log_filename_template_renderer
ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True)
- ser_ti.context_carrier = ti.dag_run.context_carrier
if not bundle_info:
bundle_info = BundleInfo(
name=ti.dag_model.bundle_name,
diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
index 1406283c05c..44c28a7a539 100644
--- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py
+++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py
@@ -35,6 +35,9 @@ from typing import TYPE_CHECKING, Annotated, Any, BinaryIO,
ClassVar, Literal, T
import anyio
import attrs
import structlog
+from opentelemetry import trace
+from opentelemetry.trace import Status, StatusCode
+from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from pydantic import BaseModel, Field, TypeAdapter
from sqlalchemy import func, select
from structlog.contextvars import bind_contextvars as bind_log_contextvars
@@ -87,6 +90,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
if TYPE_CHECKING:
+ from opentelemetry.util._decorator import _AgnosticContextManager
from sqlalchemy.orm import Session
from structlog.typing import FilteringBoundLogger, WrappedLogger
@@ -96,6 +100,34 @@ if TYPE_CHECKING:
from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI
logger = logging.getLogger(__name__)
+tracer = trace.get_tracer(__name__)
+
+
+def _make_trigger_span(
+ ti: TaskInstanceDTO | None, trigger_id: int, name: str
+) -> _AgnosticContextManager[trace.Span]:
+ parent_context = (
+ TraceContextTextMapPropagator().extract(ti.context_carrier) if ti and
ti.context_carrier else None
+ )
+ attributes: dict[str, str | int] = {
+ "airflow.trigger.name": name,
+ }
+ if isinstance(ti, TaskInstanceDTO):
+ span_name = f"trigger.{ti.task_id}" if ti else f"trigger.{trigger_id}"
+ if ti.map_index >= 0:
+ span_name += f"_{ti.map_index}"
+ attributes = {
+ **attributes,
+ "airflow.dag_id": ti.dag_id,
+ "airflow.task_id": ti.task_id,
+ "airflow.dag_run.run_id": ti.run_id,
+ "airflow.task_instance.try_number": ti.try_number,
+ "airflow.task_instance.map_index": ti.map_index,
+ }
+ else:
+ span_name = f"trigger.{name}"
+ return tracer.start_as_current_span(span_name, attributes=attributes,
context=parent_context)
+
__all__ = [
"TriggerRunner",
@@ -1179,30 +1211,38 @@ class TriggerRunner:
name = self.triggers[trigger_id]["name"]
self.log.info("trigger %s starting", name)
- try:
- async for event in trigger.run():
- await self.log.ainfo(
- "Trigger fired event",
name=self.triggers[trigger_id]["name"], result=event
- )
- self.triggers[trigger_id]["events"] += 1
- self.events.append((trigger_id, event))
- except asyncio.CancelledError:
- # We get cancelled by the scheduler changing the task state. But
if we do lets give a nice error
- # message about it
- if timeout := timeout_after:
- timeout = timeout.replace(tzinfo=timezone.utc) if not
timeout.tzinfo else timeout
- if timeout < timezone.utcnow():
- await self.log.aerror("Trigger cancelled due to timeout")
- raise
- finally:
- # CancelledError will get injected when we're stopped - which is
- # fine, the cleanup process will understand that, but we want to
- # allow triggers a chance to cleanup, either in that case or if
- # they exit cleanly. Exception from cleanup methods are ignored.
- with suppress(Exception):
- await trigger.cleanup()
-
- await self.log.ainfo("trigger completed", name=name)
+ with _make_trigger_span(ti=trigger.task_instance,
trigger_id=trigger_id, name=name) as span:
+ try:
+ async for event in trigger.run():
+ await self.log.ainfo(
+ "Trigger fired event",
name=self.triggers[trigger_id]["name"], result=event
+ )
+ self.triggers[trigger_id]["events"] += 1
+ self.events.append((trigger_id, event))
+ span.set_status(Status(StatusCode.OK))
+ except asyncio.CancelledError as e:
+ # We get cancelled by the scheduler changing the task state.
But if we do lets give a nice error
+ # message about it
+ if timeout := timeout_after:
+ timeout = timeout.replace(tzinfo=timezone.utc) if not
timeout.tzinfo else timeout
+ if timeout < timezone.utcnow():
+ await self.log.aerror("Trigger cancelled due to
timeout")
+ span.set_status(Status(StatusCode.ERROR),
description=str(e))
+ raise
+ span.set_status(Status(StatusCode.OK), description=str(e))
+ raise
+ except Exception as e:
+ span.set_status(Status(StatusCode.ERROR), description=str(e))
+ raise
+ finally:
+ # CancelledError will get injected when we're stopped - which
is
+ # fine, the cleanup process will understand that, but we want
to
+ # allow triggers a chance to cleanup, either in that case or if
+ # they exit cleanly. Exception from cleanup methods are
ignored.
+ with suppress(Exception):
+ await trigger.cleanup()
+
+ await self.log.ainfo("trigger completed", name=name)
def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]:
"""
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 73a5a3a875a..20ec118f33e 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1027,21 +1027,33 @@ class DagRun(Base, LoggingMixin):
return leaf_tis
def _emit_dagrun_span(self, state: DagRunState):
- ctx = TraceContextTextMapPropagator().extract(self.context_carrier or
{})
+ # just to be safe
+ if not isinstance(self.context_carrier, dict):
+ return
+
+ ctx = TraceContextTextMapPropagator().extract(self.context_carrier)
span = trace.get_current_span(context=ctx)
span_context = span.get_span_context()
with override_ids(span_context.trace_id, span_context.span_id):
- attributes = {
+ attributes: dict[str, str] = {
"airflow.dag_id": str(self.dag_id),
"airflow.dag_run.run_id": self.run_id,
}
+ if self.start_date:
+ attributes["airflow.dag_run.start_date"] = str(self.start_date)
+ if self.end_date:
+ attributes["airflow.dag_run.end_date"] = str(self.end_date)
+ if self.queued_at:
+ attributes["airflow.dag_run.queued_at"] = str(self.queued_at)
+ if self.created_at:
+ attributes["airflow.dag_run.created_at"] = str(self.created_at)
if self.logical_date:
attributes["airflow.dag_run.logical_date"] =
str(self.logical_date)
if self.partition_key:
attributes["airflow.dag_run.partition_key"] =
str(self.partition_key)
span = tracer.start_span(
name=f"dag_run.{self.dag_id}",
- start_time=int((self.start_date or
timezone.utcnow()).timestamp() * 1e9),
+ start_time=int((self.queued_at or self.start_date or
timezone.utcnow()).timestamp() * 1e9),
attributes=attributes,
context=context.Context(),
)
@@ -1771,7 +1783,11 @@ class DagRun(Base, LoggingMixin):
created_counts[task.task_type] += 1
for map_index in indexes:
yield TI.insert_mapping(
- self.run_id, task, map_index=map_index,
dag_version_id=dag_version_id
+ self.run_id,
+ task,
+ map_index=map_index,
+ dag_version_id=dag_version_id,
+ dag_run=self,
)
creator = create_ti_mapping
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 443468161f8..e212ca68504 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -32,6 +32,7 @@ from uuid import UUID
import attrs
import dill
import uuid6
+from opentelemetry import trace
from sqlalchemy import (
JSON,
Float,
@@ -67,6 +68,7 @@ from sqlalchemy.orm.attributes import NO_VALUE,
set_committed_value
from airflow import settings
from airflow._shared.observability.metrics.dual_stats_manager import
DualStatsManager
from airflow._shared.observability.metrics.stats import Stats
+from airflow._shared.observability.traces import new_dagrun_trace_carrier,
new_task_run_carrier
from airflow._shared.timezones import timezone
from airflow.assets.manager import asset_manager
from airflow.configuration import conf
@@ -102,7 +104,7 @@ from airflow.utils.state import DagRunState, State,
TaskInstanceState
TR = TaskReschedule
log = logging.getLogger(__name__)
-
+tracer = trace.get_tracer(__name__)
if TYPE_CHECKING:
from datetime import datetime
@@ -382,7 +384,7 @@ def clear_task_instances(
for instance in tis:
run_ids_by_dag_id[instance.dag_id].add(instance.run_id)
- drs = session.scalars(
+ drs: Iterable[DagRun] = session.scalars(
select(DagRun).where(
or_(
*(
@@ -397,6 +399,7 @@ def clear_task_instances(
# Always update clear_number and queued_at when clearing tasks,
regardless of state
dr.clear_number += 1
dr.queued_at = timezone.utcnow()
+ dr.context_carrier = new_dagrun_trace_carrier()
_recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session)
@@ -425,6 +428,8 @@ def clear_task_instances(
if dag_run_state == DagRunState.QUEUED:
dr.last_scheduling_decision = None
dr.start_date = None
+ for ti in tis:
+ ti.context_carrier = new_task_run_carrier(ti.dag_run.context_carrier)
session.flush()
@@ -679,7 +684,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
@staticmethod
def insert_mapping(
- run_id: str, task: Operator, map_index: int, dag_version_id: UUID |
None
+ run_id: str, task: Operator, map_index: int, *, dag_version_id: UUID |
None, dag_run: DagRun
) -> dict[str, Any]:
"""
Insert mapping.
@@ -689,6 +694,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
priority_weight = task.weight_rule.get_weight(
TaskInstance(task=task, run_id=run_id, map_index=map_index,
dag_version_id=dag_version_id)
)
+ context_carrier = new_task_run_carrier(dag_run.context_carrier)
return {
"dag_id": task.dag_id,
@@ -710,6 +716,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
"map_index": map_index,
"_task_display_property_value": task.task_display_name,
"dag_version_id": dag_version_id,
+ "context_carrier": context_carrier,
}
@reconstructor
diff --git a/airflow-core/src/airflow/models/taskmap.py
b/airflow-core/src/airflow/models/taskmap.py
index 18d09d6aa56..60486b8ce86 100644
--- a/airflow-core/src/airflow/models/taskmap.py
+++ b/airflow-core/src/airflow/models/taskmap.py
@@ -24,9 +24,11 @@ import enum
from collections.abc import Collection, Iterable, Sequence
from typing import TYPE_CHECKING, Any
+from opentelemetry import trace
from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String,
func, or_, select
from sqlalchemy.orm import Mapped, mapped_column
+from airflow._shared.observability.traces import new_task_run_carrier
from airflow.models.base import COLLATION_ARGS, ID_LEN,
TaskInstanceDependencies
from airflow.models.dag_version import DagVersion
from airflow.utils.db import exists_query
@@ -38,6 +40,7 @@ if TYPE_CHECKING:
from airflow.models.taskinstance import TaskInstance
from airflow.serialization.definitions.mappedoperator import Operator
+tracer = trace.get_tracer(__name__)
class TaskMapVariant(enum.Enum):
@@ -242,6 +245,18 @@ class TaskMap(TaskInstanceDependencies):
else:
dag_version_id = None
+ if unmapped_ti:
+ dr = unmapped_ti.dag_run
+ else:
+ from airflow.models import DagRun
+
+ dr = session.scalar(
+ select(DagRun).where(
+ DagRun.dag_id == task.dag_id,
+ DagRun.run_id == run_id,
+ )
+ )
+
for index in indexes_to_map:
# TODO: Make more efficient with
bulk_insert_mappings/bulk_save_mappings.
ti = TaskInstance(
@@ -254,6 +269,7 @@ class TaskMap(TaskInstanceDependencies):
task.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
+ ti.context_carrier = new_task_run_carrier(dr.context_carrier)
ti.refresh_from_task(task) # session.merge() loses task
information.
all_expanded_tis.append(ti)
diff --git a/airflow-core/tests/integration/otel/test_otel.py
b/airflow-core/tests/integration/otel/test_otel.py
index 0d40156c45e..6852b47af04 100644
--- a/airflow-core/tests/integration/otel/test_otel.py
+++ b/airflow-core/tests/integration/otel/test_otel.py
@@ -508,9 +508,10 @@ class TestOtelIntegration:
nested = get_span_hierarchy()
assert nested == {
- "sub_span1": "task_run.task1",
- "task_run.task1": "dag_run.otel_test_dag",
"dag_run.otel_test_dag": None,
+ "sub_span1": "worker.task1",
+ "task_run.task1": "dag_run.otel_test_dag",
+ "worker.task1": "task_run.task1",
}
def start_scheduler(self, capture_output: bool = False):
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 7cc7aa85594..7f766ede71e 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -24,13 +24,21 @@ from uuid import UUID, uuid4
import pytest
import uuid6
+from opentelemetry import trace as otel_trace
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
+from opentelemetry.trace import StatusCode
+from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from sqlalchemy import select, update
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
+from airflow._shared.observability.traces import OverrideableRandomIdGenerator
from airflow._shared.timezones import timezone
from airflow.api_fastapi.auth.tokens import JWTValidator
from airflow.api_fastapi.execution_api.app import lifespan
+from airflow.api_fastapi.execution_api.routes.task_instances import
_emit_task_span
from airflow.exceptions import AirflowSkipException
from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger
from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent,
AssetModel
@@ -3242,3 +3250,111 @@ class TestTokenTypeValidation:
payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"}
resp = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
assert resp.status_code in [200, 204]
+
+
+class TestEmitTaskSpan:
+ """Tests for the _emit_task_span function in the execution API
task-instance route."""
+
+ @pytest.fixture(autouse=True)
+ def sdk_tracer_provider(self):
+ self.exporter = InMemorySpanExporter()
+ provider = TracerProvider(id_generator=OverrideableRandomIdGenerator())
+ provider.add_span_processor(SimpleSpanProcessor(self.exporter))
+ test_tracer = provider.get_tracer("test")
+ with
mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.tracer",
test_tracer):
+ yield
+
+ def _make_carriers(self):
+ """Return a (dr_carrier, ti_carrier) pair built with a real SDK
provider."""
+ p = TracerProvider()
+ t = p.get_tracer("setup")
+ dr_span = t.start_span("dr")
+ dr_ctx = otel_trace.set_span_in_context(dr_span)
+ dr_carrier: dict = {}
+ TraceContextTextMapPropagator().inject(dr_carrier, context=dr_ctx)
+ ti_span = t.start_span("ti", context=dr_ctx)
+ ti_ctx = otel_trace.set_span_in_context(ti_span)
+ ti_carrier: dict = {}
+ TraceContextTextMapPropagator().inject(ti_carrier, context=ti_ctx)
+ return dr_carrier, ti_carrier
+
+ def _make_ti(self, task_id="my_task", map_index=-1, queued_dttm=None,
start_date=None):
+ dr_carrier, ti_carrier = self._make_carriers()
+ ti = mock.MagicMock()
+ ti.dag_id = "test_dag"
+ ti.task_id = task_id
+ ti.run_id = "test_run"
+ ti.try_number = 1
+ ti.map_index = map_index
+ ti.queued_dttm = queued_dttm
+ ti.start_date = start_date or DEFAULT_START_DATE
+ ti.dag_run.context_carrier = dr_carrier
+ ti.context_carrier = ti_carrier
+ return ti
+
+ def test_emit_task_span_success_sets_ok_status(self):
+ _emit_task_span(self._make_ti(), TaskInstanceState.SUCCESS)
+
+ spans = self.exporter.get_finished_spans()
+ assert len(spans) == 1
+ assert spans[0].status.status_code == StatusCode.OK
+
+ def test_emit_task_span_failed_sets_error_status(self):
+ _emit_task_span(self._make_ti(), TaskInstanceState.FAILED)
+
+ spans = self.exporter.get_finished_spans()
+ assert len(spans) == 1
+ assert spans[0].status.status_code == StatusCode.ERROR
+
+ def test_emit_task_span_sets_attributes(self):
+ ti = self._make_ti(task_id="my_task", map_index=2)
+ _emit_task_span(ti, TaskInstanceState.SUCCESS)
+
+ attrs = self.exporter.get_finished_spans()[0].attributes
+ assert attrs["airflow.dag_id"] == "test_dag"
+ assert attrs["airflow.task_id"] == "my_task"
+ assert attrs["airflow.dag_run.run_id"] == "test_run"
+ assert attrs["airflow.task_instance.try_number"] == 1
+ assert attrs["airflow.task_instance.map_index"] == 2
+ assert attrs["airflow.task_instance.state"] ==
TaskInstanceState.SUCCESS
+
+ def test_emit_task_span_name_unmapped(self):
+ _emit_task_span(self._make_ti(task_id="my_task", map_index=-1),
TaskInstanceState.SUCCESS)
+ assert self.exporter.get_finished_spans()[0].name == "task_run.my_task"
+
+ def test_emit_task_span_name_mapped(self):
+ _emit_task_span(self._make_ti(task_id="my_task", map_index=3),
TaskInstanceState.SUCCESS)
+ assert self.exporter.get_finished_spans()[0].name ==
"task_run.my_task[3]"
+
+ def test_emit_task_span_start_time_uses_queued_dttm(self):
+ queued_dttm = timezone.parse("2024-01-01T10:00:00Z")
+ start_date = timezone.parse("2024-01-01T10:05:00Z")
+ ti = self._make_ti(queued_dttm=queued_dttm, start_date=start_date)
+ _emit_task_span(ti, TaskInstanceState.SUCCESS)
+
+ assert self.exporter.get_finished_spans()[0].start_time ==
int(queued_dttm.timestamp() * 1e9)
+
+ def test_emit_task_span_start_time_falls_back_to_start_date(self):
+ start_date = timezone.parse("2024-01-01T10:05:00Z")
+ ti = self._make_ti(queued_dttm=None, start_date=start_date)
+ _emit_task_span(ti, TaskInstanceState.SUCCESS)
+
+ assert self.exporter.get_finished_spans()[0].start_time ==
int(start_date.timestamp() * 1e9)
+
+ def test_emit_task_span_skips_if_no_ti_carrier(self):
+ ti = mock.MagicMock()
+ ti.dag_run.context_carrier = {
+ "traceparent":
"00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"
+ }
+ ti.context_carrier = None
+
+ _emit_task_span(ti, TaskInstanceState.SUCCESS)
+ assert len(self.exporter.get_finished_spans()) == 0
+
+ def test_emit_task_span_skips_if_no_dagrun_carrier(self):
+ ti = mock.MagicMock()
+ ti.dag_run.context_carrier = None
+ ti.context_carrier = {"traceparent":
"00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"}
+
+ _emit_task_span(ti, TaskInstanceState.SUCCESS)
+ assert len(self.exporter.get_finished_spans()) == 0
diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py
b/airflow-core/tests/unit/jobs/test_triggerer_job.py
index 802a34192e3..3761189bfeb 100644
--- a/airflow-core/tests/unit/jobs/test_triggerer_job.py
+++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py
@@ -24,18 +24,26 @@ import os
import selectors
import time
import typing
+import uuid
from collections.abc import AsyncIterator
from socket import socket
from typing import TYPE_CHECKING, Any
+from unittest import mock
from unittest.mock import ANY, AsyncMock, MagicMock, patch
import pendulum
import pytest
from asgiref.sync import sync_to_async
+from opentelemetry import trace as otel_trace
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
+from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from structlog.typing import FilteringBoundLogger
from airflow._shared.timezones import timezone
from airflow.executors import workloads
+from airflow.executors.workloads.task import TaskInstanceDTO
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import (
ToTriggerRunner,
@@ -45,6 +53,7 @@ from airflow.jobs.triggerer_job_runner import (
TriggerLoggingFactory,
TriggerRunner,
TriggerRunnerSupervisor,
+ _make_trigger_span,
messages,
)
from airflow.models import Connection, DagModel, DagRun, Trigger, Variable
@@ -318,7 +327,6 @@ def test_trigger_logger_close():
def test_trigger_logger_fd_closed_when_removed(session):
-
trigger = TimeDeltaTrigger(datetime.timedelta(seconds=0.5))
create_trigger_in_db(session, trigger)
@@ -349,11 +357,12 @@ class TestTriggerRunner:
mock_trigger = MagicMock(spec=BaseTrigger)
mock_trigger.timeout_after = None
mock_trigger.run.side_effect = asyncio.CancelledError()
+ mock_trigger.task_instance = MagicMock()
+ mock_trigger.task_instance.map_index = -1
with pytest.raises(asyncio.CancelledError):
asyncio.run(trigger_runner.run_trigger(1, mock_trigger))
- # @pytest.mark.asyncio
def test_run_inline_trigger_timeout(self, session, cap_structlog) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {
@@ -361,6 +370,8 @@ class TestTriggerRunner:
}
mock_trigger = MagicMock(spec=BaseTrigger)
mock_trigger.run.side_effect = asyncio.CancelledError()
+ mock_trigger.task_instance = MagicMock()
+ mock_trigger.task_instance.map_index = -1
with pytest.raises(asyncio.CancelledError):
asyncio.run(
@@ -1358,3 +1369,89 @@ class TestTriggererMessageTypes:
+ "\n".join(f" - {t}" for t in sorted(task_diff))
+ "\n\nEither handle these types in ToTriggerRunner or update
in_task_but_not_in_trigger_runner list."
)
+
+
+class TestMakeTriggerSpan:
+ """Tests for the _make_trigger_span helper in the triggerer job runner."""
+
+ @pytest.fixture(autouse=True)
+ def sdk_tracer_provider(self):
+ self.exporter = InMemorySpanExporter()
+ provider = TracerProvider()
+ provider.add_span_processor(SimpleSpanProcessor(self.exporter))
+ test_tracer = provider.get_tracer("test")
+ with mock.patch("airflow.jobs.triggerer_job_runner.tracer",
test_tracer):
+ yield
+
+ def _make_ti_dto(self, task_id="my_task", map_index=-1,
context_carrier=None):
+ return TaskInstanceDTO(
+ id=uuid.uuid4(),
+ dag_version_id=uuid.uuid4(),
+ task_id=task_id,
+ dag_id="test_dag",
+ run_id="test_run",
+ try_number=1,
+ map_index=map_index,
+ pool_slots=1,
+ queue="default",
+ priority_weight=1,
+ context_carrier=context_carrier,
+ )
+
+ def test_make_trigger_span_name_with_task_instance(self):
+ ti = self._make_ti_dto(task_id="sensor_task", map_index=-1)
+ with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"):
+ pass
+ assert self.exporter.get_finished_spans()[0].name ==
"trigger.sensor_task"
+
+ def test_make_trigger_span_name_with_mapped_task(self):
+ ti = self._make_ti_dto(task_id="sensor_task", map_index=2)
+ with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"):
+ pass
+ assert self.exporter.get_finished_spans()[0].name ==
"trigger.sensor_task_2"
+
+ def test_make_trigger_span_name_without_task_instance(self):
+ with _make_trigger_span(ti=None, trigger_id=42, name="Some trigger
name"):
+ pass
+ assert self.exporter.get_finished_spans()[0].name == "trigger.Some
trigger name"
+
+ def test_make_trigger_span_uses_task_context_carrier(self):
+ # Build a valid ti carrier from a separate provider so we have a known
parent span.
+ setup_provider = TracerProvider()
+ setup_tracer = setup_provider.get_tracer("setup")
+ parent_span = setup_tracer.start_span("ti_parent")
+ parent_ctx = otel_trace.set_span_in_context(parent_span)
+ ti_carrier: dict = {}
+ TraceContextTextMapPropagator().inject(ti_carrier, context=parent_ctx)
+ expected_parent_span_id = parent_span.get_span_context().span_id
+
+ ti = self._make_ti_dto(context_carrier=ti_carrier)
+ with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"):
+ pass
+
+ spans = self.exporter.get_finished_spans()
+ assert len(spans) == 1
+ assert spans[0].parent is not None
+ assert spans[0].parent.span_id == expected_parent_span_id
+
+ def test_make_trigger_span_sets_attributes_with_ti(self):
+ ti = self._make_ti_dto(task_id="my_task", map_index=1)
+ with _make_trigger_span(ti=ti, trigger_id=5, name="MyTrigger"):
+ pass
+
+ attrs = self.exporter.get_finished_spans()[0].attributes
+ assert attrs["airflow.trigger.name"] == "MyTrigger"
+ assert attrs["airflow.dag_id"] == "test_dag"
+ assert attrs["airflow.task_id"] == "my_task"
+ assert attrs["airflow.dag_run.run_id"] == "test_run"
+ assert attrs["airflow.task_instance.try_number"] == 1
+ assert attrs["airflow.task_instance.map_index"] == 1
+
+ def test_make_trigger_span_sets_only_trigger_name_without_ti(self):
+ with _make_trigger_span(ti=None, trigger_id=99, name="OnlyTrigger"):
+ pass
+
+ attrs = self.exporter.get_finished_spans()[0].attributes
+ assert attrs["airflow.trigger.name"] == "OnlyTrigger"
+ assert "airflow.dag_id" not in attrs
+ assert "airflow.task_id" not in attrs
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 7ad78292a1e..b446db9f2b9 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -27,7 +27,12 @@ from unittest.mock import ANY, call
import pendulum
import pytest
+from opentelemetry import trace as otel_trace
from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
+from opentelemetry.trace import StatusCode
+from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from sqlalchemy import (
func,
select,
@@ -37,6 +42,7 @@ from sqlalchemy.orm import joinedload
from airflow import settings
from airflow._shared.observability.metrics.stats import Stats
+from airflow._shared.observability.traces import OverrideableRandomIdGenerator
from airflow._shared.timezones import timezone
from airflow.callbacks.callback_requests import DagCallbackRequest,
DagRunContext
from airflow.models.dag import DagModel, infer_automated_data_interval
@@ -3441,13 +3447,6 @@ class TestDagRunTracing:
def test_emit_dagrun_span_uses_context_carrier_ids(self, dag_maker,
session):
"""The emitted span should inherit trace_id/span_id from the
context_carrier."""
- from opentelemetry.sdk.trace import TracerProvider
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
- from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
- from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
-
- from airflow._shared.observability.traces import
OverrideableRandomIdGenerator
-
in_mem_exporter = InMemorySpanExporter()
provider = TracerProvider(id_generator=OverrideableRandomIdGenerator())
provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter))
@@ -3471,8 +3470,6 @@ class TestDagRunTracing:
# Decode the expected trace_id/span_id from the stored context_carrier
ctx = TraceContextTextMapPropagator().extract(dr.context_carrier)
- from opentelemetry import trace as otel_trace
-
stored_span = otel_trace.get_current_span(context=ctx)
stored_ctx = stored_span.get_span_context()
@@ -3482,13 +3479,6 @@ class TestDagRunTracing:
@pytest.mark.parametrize("final_state", [DagRunState.SUCCESS,
DagRunState.FAILED])
def test_emit_dagrun_span_attributes_and_status(self, dag_maker, session,
final_state):
"""The emitted span should have the correct name, attributes, and
status code."""
- from opentelemetry.sdk.trace import TracerProvider
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
- from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
- from opentelemetry.trace import StatusCode
-
- from airflow._shared.observability.traces import
OverrideableRandomIdGenerator
-
in_mem_exporter = InMemorySpanExporter()
provider = TracerProvider(id_generator=OverrideableRandomIdGenerator())
provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter))
@@ -3527,12 +3517,6 @@ class TestDagRunTracing:
context_carrier was cleared/backfilled to NULL. Per OTel spec, missing
context
results in a new root span rather than a crash.
"""
- from opentelemetry.sdk.trace import TracerProvider
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
- from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
-
- from airflow._shared.observability.traces import
OverrideableRandomIdGenerator
-
in_mem_exporter = InMemorySpanExporter()
provider = TracerProvider(id_generator=OverrideableRandomIdGenerator())
provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter))
@@ -3555,5 +3539,8 @@ class TestDagRunTracing:
# A root span should still be emitted
spans = in_mem_exporter.get_finished_spans()
- assert len(spans) == 1
- assert spans[0].name == f"dag_run.{dr.dag_id}"
+ if isinstance(carrier_value, dict):
+ assert len(spans) == 1
+ assert spans[0].name == f"dag_run.{dr.dag_id}"
+ else:
+ assert len(spans) == 0
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 82b9adc3162..bb058d1a737 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -30,11 +30,15 @@ import pendulum
import pytest
import time_machine
import uuid6
+from opentelemetry import trace as otel_trace
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from sqlalchemy import delete, func, select
from sqlalchemy.exc import IntegrityError
from airflow import settings
from airflow._shared.observability.metrics.stats import Stats
+from airflow._shared.observability.traces import new_dagrun_trace_carrier,
new_task_run_carrier
from airflow._shared.timezones import timezone
from airflow.exceptions import (
AirflowException,
@@ -3298,3 +3302,103 @@ def
test_get_dagrun_loaded_but_none_returns_dagrun(dag_maker, session):
assert dr_from_ti is not None
assert dr_from_ti == dr
+
+
+class TestMakeTaskCarrier:
+ """Tests for the _make_task_carrier helper."""
+
+ @pytest.fixture(autouse=True)
+ def sdk_tracer_provider(self):
+ provider = TracerProvider()
+ real_tracer = provider.get_tracer("airflow.models.taskinstance")
+ with (
+ mock.patch("airflow.models.taskinstance.tracer", real_tracer),
+ mock.patch("airflow._shared.observability.traces.tracer",
real_tracer),
+ ):
+ yield
+
+ def test_make_task_carrier_returns_traceparent(self):
+ carrier = new_task_run_carrier(new_dagrun_trace_carrier())
+ assert isinstance(carrier, dict)
+ assert "traceparent" in carrier
+
+ def test_make_task_carrier_child_of_parent(self):
+ parent_carrier = new_dagrun_trace_carrier()
+ child_carrier = new_task_run_carrier(parent_carrier)
+
+ propagator = TraceContextTextMapPropagator()
+ parent_trace_id = (
+
otel_trace.get_current_span(context=propagator.extract(parent_carrier))
+ .get_span_context()
+ .trace_id
+ )
+ child_trace_id = (
+
otel_trace.get_current_span(context=propagator.extract(child_carrier)).get_span_context().trace_id
+ )
+ assert child_trace_id == parent_trace_id
+ assert child_trace_id != 0
+
+ def test_make_task_carrier_with_none_carrier(self):
+ carrier = new_task_run_carrier(None)
+ assert isinstance(carrier, dict)
+ assert "traceparent" in carrier
+
+
[email protected]_test
+def test_insert_mapping_includes_context_carrier(dag_maker, session):
+ """insert_mapping should include a context_carrier with a traceparent
derived from the dag run."""
+ provider = TracerProvider()
+ real_tracer = provider.get_tracer("airflow.models.taskinstance")
+ with (
+ mock.patch("airflow.models.taskinstance.tracer", real_tracer),
+ mock.patch("airflow._shared.observability.traces.tracer", real_tracer),
+ ):
+ with dag_maker("test_insert_mapping_carrier"):
+ EmptyOperator(task_id="t1")
+ session.flush()
+
+ # Get the scheduler-side operator (has a proper
PriorityWeightStrategy, not the enum weight_rule).
+ op = create_scheduler_operator(dag_maker.dag.get_task("t1"))
+
+ # Mock the DagRun to avoid inserting into the dag_run table (schema
migrations may be pending).
+ dag_run = mock.MagicMock()
+ dag_run.context_carrier = new_dagrun_trace_carrier()
+
+ mapping = TaskInstance.insert_mapping(
+ run_id="test_run",
+ task=op,
+ map_index=0,
+ dag_version_id=None,
+ dag_run=dag_run,
+ )
+
+ assert "context_carrier" in mapping
+ assert mapping["context_carrier"] is not None
+ assert "traceparent" in mapping["context_carrier"]
+
+
[email protected]_test
+def test_clear_task_instances_resets_context_carrier(dag_maker, session):
+ """clear_task_instances should assign fresh context carriers to both the
TI and its dag run."""
+ provider = TracerProvider()
+ real_tracer = provider.get_tracer("airflow.models.taskinstance")
+ with (
+ mock.patch("airflow.models.taskinstance.tracer", real_tracer),
+ mock.patch("airflow._shared.observability.traces.tracer", real_tracer),
+ ):
+ with dag_maker("test_clear_carrier"):
+ EmptyOperator(task_id="t1")
+ dag_run = dag_maker.create_dagrun()
+ ti = dag_run.get_task_instance("t1", session=session)
+ ti.state = TaskInstanceState.SUCCESS
+ # Set an explicit carrier so we can verify it changes.
+ ti.context_carrier = {"traceparent":
"00-aaaaaaaaaaaaaaaaaaaaaaaaaaaa0001-bbbbbbbbbbbbbbbb-01"}
+ session.flush()
+
+ original_ti_traceparent = ti.context_carrier["traceparent"]
+ original_dr_traceparent = dag_run.context_carrier["traceparent"]
+
+ clear_task_instances([ti], session)
+
+ assert ti.context_carrier["traceparent"] != original_ti_traceparent
+ assert dag_run.context_carrier["traceparent"] != original_dr_traceparent
diff --git
a/shared/observability/src/airflow_shared/observability/traces/__init__.py
b/shared/observability/src/airflow_shared/observability/traces/__init__.py
index 53163e45b97..dc3532262d1 100644
--- a/shared/observability/src/airflow_shared/observability/traces/__init__.py
+++ b/shared/observability/src/airflow_shared/observability/traces/__init__.py
@@ -34,6 +34,7 @@ from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapProp
if TYPE_CHECKING:
from configparser import ConfigParser
log = logging.getLogger(__name__)
+tracer = trace.get_tracer(__name__)
OVERRIDE_SPAN_ID_KEY = context.create_key("override_span_id")
OVERRIDE_TRACE_ID_KEY = context.create_key("override_trace_id")
@@ -70,6 +71,17 @@ def new_dagrun_trace_carrier() -> dict[str, str]:
return carrier
+def new_task_run_carrier(dag_run_context_carrier):
+ parent_context = (
+ TraceContextTextMapPropagator().extract(dag_run_context_carrier) if
dag_run_context_carrier else None
+ )
+ span = tracer.start_span("notused", context=parent_context) #
intentionally never closed
+ new_ctx = trace.set_span_in_context(span)
+ carrier: dict[str, str] = {}
+ TraceContextTextMapPropagator().inject(carrier, context=new_ctx)
+ return carrier
+
+
@contextmanager
def override_ids(trace_id, span_id, ctx=None):
ctx = context.set_value(OVERRIDE_TRACE_ID_KEY, trace_id, context=ctx)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index aa5a5a08ad4..6e0b0766be3 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -148,9 +148,9 @@ def _make_task_span(msg: StartupDetails):
TraceContextTextMapPropagator().extract(msg.ti.context_carrier) if
msg.ti.context_carrier else None
)
ti = msg.ti
- span_name = f"task_run.{ti.task_id}"
+ span_name = f"worker.{ti.task_id}"
if ti.map_index is not None and ti.map_index >= 0:
- span_name += f"_{ti.map_index}"
+ span_name += f"[{ti.map_index}]"
with tracer.start_as_current_span(span_name, context=parent_context) as
span:
span.set_attributes(
{
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 0eab6a50afc..88424bfe53e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -32,9 +32,16 @@ from unittest.mock import call, patch
import pandas as pd
import pytest
+from opentelemetry import trace as otel_trace
+from opentelemetry.sdk.trace import TracerProvider
+from opentelemetry.sdk.trace.export import SimpleSpanProcessor
+from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
+from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
from task_sdk import FAKE_BUNDLE
from uuid6 import uuid7
+from airflow._shared.observability.traces import
OverrideableRandomIdGenerator, new_task_run_carrier
+from airflow.api_fastapi.execution_api.routes.task_instances import
_emit_task_span
from airflow.listeners import hookimpl
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import (
@@ -415,24 +422,30 @@ def
test_main_sends_reschedule_task_when_startup_reschedules(
def test_task_span_is_child_of_dag_run_span(make_ti_context):
- """Task span must be a child of the dag run span propagated via
context_carrier."""
- from opentelemetry.sdk.trace import TracerProvider
- from opentelemetry.sdk.trace.export import SimpleSpanProcessor
- from opentelemetry.sdk.trace.export.in_memory_span_exporter import
InMemorySpanExporter
- from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
-
- # Build a real SDK provider and exporter so we can inspect finished spans.
+ """Full trace hierarchy: dag_run → task_run.my_task (API server) →
worker.my_task (task runner)."""
+ # Single provider shared by all spans so contexts are compatible.
in_mem_exporter = InMemorySpanExporter()
- provider = TracerProvider()
+ provider = TracerProvider(id_generator=OverrideableRandomIdGenerator())
provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter))
- # Create a "dag run" span whose context we will propagate into the task.
+ # Step 1: create the dag run span and capture its carrier.
dag_run_tracer = provider.get_tracer("dag_run")
with dag_run_tracer.start_as_current_span("dag_run.test_dag") as
dag_run_span:
- carrier: dict[str, str] = {}
- TraceContextTextMapPropagator().inject(carrier)
+ dag_run_carrier: dict[str, str] = {}
+ TraceContextTextMapPropagator().inject(dag_run_carrier)
dag_run_span_ctx = dag_run_span.get_span_context()
+ # Step 2: derive the parent task span carrier (child of dag run), as the
scheduler does.
+ ti_model_tracer = provider.get_tracer("airflow.models.taskinstance")
+ with mock.patch("airflow.models.taskinstance.tracer", ti_model_tracer):
+ ti_carrier = new_task_run_carrier(dag_run_carrier)
+
+ # Extract the parent task span context (the stable span ID stored in
ti_carrier).
+ parent_task_span_ctx = otel_trace.get_current_span(
+ context=TraceContextTextMapPropagator().extract(ti_carrier)
+ ).get_span_context()
+
+ # Step 3: build StartupDetails with ti.context_carrier = ti_carrier.
what = StartupDetails(
ti=TaskInstance(
id=uuid7(),
@@ -441,7 +454,7 @@ def
test_task_span_is_child_of_dag_run_span(make_ti_context):
run_id="test_run",
try_number=1,
dag_version_id=uuid7(),
- context_carrier=carrier,
+ context_carrier=ti_carrier,
),
dag_rel_path="",
bundle_info=BundleInfo(name="my-bundle", version=None),
@@ -450,27 +463,45 @@ def
test_task_span_is_child_of_dag_run_span(make_ti_context):
sentry_integration="",
)
- task_tracer = provider.get_tracer("airflow.sdk.execution_time.task_runner")
- with mock.patch("airflow.sdk.execution_time.task_runner.tracer",
task_tracer):
- with _make_task_span(what) as span:
- task_span_ctx = span.get_span_context()
+ # Step 4: emit the worker span (task runner side).
+ task_runner_tracer =
provider.get_tracer("airflow.sdk.execution_time.task_runner")
+ with mock.patch("airflow.sdk.execution_time.task_runner.tracer",
task_runner_tracer):
+ with _make_task_span(what):
+ pass
- # The task span must share the dag run's trace ID.
- assert task_span_ctx.trace_id == dag_run_span_ctx.trace_id
+ # Step 5: emit the parent task span (API server side, as happens on task
completion).
+ mock_ti = mock.MagicMock()
+ mock_ti.dag_id = "test_dag"
+ mock_ti.task_id = "my_task"
+ mock_ti.run_id = "test_run"
+ mock_ti.try_number = 1
+ mock_ti.map_index = -1
+ mock_ti.queued_dttm = None
+ mock_ti.start_date = timezone.utcnow()
+ mock_ti.dag_run.context_carrier = dag_run_carrier
+ mock_ti.context_carrier = ti_carrier
+ api_tracer =
provider.get_tracer("airflow.api_fastapi.execution_api.routes.task_instances")
+ with
mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.tracer",
api_tracer):
+ _emit_task_span(mock_ti, TaskInstanceState.SUCCESS)
- # The task span's parent must be the dag run span.
finished = in_mem_exporter.get_finished_spans()
+
+ # task_run.my_task: emitted by API server, child of dag run, span_id ==
parent_task_span_ctx.span_id.
task_spans = [s for s in finished if s.name == "task_run.my_task"]
assert len(task_spans) == 1
- assert task_spans[0].parent is not None
- assert task_spans[0].parent.span_id == dag_run_span_ctx.span_id
-
- # Span attributes are set correctly.
- attrs = task_spans[0].attributes
- assert attrs["airflow.dag_id"] == "test_dag"
- assert attrs["airflow.task_id"] == "my_task"
- assert attrs["airflow.dag_run.run_id"] == "test_run"
- assert attrs["airflow.task_instance.try_number"] == 1
+ task_span = task_spans[0]
+ assert task_span.parent is not None
+ assert task_span.parent.span_id == dag_run_span_ctx.span_id
+ assert task_span.context.span_id == parent_task_span_ctx.span_id
+
+ # worker.my_task: created by task runner, child of the parent task span.
+ worker_spans = [s for s in finished if s.name == "worker.my_task"]
+ assert len(worker_spans) == 1
+ assert worker_spans[0].parent is not None
+ assert worker_spans[0].parent.span_id == parent_task_span_ctx.span_id
+
+ # All spans share the same trace ID.
+ assert {s.context.trace_id for s in finished} ==
{dag_run_span_ctx.trace_id}
def test_task_span_no_parent_when_no_context_carrier(make_ti_context):