This is an automated email from the ASF dual-hosted git repository. rahulvats pushed a commit to branch v3-2-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit d659d38965a5bddf852cc4d26a4dab67fd051c5f Author: Daniel Standish <[email protected]> AuthorDate: Mon Mar 23 18:56:13 2026 -0700 Introduce parent task spans and nest worker and trigger spans under them (#63839) This lets us tie together the worker and trigger phases of task execution. Also lets us see the delta between task queued time and task start time. (cherry picked from commit 0dc4d33b602c9fbeeebd1fa5a6fba3c06e120f5a) --- .../execution_api/routes/task_instances.py | 46 ++++++++ .../src/airflow/executors/workloads/task.py | 1 - .../src/airflow/jobs/triggerer_job_runner.py | 88 +++++++++++----- airflow-core/src/airflow/models/dagrun.py | 24 ++++- airflow-core/src/airflow/models/taskinstance.py | 13 ++- airflow-core/src/airflow/models/taskmap.py | 16 +++ airflow-core/tests/integration/otel/test_otel.py | 5 +- .../versions/head/test_task_instances.py | 116 +++++++++++++++++++++ airflow-core/tests/unit/jobs/test_triggerer_job.py | 101 +++++++++++++++++- airflow-core/tests/unit/models/test_dagrun.py | 35 ++----- .../tests/unit/models/test_taskinstance.py | 104 ++++++++++++++++++ .../observability/traces/__init__.py | 12 +++ .../src/airflow/sdk/execution_time/task_runner.py | 4 +- .../task_sdk/execution_time/test_task_runner.py | 87 +++++++++++----- 14 files changed, 562 insertions(+), 90 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index 96d8f3a6c86..5f5073c916b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -29,6 +29,9 @@ import attrs import structlog from cadwyn import VersionedAPIRouter from fastapi import Body, HTTPException, Query, Security, status +from opentelemetry import trace +from opentelemetry.trace import StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import JsonValue from sqlalchemy import and_, func, or_, tuple_, update from sqlalchemy.engine import CursorResult @@ -37,6 +40,7 @@ from sqlalchemy.orm import joinedload from sqlalchemy.sql import select from structlog.contextvars import bind_contextvars +from airflow._shared.observability.traces import override_ids from airflow._shared.timezones import timezone from airflow.api_fastapi.common.dagbag import DagBagDep, get_latest_version_of_dag from airflow.api_fastapi.common.db.common import SessionDep @@ -87,6 +91,7 @@ ti_id_router = VersionedAPIRouter( log = structlog.get_logger(__name__) +tracer = trace.get_tracer(__name__) @ti_id_router.patch( @@ -431,6 +436,46 @@ def ti_update_state( ) +def _emit_task_span(ti, state): + # just to be safe + if not ti.dag_run: + return + if not isinstance(ti.dag_run.context_carrier, dict): + return + if not isinstance(ti.context_carrier, dict): + return + dr_ctx = TraceContextTextMapPropagator().extract(ti.dag_run.context_carrier) + + ti_ctx = TraceContextTextMapPropagator().extract(ti.context_carrier) + ti_span = trace.get_current_span(context=ti_ctx) + span_context = ti_span.get_span_context() + start_time_candidates = (x for x in (ti.queued_dttm, ti.start_date, timezone.utcnow()) if x) + name = f"task_run.{ti.task_id}" + if ti.map_index >= 0: + name += f"[{ti.map_index}]" + with override_ids(span_context.trace_id, span_context.span_id): + span = tracer.start_span( + name=name, + start_time=int(min(start_time_candidates).timestamp() * 1e9), + context=dr_ctx, + ) + + span.set_attributes( + { + "airflow.dag_id": ti.dag_id, + "airflow.task_id": ti.task_id, + "airflow.dag_run.run_id": ti.run_id, + "airflow.task_instance.try_number": ti.try_number, + "airflow.task_instance.map_index": ti.map_index if ti.map_index is not None else -1, + "airflow.task_instance.state": state, + "airflow.task_instance.id": ti.id, + } + ) + status_code = StatusCode.OK if state == TaskInstanceState.SUCCESS else StatusCode.ERROR + span.set_status(status_code) + span.end() + + def _handle_fail_fast_for_dag(ti: TI, dag_id: str, session: SessionDep, dag_bag: DagBagDep) -> None: dr = ti.dag_run @@ -479,6 +524,7 @@ def _create_ti_state_update_query_and_update_state( ti_patch_payload.outlet_events, session, ) + _emit_task_span(ti, state=updated_state) elif isinstance(ti_patch_payload, TIDeferredStatePayload): # Calculate timeout if it was passed timeout = None diff --git a/airflow-core/src/airflow/executors/workloads/task.py b/airflow-core/src/airflow/executors/workloads/task.py index a5939cf4244..4ca8c310fb5 100644 --- a/airflow-core/src/airflow/executors/workloads/task.py +++ b/airflow-core/src/airflow/executors/workloads/task.py @@ -86,7 +86,6 @@ class ExecuteTask(BaseDagBundleWorkload): from airflow.utils.helpers import log_filename_template_renderer ser_ti = TaskInstanceDTO.model_validate(ti, from_attributes=True) - ser_ti.context_carrier = ti.dag_run.context_carrier if not bundle_info: bundle_info = BundleInfo( name=ti.dag_model.bundle_name, diff --git a/airflow-core/src/airflow/jobs/triggerer_job_runner.py b/airflow-core/src/airflow/jobs/triggerer_job_runner.py index 1406283c05c..44c28a7a539 100644 --- a/airflow-core/src/airflow/jobs/triggerer_job_runner.py +++ b/airflow-core/src/airflow/jobs/triggerer_job_runner.py @@ -35,6 +35,9 @@ from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, ClassVar, Literal, T import anyio import attrs import structlog +from opentelemetry import trace +from opentelemetry.trace import Status, StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from pydantic import BaseModel, Field, TypeAdapter from sqlalchemy import func, select from structlog.contextvars import bind_contextvars as bind_log_contextvars @@ -87,6 +90,7 @@ from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import provide_session if TYPE_CHECKING: + from opentelemetry.util._decorator import _AgnosticContextManager from sqlalchemy.orm import Session from structlog.typing import FilteringBoundLogger, WrappedLogger @@ -96,6 +100,34 @@ if TYPE_CHECKING: from airflow.sdk.types import RuntimeTaskInstanceProtocol as RuntimeTI logger = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) + + +def _make_trigger_span( + ti: TaskInstanceDTO | None, trigger_id: int, name: str +) -> _AgnosticContextManager[trace.Span]: + parent_context = ( + TraceContextTextMapPropagator().extract(ti.context_carrier) if ti and ti.context_carrier else None + ) + attributes: dict[str, str | int] = { + "airflow.trigger.name": name, + } + if isinstance(ti, TaskInstanceDTO): + span_name = f"trigger.{ti.task_id}" if ti else f"trigger.{trigger_id}" + if ti.map_index >= 0: + span_name += f"_{ti.map_index}" + attributes = { + **attributes, + "airflow.dag_id": ti.dag_id, + "airflow.task_id": ti.task_id, + "airflow.dag_run.run_id": ti.run_id, + "airflow.task_instance.try_number": ti.try_number, + "airflow.task_instance.map_index": ti.map_index, + } + else: + span_name = f"trigger.{name}" + return tracer.start_as_current_span(span_name, attributes=attributes, context=parent_context) + __all__ = [ "TriggerRunner", @@ -1179,30 +1211,38 @@ class TriggerRunner: name = self.triggers[trigger_id]["name"] self.log.info("trigger %s starting", name) - try: - async for event in trigger.run(): - await self.log.ainfo( - "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event - ) - self.triggers[trigger_id]["events"] += 1 - self.events.append((trigger_id, event)) - except asyncio.CancelledError: - # We get cancelled by the scheduler changing the task state. But if we do lets give a nice error - # message about it - if timeout := timeout_after: - timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout - if timeout < timezone.utcnow(): - await self.log.aerror("Trigger cancelled due to timeout") - raise - finally: - # CancelledError will get injected when we're stopped - which is - # fine, the cleanup process will understand that, but we want to - # allow triggers a chance to cleanup, either in that case or if - # they exit cleanly. Exception from cleanup methods are ignored. - with suppress(Exception): - await trigger.cleanup() - - await self.log.ainfo("trigger completed", name=name) + with _make_trigger_span(ti=trigger.task_instance, trigger_id=trigger_id, name=name) as span: + try: + async for event in trigger.run(): + await self.log.ainfo( + "Trigger fired event", name=self.triggers[trigger_id]["name"], result=event + ) + self.triggers[trigger_id]["events"] += 1 + self.events.append((trigger_id, event)) + span.set_status(Status(StatusCode.OK)) + except asyncio.CancelledError as e: + # We get cancelled by the scheduler changing the task state. But if we do lets give a nice error + # message about it + if timeout := timeout_after: + timeout = timeout.replace(tzinfo=timezone.utc) if not timeout.tzinfo else timeout + if timeout < timezone.utcnow(): + await self.log.aerror("Trigger cancelled due to timeout") + span.set_status(Status(StatusCode.ERROR), description=str(e)) + raise + span.set_status(Status(StatusCode.OK), description=str(e)) + raise + except Exception as e: + span.set_status(Status(StatusCode.ERROR), description=str(e)) + raise + finally: + # CancelledError will get injected when we're stopped - which is + # fine, the cleanup process will understand that, but we want to + # allow triggers a chance to cleanup, either in that case or if + # they exit cleanly. Exception from cleanup methods are ignored. + with suppress(Exception): + await trigger.cleanup() + + await self.log.ainfo("trigger completed", name=name) def get_trigger_by_classpath(self, classpath: str) -> type[BaseTrigger]: """ diff --git a/airflow-core/src/airflow/models/dagrun.py b/airflow-core/src/airflow/models/dagrun.py index 73a5a3a875a..20ec118f33e 100644 --- a/airflow-core/src/airflow/models/dagrun.py +++ b/airflow-core/src/airflow/models/dagrun.py @@ -1027,21 +1027,33 @@ class DagRun(Base, LoggingMixin): return leaf_tis def _emit_dagrun_span(self, state: DagRunState): - ctx = TraceContextTextMapPropagator().extract(self.context_carrier or {}) + # just to be safe + if not isinstance(self.context_carrier, dict): + return + + ctx = TraceContextTextMapPropagator().extract(self.context_carrier) span = trace.get_current_span(context=ctx) span_context = span.get_span_context() with override_ids(span_context.trace_id, span_context.span_id): - attributes = { + attributes: dict[str, str] = { "airflow.dag_id": str(self.dag_id), "airflow.dag_run.run_id": self.run_id, } + if self.start_date: + attributes["airflow.dag_run.start_date"] = str(self.start_date) + if self.end_date: + attributes["airflow.dag_run.end_date"] = str(self.end_date) + if self.queued_at: + attributes["airflow.dag_run.queued_at"] = str(self.queued_at) + if self.created_at: + attributes["airflow.dag_run.created_at"] = str(self.created_at) if self.logical_date: attributes["airflow.dag_run.logical_date"] = str(self.logical_date) if self.partition_key: attributes["airflow.dag_run.partition_key"] = str(self.partition_key) span = tracer.start_span( name=f"dag_run.{self.dag_id}", - start_time=int((self.start_date or timezone.utcnow()).timestamp() * 1e9), + start_time=int((self.queued_at or self.start_date or timezone.utcnow()).timestamp() * 1e9), attributes=attributes, context=context.Context(), ) @@ -1771,7 +1783,11 @@ class DagRun(Base, LoggingMixin): created_counts[task.task_type] += 1 for map_index in indexes: yield TI.insert_mapping( - self.run_id, task, map_index=map_index, dag_version_id=dag_version_id + self.run_id, + task, + map_index=map_index, + dag_version_id=dag_version_id, + dag_run=self, ) creator = create_ti_mapping diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 443468161f8..e212ca68504 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -32,6 +32,7 @@ from uuid import UUID import attrs import dill import uuid6 +from opentelemetry import trace from sqlalchemy import ( JSON, Float, @@ -67,6 +68,7 @@ from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from airflow import settings from airflow._shared.observability.metrics.dual_stats_manager import DualStatsManager from airflow._shared.observability.metrics.stats import Stats +from airflow._shared.observability.traces import new_dagrun_trace_carrier, new_task_run_carrier from airflow._shared.timezones import timezone from airflow.assets.manager import asset_manager from airflow.configuration import conf @@ -102,7 +104,7 @@ from airflow.utils.state import DagRunState, State, TaskInstanceState TR = TaskReschedule log = logging.getLogger(__name__) - +tracer = trace.get_tracer(__name__) if TYPE_CHECKING: from datetime import datetime @@ -382,7 +384,7 @@ def clear_task_instances( for instance in tis: run_ids_by_dag_id[instance.dag_id].add(instance.run_id) - drs = session.scalars( + drs: Iterable[DagRun] = session.scalars( select(DagRun).where( or_( *( @@ -397,6 +399,7 @@ def clear_task_instances( # Always update clear_number and queued_at when clearing tasks, regardless of state dr.clear_number += 1 dr.queued_at = timezone.utcnow() + dr.context_carrier = new_dagrun_trace_carrier() _recalculate_dagrun_queued_at_deadlines(dr, dr.queued_at, session) @@ -425,6 +428,8 @@ def clear_task_instances( if dag_run_state == DagRunState.QUEUED: dr.last_scheduling_decision = None dr.start_date = None + for ti in tis: + ti.context_carrier = new_task_run_carrier(ti.dag_run.context_carrier) session.flush() @@ -679,7 +684,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload): @staticmethod def insert_mapping( - run_id: str, task: Operator, map_index: int, dag_version_id: UUID | None + run_id: str, task: Operator, map_index: int, *, dag_version_id: UUID | None, dag_run: DagRun ) -> dict[str, Any]: """ Insert mapping. @@ -689,6 +694,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload): priority_weight = task.weight_rule.get_weight( TaskInstance(task=task, run_id=run_id, map_index=map_index, dag_version_id=dag_version_id) ) + context_carrier = new_task_run_carrier(dag_run.context_carrier) return { "dag_id": task.dag_id, @@ -710,6 +716,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload): "map_index": map_index, "_task_display_property_value": task.task_display_name, "dag_version_id": dag_version_id, + "context_carrier": context_carrier, } @reconstructor diff --git a/airflow-core/src/airflow/models/taskmap.py b/airflow-core/src/airflow/models/taskmap.py index 18d09d6aa56..60486b8ce86 100644 --- a/airflow-core/src/airflow/models/taskmap.py +++ b/airflow-core/src/airflow/models/taskmap.py @@ -24,9 +24,11 @@ import enum from collections.abc import Collection, Iterable, Sequence from typing import TYPE_CHECKING, Any +from opentelemetry import trace from sqlalchemy import CheckConstraint, ForeignKeyConstraint, Integer, String, func, or_, select from sqlalchemy.orm import Mapped, mapped_column +from airflow._shared.observability.traces import new_task_run_carrier from airflow.models.base import COLLATION_ARGS, ID_LEN, TaskInstanceDependencies from airflow.models.dag_version import DagVersion from airflow.utils.db import exists_query @@ -38,6 +40,7 @@ if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance from airflow.serialization.definitions.mappedoperator import Operator +tracer = trace.get_tracer(__name__) class TaskMapVariant(enum.Enum): @@ -242,6 +245,18 @@ class TaskMap(TaskInstanceDependencies): else: dag_version_id = None + if unmapped_ti: + dr = unmapped_ti.dag_run + else: + from airflow.models import DagRun + + dr = session.scalar( + select(DagRun).where( + DagRun.dag_id == task.dag_id, + DagRun.run_id == run_id, + ) + ) + for index in indexes_to_map: # TODO: Make more efficient with bulk_insert_mappings/bulk_save_mappings. ti = TaskInstance( @@ -254,6 +269,7 @@ class TaskMap(TaskInstanceDependencies): task.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) + ti.context_carrier = new_task_run_carrier(dr.context_carrier) ti.refresh_from_task(task) # session.merge() loses task information. all_expanded_tis.append(ti) diff --git a/airflow-core/tests/integration/otel/test_otel.py b/airflow-core/tests/integration/otel/test_otel.py index 0d40156c45e..6852b47af04 100644 --- a/airflow-core/tests/integration/otel/test_otel.py +++ b/airflow-core/tests/integration/otel/test_otel.py @@ -508,9 +508,10 @@ class TestOtelIntegration: nested = get_span_hierarchy() assert nested == { - "sub_span1": "task_run.task1", - "task_run.task1": "dag_run.otel_test_dag", "dag_run.otel_test_dag": None, + "sub_span1": "worker.task1", + "task_run.task1": "dag_run.otel_test_dag", + "worker.task1": "task_run.task1", } def start_scheduler(self, capture_output: bool = False): diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 7cc7aa85594..7f766ede71e 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -24,13 +24,21 @@ from uuid import UUID, uuid4 import pytest import uuid6 +from opentelemetry import trace as otel_trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from sqlalchemy import select, update from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session +from airflow._shared.observability.traces import OverrideableRandomIdGenerator from airflow._shared.timezones import timezone from airflow.api_fastapi.auth.tokens import JWTValidator from airflow.api_fastapi.execution_api.app import lifespan +from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span from airflow.exceptions import AirflowSkipException from airflow.models import RenderedTaskInstanceFields, TaskReschedule, Trigger from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, AssetModel @@ -3242,3 +3250,111 @@ class TestTokenTypeValidation: payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) assert resp.status_code in [200, 204] + + +class TestEmitTaskSpan: + """Tests for the _emit_task_span function in the execution API task-instance route.""" + + @pytest.fixture(autouse=True) + def sdk_tracer_provider(self): + self.exporter = InMemorySpanExporter() + provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) + provider.add_span_processor(SimpleSpanProcessor(self.exporter)) + test_tracer = provider.get_tracer("test") + with mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.tracer", test_tracer): + yield + + def _make_carriers(self): + """Return a (dr_carrier, ti_carrier) pair built with a real SDK provider.""" + p = TracerProvider() + t = p.get_tracer("setup") + dr_span = t.start_span("dr") + dr_ctx = otel_trace.set_span_in_context(dr_span) + dr_carrier: dict = {} + TraceContextTextMapPropagator().inject(dr_carrier, context=dr_ctx) + ti_span = t.start_span("ti", context=dr_ctx) + ti_ctx = otel_trace.set_span_in_context(ti_span) + ti_carrier: dict = {} + TraceContextTextMapPropagator().inject(ti_carrier, context=ti_ctx) + return dr_carrier, ti_carrier + + def _make_ti(self, task_id="my_task", map_index=-1, queued_dttm=None, start_date=None): + dr_carrier, ti_carrier = self._make_carriers() + ti = mock.MagicMock() + ti.dag_id = "test_dag" + ti.task_id = task_id + ti.run_id = "test_run" + ti.try_number = 1 + ti.map_index = map_index + ti.queued_dttm = queued_dttm + ti.start_date = start_date or DEFAULT_START_DATE + ti.dag_run.context_carrier = dr_carrier + ti.context_carrier = ti_carrier + return ti + + def test_emit_task_span_success_sets_ok_status(self): + _emit_task_span(self._make_ti(), TaskInstanceState.SUCCESS) + + spans = self.exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.OK + + def test_emit_task_span_failed_sets_error_status(self): + _emit_task_span(self._make_ti(), TaskInstanceState.FAILED) + + spans = self.exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].status.status_code == StatusCode.ERROR + + def test_emit_task_span_sets_attributes(self): + ti = self._make_ti(task_id="my_task", map_index=2) + _emit_task_span(ti, TaskInstanceState.SUCCESS) + + attrs = self.exporter.get_finished_spans()[0].attributes + assert attrs["airflow.dag_id"] == "test_dag" + assert attrs["airflow.task_id"] == "my_task" + assert attrs["airflow.dag_run.run_id"] == "test_run" + assert attrs["airflow.task_instance.try_number"] == 1 + assert attrs["airflow.task_instance.map_index"] == 2 + assert attrs["airflow.task_instance.state"] == TaskInstanceState.SUCCESS + + def test_emit_task_span_name_unmapped(self): + _emit_task_span(self._make_ti(task_id="my_task", map_index=-1), TaskInstanceState.SUCCESS) + assert self.exporter.get_finished_spans()[0].name == "task_run.my_task" + + def test_emit_task_span_name_mapped(self): + _emit_task_span(self._make_ti(task_id="my_task", map_index=3), TaskInstanceState.SUCCESS) + assert self.exporter.get_finished_spans()[0].name == "task_run.my_task[3]" + + def test_emit_task_span_start_time_uses_queued_dttm(self): + queued_dttm = timezone.parse("2024-01-01T10:00:00Z") + start_date = timezone.parse("2024-01-01T10:05:00Z") + ti = self._make_ti(queued_dttm=queued_dttm, start_date=start_date) + _emit_task_span(ti, TaskInstanceState.SUCCESS) + + assert self.exporter.get_finished_spans()[0].start_time == int(queued_dttm.timestamp() * 1e9) + + def test_emit_task_span_start_time_falls_back_to_start_date(self): + start_date = timezone.parse("2024-01-01T10:05:00Z") + ti = self._make_ti(queued_dttm=None, start_date=start_date) + _emit_task_span(ti, TaskInstanceState.SUCCESS) + + assert self.exporter.get_finished_spans()[0].start_time == int(start_date.timestamp() * 1e9) + + def test_emit_task_span_skips_if_no_ti_carrier(self): + ti = mock.MagicMock() + ti.dag_run.context_carrier = { + "traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01" + } + ti.context_carrier = None + + _emit_task_span(ti, TaskInstanceState.SUCCESS) + assert len(self.exporter.get_finished_spans()) == 0 + + def test_emit_task_span_skips_if_no_dagrun_carrier(self): + ti = mock.MagicMock() + ti.dag_run.context_carrier = None + ti.context_carrier = {"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"} + + _emit_task_span(ti, TaskInstanceState.SUCCESS) + assert len(self.exporter.get_finished_spans()) == 0 diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index 802a34192e3..3761189bfeb 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -24,18 +24,26 @@ import os import selectors import time import typing +import uuid from collections.abc import AsyncIterator from socket import socket from typing import TYPE_CHECKING, Any +from unittest import mock from unittest.mock import ANY, AsyncMock, MagicMock, patch import pendulum import pytest from asgiref.sync import sync_to_async +from opentelemetry import trace as otel_trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from structlog.typing import FilteringBoundLogger from airflow._shared.timezones import timezone from airflow.executors import workloads +from airflow.executors.workloads.task import TaskInstanceDTO from airflow.jobs.job import Job from airflow.jobs.triggerer_job_runner import ( ToTriggerRunner, @@ -45,6 +53,7 @@ from airflow.jobs.triggerer_job_runner import ( TriggerLoggingFactory, TriggerRunner, TriggerRunnerSupervisor, + _make_trigger_span, messages, ) from airflow.models import Connection, DagModel, DagRun, Trigger, Variable @@ -318,7 +327,6 @@ def test_trigger_logger_close(): def test_trigger_logger_fd_closed_when_removed(session): - trigger = TimeDeltaTrigger(datetime.timedelta(seconds=0.5)) create_trigger_in_db(session, trigger) @@ -349,11 +357,12 @@ class TestTriggerRunner: mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.timeout_after = None mock_trigger.run.side_effect = asyncio.CancelledError() + mock_trigger.task_instance = MagicMock() + mock_trigger.task_instance.map_index = -1 with pytest.raises(asyncio.CancelledError): asyncio.run(trigger_runner.run_trigger(1, mock_trigger)) - # @pytest.mark.asyncio def test_run_inline_trigger_timeout(self, session, cap_structlog) -> None: trigger_runner = TriggerRunner() trigger_runner.triggers = { @@ -361,6 +370,8 @@ class TestTriggerRunner: } mock_trigger = MagicMock(spec=BaseTrigger) mock_trigger.run.side_effect = asyncio.CancelledError() + mock_trigger.task_instance = MagicMock() + mock_trigger.task_instance.map_index = -1 with pytest.raises(asyncio.CancelledError): asyncio.run( @@ -1358,3 +1369,89 @@ class TestTriggererMessageTypes: + "\n".join(f" - {t}" for t in sorted(task_diff)) + "\n\nEither handle these types in ToTriggerRunner or update in_task_but_not_in_trigger_runner list." ) + + +class TestMakeTriggerSpan: + """Tests for the _make_trigger_span helper in the triggerer job runner.""" + + @pytest.fixture(autouse=True) + def sdk_tracer_provider(self): + self.exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(self.exporter)) + test_tracer = provider.get_tracer("test") + with mock.patch("airflow.jobs.triggerer_job_runner.tracer", test_tracer): + yield + + def _make_ti_dto(self, task_id="my_task", map_index=-1, context_carrier=None): + return TaskInstanceDTO( + id=uuid.uuid4(), + dag_version_id=uuid.uuid4(), + task_id=task_id, + dag_id="test_dag", + run_id="test_run", + try_number=1, + map_index=map_index, + pool_slots=1, + queue="default", + priority_weight=1, + context_carrier=context_carrier, + ) + + def test_make_trigger_span_name_with_task_instance(self): + ti = self._make_ti_dto(task_id="sensor_task", map_index=-1) + with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"): + pass + assert self.exporter.get_finished_spans()[0].name == "trigger.sensor_task" + + def test_make_trigger_span_name_with_mapped_task(self): + ti = self._make_ti_dto(task_id="sensor_task", map_index=2) + with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"): + pass + assert self.exporter.get_finished_spans()[0].name == "trigger.sensor_task_2" + + def test_make_trigger_span_name_without_task_instance(self): + with _make_trigger_span(ti=None, trigger_id=42, name="Some trigger name"): + pass + assert self.exporter.get_finished_spans()[0].name == "trigger.Some trigger name" + + def test_make_trigger_span_uses_task_context_carrier(self): + # Build a valid ti carrier from a separate provider so we have a known parent span. + setup_provider = TracerProvider() + setup_tracer = setup_provider.get_tracer("setup") + parent_span = setup_tracer.start_span("ti_parent") + parent_ctx = otel_trace.set_span_in_context(parent_span) + ti_carrier: dict = {} + TraceContextTextMapPropagator().inject(ti_carrier, context=parent_ctx) + expected_parent_span_id = parent_span.get_span_context().span_id + + ti = self._make_ti_dto(context_carrier=ti_carrier) + with _make_trigger_span(ti=ti, trigger_id=1, name="MySensor"): + pass + + spans = self.exporter.get_finished_spans() + assert len(spans) == 1 + assert spans[0].parent is not None + assert spans[0].parent.span_id == expected_parent_span_id + + def test_make_trigger_span_sets_attributes_with_ti(self): + ti = self._make_ti_dto(task_id="my_task", map_index=1) + with _make_trigger_span(ti=ti, trigger_id=5, name="MyTrigger"): + pass + + attrs = self.exporter.get_finished_spans()[0].attributes + assert attrs["airflow.trigger.name"] == "MyTrigger" + assert attrs["airflow.dag_id"] == "test_dag" + assert attrs["airflow.task_id"] == "my_task" + assert attrs["airflow.dag_run.run_id"] == "test_run" + assert attrs["airflow.task_instance.try_number"] == 1 + assert attrs["airflow.task_instance.map_index"] == 1 + + def test_make_trigger_span_sets_only_trigger_name_without_ti(self): + with _make_trigger_span(ti=None, trigger_id=99, name="OnlyTrigger"): + pass + + attrs = self.exporter.get_finished_spans()[0].attributes + assert attrs["airflow.trigger.name"] == "OnlyTrigger" + assert "airflow.dag_id" not in attrs + assert "airflow.task_id" not in attrs diff --git a/airflow-core/tests/unit/models/test_dagrun.py b/airflow-core/tests/unit/models/test_dagrun.py index 7ad78292a1e..b446db9f2b9 100644 --- a/airflow-core/tests/unit/models/test_dagrun.py +++ b/airflow-core/tests/unit/models/test_dagrun.py @@ -27,7 +27,12 @@ from unittest.mock import ANY, call import pendulum import pytest +from opentelemetry import trace as otel_trace from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace import StatusCode +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from sqlalchemy import ( func, select, @@ -37,6 +42,7 @@ from sqlalchemy.orm import joinedload from airflow import settings from airflow._shared.observability.metrics.stats import Stats +from airflow._shared.observability.traces import OverrideableRandomIdGenerator from airflow._shared.timezones import timezone from airflow.callbacks.callback_requests import DagCallbackRequest, DagRunContext from airflow.models.dag import DagModel, infer_automated_data_interval @@ -3441,13 +3447,6 @@ class TestDagRunTracing: def test_emit_dagrun_span_uses_context_carrier_ids(self, dag_maker, session): """The emitted span should inherit trace_id/span_id from the context_carrier.""" - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator - - from airflow._shared.observability.traces import OverrideableRandomIdGenerator - in_mem_exporter = InMemorySpanExporter() provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) @@ -3471,8 +3470,6 @@ class TestDagRunTracing: # Decode the expected trace_id/span_id from the stored context_carrier ctx = TraceContextTextMapPropagator().extract(dr.context_carrier) - from opentelemetry import trace as otel_trace - stored_span = otel_trace.get_current_span(context=ctx) stored_ctx = stored_span.get_span_context() @@ -3482,13 +3479,6 @@ class TestDagRunTracing: @pytest.mark.parametrize("final_state", [DagRunState.SUCCESS, DagRunState.FAILED]) def test_emit_dagrun_span_attributes_and_status(self, dag_maker, session, final_state): """The emitted span should have the correct name, attributes, and status code.""" - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - from opentelemetry.trace import StatusCode - - from airflow._shared.observability.traces import OverrideableRandomIdGenerator - in_mem_exporter = InMemorySpanExporter() provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) @@ -3527,12 +3517,6 @@ class TestDagRunTracing: context_carrier was cleared/backfilled to NULL. Per OTel spec, missing context results in a new root span rather than a crash. """ - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - - from airflow._shared.observability.traces import OverrideableRandomIdGenerator - in_mem_exporter = InMemorySpanExporter() provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) @@ -3555,5 +3539,8 @@ class TestDagRunTracing: # A root span should still be emitted spans = in_mem_exporter.get_finished_spans() - assert len(spans) == 1 - assert spans[0].name == f"dag_run.{dr.dag_id}" + if isinstance(carrier_value, dict): + assert len(spans) == 1 + assert spans[0].name == f"dag_run.{dr.dag_id}" + else: + assert len(spans) == 0 diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 82b9adc3162..bb058d1a737 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -30,11 +30,15 @@ import pendulum import pytest import time_machine import uuid6 +from opentelemetry import trace as otel_trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from sqlalchemy import delete, func, select from sqlalchemy.exc import IntegrityError from airflow import settings from airflow._shared.observability.metrics.stats import Stats +from airflow._shared.observability.traces import new_dagrun_trace_carrier, new_task_run_carrier from airflow._shared.timezones import timezone from airflow.exceptions import ( AirflowException, @@ -3298,3 +3302,103 @@ def test_get_dagrun_loaded_but_none_returns_dagrun(dag_maker, session): assert dr_from_ti is not None assert dr_from_ti == dr + + +class TestMakeTaskCarrier: + """Tests for the _make_task_carrier helper.""" + + @pytest.fixture(autouse=True) + def sdk_tracer_provider(self): + provider = TracerProvider() + real_tracer = provider.get_tracer("airflow.models.taskinstance") + with ( + mock.patch("airflow.models.taskinstance.tracer", real_tracer), + mock.patch("airflow._shared.observability.traces.tracer", real_tracer), + ): + yield + + def test_make_task_carrier_returns_traceparent(self): + carrier = new_task_run_carrier(new_dagrun_trace_carrier()) + assert isinstance(carrier, dict) + assert "traceparent" in carrier + + def test_make_task_carrier_child_of_parent(self): + parent_carrier = new_dagrun_trace_carrier() + child_carrier = new_task_run_carrier(parent_carrier) + + propagator = TraceContextTextMapPropagator() + parent_trace_id = ( + otel_trace.get_current_span(context=propagator.extract(parent_carrier)) + .get_span_context() + .trace_id + ) + child_trace_id = ( + otel_trace.get_current_span(context=propagator.extract(child_carrier)).get_span_context().trace_id + ) + assert child_trace_id == parent_trace_id + assert child_trace_id != 0 + + def test_make_task_carrier_with_none_carrier(self): + carrier = new_task_run_carrier(None) + assert isinstance(carrier, dict) + assert "traceparent" in carrier + + [email protected]_test +def test_insert_mapping_includes_context_carrier(dag_maker, session): + """insert_mapping should include a context_carrier with a traceparent derived from the dag run.""" + provider = TracerProvider() + real_tracer = provider.get_tracer("airflow.models.taskinstance") + with ( + mock.patch("airflow.models.taskinstance.tracer", real_tracer), + mock.patch("airflow._shared.observability.traces.tracer", real_tracer), + ): + with dag_maker("test_insert_mapping_carrier"): + EmptyOperator(task_id="t1") + session.flush() + + # Get the scheduler-side operator (has a proper PriorityWeightStrategy, not the enum weight_rule). + op = create_scheduler_operator(dag_maker.dag.get_task("t1")) + + # Mock the DagRun to avoid inserting into the dag_run table (schema migrations may be pending). + dag_run = mock.MagicMock() + dag_run.context_carrier = new_dagrun_trace_carrier() + + mapping = TaskInstance.insert_mapping( + run_id="test_run", + task=op, + map_index=0, + dag_version_id=None, + dag_run=dag_run, + ) + + assert "context_carrier" in mapping + assert mapping["context_carrier"] is not None + assert "traceparent" in mapping["context_carrier"] + + [email protected]_test +def test_clear_task_instances_resets_context_carrier(dag_maker, session): + """clear_task_instances should assign fresh context carriers to both the TI and its dag run.""" + provider = TracerProvider() + real_tracer = provider.get_tracer("airflow.models.taskinstance") + with ( + mock.patch("airflow.models.taskinstance.tracer", real_tracer), + mock.patch("airflow._shared.observability.traces.tracer", real_tracer), + ): + with dag_maker("test_clear_carrier"): + EmptyOperator(task_id="t1") + dag_run = dag_maker.create_dagrun() + ti = dag_run.get_task_instance("t1", session=session) + ti.state = TaskInstanceState.SUCCESS + # Set an explicit carrier so we can verify it changes. + ti.context_carrier = {"traceparent": "00-aaaaaaaaaaaaaaaaaaaaaaaaaaaa0001-bbbbbbbbbbbbbbbb-01"} + session.flush() + + original_ti_traceparent = ti.context_carrier["traceparent"] + original_dr_traceparent = dag_run.context_carrier["traceparent"] + + clear_task_instances([ti], session) + + assert ti.context_carrier["traceparent"] != original_ti_traceparent + assert dag_run.context_carrier["traceparent"] != original_dr_traceparent diff --git a/shared/observability/src/airflow_shared/observability/traces/__init__.py b/shared/observability/src/airflow_shared/observability/traces/__init__.py index 53163e45b97..dc3532262d1 100644 --- a/shared/observability/src/airflow_shared/observability/traces/__init__.py +++ b/shared/observability/src/airflow_shared/observability/traces/__init__.py @@ -34,6 +34,7 @@ from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapProp if TYPE_CHECKING: from configparser import ConfigParser log = logging.getLogger(__name__) +tracer = trace.get_tracer(__name__) OVERRIDE_SPAN_ID_KEY = context.create_key("override_span_id") OVERRIDE_TRACE_ID_KEY = context.create_key("override_trace_id") @@ -70,6 +71,17 @@ def new_dagrun_trace_carrier() -> dict[str, str]: return carrier +def new_task_run_carrier(dag_run_context_carrier): + parent_context = ( + TraceContextTextMapPropagator().extract(dag_run_context_carrier) if dag_run_context_carrier else None + ) + span = tracer.start_span("notused", context=parent_context) # intentionally never closed + new_ctx = trace.set_span_in_context(span) + carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(carrier, context=new_ctx) + return carrier + + @contextmanager def override_ids(trace_id, span_id, ctx=None): ctx = context.set_value(OVERRIDE_TRACE_ID_KEY, trace_id, context=ctx) diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index aa5a5a08ad4..6e0b0766be3 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -148,9 +148,9 @@ def _make_task_span(msg: StartupDetails): TraceContextTextMapPropagator().extract(msg.ti.context_carrier) if msg.ti.context_carrier else None ) ti = msg.ti - span_name = f"task_run.{ti.task_id}" + span_name = f"worker.{ti.task_id}" if ti.map_index is not None and ti.map_index >= 0: - span_name += f"_{ti.map_index}" + span_name += f"[{ti.map_index}]" with tracer.start_as_current_span(span_name, context=parent_context) as span: span.set_attributes( { diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index 0eab6a50afc..88424bfe53e 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -32,9 +32,16 @@ from unittest.mock import call, patch import pandas as pd import pytest +from opentelemetry import trace as otel_trace +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from task_sdk import FAKE_BUNDLE from uuid6 import uuid7 +from airflow._shared.observability.traces import OverrideableRandomIdGenerator, new_task_run_carrier +from airflow.api_fastapi.execution_api.routes.task_instances import _emit_task_span from airflow.listeners import hookimpl from airflow.providers.standard.operators.python import PythonOperator from airflow.sdk import ( @@ -415,24 +422,30 @@ def test_main_sends_reschedule_task_when_startup_reschedules( def test_task_span_is_child_of_dag_run_span(make_ti_context): - """Task span must be a child of the dag run span propagated via context_carrier.""" - from opentelemetry.sdk.trace import TracerProvider - from opentelemetry.sdk.trace.export import SimpleSpanProcessor - from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator - - # Build a real SDK provider and exporter so we can inspect finished spans. + """Full trace hierarchy: dag_run → task_run.my_task (API server) → worker.my_task (task runner).""" + # Single provider shared by all spans so contexts are compatible. in_mem_exporter = InMemorySpanExporter() - provider = TracerProvider() + provider = TracerProvider(id_generator=OverrideableRandomIdGenerator()) provider.add_span_processor(SimpleSpanProcessor(in_mem_exporter)) - # Create a "dag run" span whose context we will propagate into the task. + # Step 1: create the dag run span and capture its carrier. dag_run_tracer = provider.get_tracer("dag_run") with dag_run_tracer.start_as_current_span("dag_run.test_dag") as dag_run_span: - carrier: dict[str, str] = {} - TraceContextTextMapPropagator().inject(carrier) + dag_run_carrier: dict[str, str] = {} + TraceContextTextMapPropagator().inject(dag_run_carrier) dag_run_span_ctx = dag_run_span.get_span_context() + # Step 2: derive the parent task span carrier (child of dag run), as the scheduler does. + ti_model_tracer = provider.get_tracer("airflow.models.taskinstance") + with mock.patch("airflow.models.taskinstance.tracer", ti_model_tracer): + ti_carrier = new_task_run_carrier(dag_run_carrier) + + # Extract the parent task span context (the stable span ID stored in ti_carrier). + parent_task_span_ctx = otel_trace.get_current_span( + context=TraceContextTextMapPropagator().extract(ti_carrier) + ).get_span_context() + + # Step 3: build StartupDetails with ti.context_carrier = ti_carrier. what = StartupDetails( ti=TaskInstance( id=uuid7(), @@ -441,7 +454,7 @@ def test_task_span_is_child_of_dag_run_span(make_ti_context): run_id="test_run", try_number=1, dag_version_id=uuid7(), - context_carrier=carrier, + context_carrier=ti_carrier, ), dag_rel_path="", bundle_info=BundleInfo(name="my-bundle", version=None), @@ -450,27 +463,45 @@ def test_task_span_is_child_of_dag_run_span(make_ti_context): sentry_integration="", ) - task_tracer = provider.get_tracer("airflow.sdk.execution_time.task_runner") - with mock.patch("airflow.sdk.execution_time.task_runner.tracer", task_tracer): - with _make_task_span(what) as span: - task_span_ctx = span.get_span_context() + # Step 4: emit the worker span (task runner side). + task_runner_tracer = provider.get_tracer("airflow.sdk.execution_time.task_runner") + with mock.patch("airflow.sdk.execution_time.task_runner.tracer", task_runner_tracer): + with _make_task_span(what): + pass - # The task span must share the dag run's trace ID. - assert task_span_ctx.trace_id == dag_run_span_ctx.trace_id + # Step 5: emit the parent task span (API server side, as happens on task completion). + mock_ti = mock.MagicMock() + mock_ti.dag_id = "test_dag" + mock_ti.task_id = "my_task" + mock_ti.run_id = "test_run" + mock_ti.try_number = 1 + mock_ti.map_index = -1 + mock_ti.queued_dttm = None + mock_ti.start_date = timezone.utcnow() + mock_ti.dag_run.context_carrier = dag_run_carrier + mock_ti.context_carrier = ti_carrier + api_tracer = provider.get_tracer("airflow.api_fastapi.execution_api.routes.task_instances") + with mock.patch("airflow.api_fastapi.execution_api.routes.task_instances.tracer", api_tracer): + _emit_task_span(mock_ti, TaskInstanceState.SUCCESS) - # The task span's parent must be the dag run span. finished = in_mem_exporter.get_finished_spans() + + # task_run.my_task: emitted by API server, child of dag run, span_id == parent_task_span_ctx.span_id. task_spans = [s for s in finished if s.name == "task_run.my_task"] assert len(task_spans) == 1 - assert task_spans[0].parent is not None - assert task_spans[0].parent.span_id == dag_run_span_ctx.span_id - - # Span attributes are set correctly. - attrs = task_spans[0].attributes - assert attrs["airflow.dag_id"] == "test_dag" - assert attrs["airflow.task_id"] == "my_task" - assert attrs["airflow.dag_run.run_id"] == "test_run" - assert attrs["airflow.task_instance.try_number"] == 1 + task_span = task_spans[0] + assert task_span.parent is not None + assert task_span.parent.span_id == dag_run_span_ctx.span_id + assert task_span.context.span_id == parent_task_span_ctx.span_id + + # worker.my_task: created by task runner, child of the parent task span. + worker_spans = [s for s in finished if s.name == "worker.my_task"] + assert len(worker_spans) == 1 + assert worker_spans[0].parent is not None + assert worker_spans[0].parent.span_id == parent_task_span_ctx.span_id + + # All spans share the same trace ID. + assert {s.context.trace_id for s in finished} == {dag_run_span_ctx.trace_id} def test_task_span_no_parent_when_no_context_carrier(make_ti_context):
