This is an automated email from the ASF dual-hosted git repository.
ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 27754b93fea Add team_name to Multi-Team metrics (#68108)
27754b93fea is described below
commit 27754b93fea5995300b9df9bff45d4739a416a2c
Author: D. Ferruzzi <[email protected]>
AuthorDate: Wed Jun 10 14:54:15 2026 -0700
Add team_name to Multi-Team metrics (#68108)
* Add team_name to Multi-Team metrics
Add `team_name` to ~35 metrics (DagRun, TaskInstance, pool slots) for
multi-team Airflow deployments. When multi_team is disabled or a DAG has no
team, the tag is omitted entirely (no change to existing systems).
Changes:
- Add `team_name` to `DagRun.stats_tags` and `TaskInstance.stats_tags`
- Set `_team_name` on `DagRun`/`TI` objects in 4 scheduler injection points
- Add persistent `_dag_id_to_team_name` cache on `SchedulerJobRunner`
- Add `stats_tags` property to `RuntimeTaskInstance` (Task SDK)
- Add `team_name` to pool slot gauge metrics via
`pool.get_name_to_team_name_mapping`
Co-authored-by: Niko Oliveira <[email protected]>
---
.../src/airflow/jobs/scheduler_job_runner.py | 139 ++++++++++++---------
airflow-core/src/airflow/models/dagrun.py | 4 +-
airflow-core/src/airflow/models/taskinstance.py | 4 +-
airflow-core/tests/unit/jobs/test_scheduler_job.py | 67 ++++++++++
airflow-core/tests/unit/models/test_dagrun.py | 30 +++++
.../tests/unit/models/test_taskinstance.py | 33 +++++
.../src/airflow/sdk/execution_time/task_runner.py | 18 ++-
.../task_sdk/execution_time/test_task_runner.py | 49 ++++++++
8 files changed, 275 insertions(+), 69 deletions(-)
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index a860810b746..21ce3f3c582 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -335,6 +335,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self._parallelism = conf.getint("core", "parallelism")
self._multi_team = conf.getboolean("core", "multi_team")
self._max_partition_dag_runs_per_loop = MAX_PARTITION_DAG_RUNS_PER_LOOP
+ self._dag_id_to_team_name: dict[str, str | None] = {}
self.executors: list[BaseExecutor] = executors if executors else
ExecutorLoader.init_executors()
self.executor: BaseExecutor = self.executors[0]
@@ -385,9 +386,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self, dag_ids: Collection[str], session: Session
) -> dict[str, str | None]:
"""
- Batch query to resolve team names for multiple DAG IDs using the DAG >
Bundle > Team relationship chain.
+ Resolve team names for DAG IDs via the DAG > Bundle > Team
relationship.
- DAG IDs > DagModel (via dag_id) > DagBundleModel (via bundle_name) >
Team
+ Results are cached for the current scheduler loop iteration. The cache
is cleared
+ at the start of each loop so all injection points within one heartbeat
share
+ a single query, but changes are picked up on the next iteration.
:param dag_ids: Collection of DAG IDs to resolve team names for
:param session: Database session for queries
@@ -396,36 +399,35 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
if not dag_ids:
return {}
- try:
- # Query all team names for the given DAG IDs in a single query
- query_results = session.execute(
- select(DagModel.dag_id, Team.name)
- .join(DagBundleModel.teams) # Join Team to DagBundleModel via
association table
- .join(
- DagModel, DagModel.bundle_name == DagBundleModel.name
- ) # Join DagBundleModel to DagModel
- .where(DagModel.dag_id.in_(dag_ids))
- ).all()
-
- # Create mapping from results
- dag_id_to_team_name = {dag_id: team_name for dag_id, team_name in
query_results}
-
- # Ensure all requested dag_ids are in the result (with None for
those not found)
- result = {dag_id: dag_id_to_team_name.get(dag_id) for dag_id in
dag_ids}
+ missing = [dag_id for dag_id in dag_ids if dag_id not in
self._dag_id_to_team_name]
+ if missing:
+ try:
+ # Query all team names for the given DAG IDs in a single query
+ query_results = session.execute(
+ select(DagModel.dag_id, Team.name)
+ .join(DagBundleModel.teams) # Join Team to DagBundleModel
via association table
+ .join(
+ DagModel, DagModel.bundle_name == DagBundleModel.name
+ ) # Join DagBundleModel to DagModel
+ .where(DagModel.dag_id.in_(missing))
+ ).all()
+
+ # Create mapping from results
+ queried = {dag_id: team_name for dag_id, team_name in
query_results}
+
+ # Cache all results, including None for dag_ids with no team
+ for dag_id in missing:
+ self._dag_id_to_team_name[dag_id] = queried.get(dag_id)
+ self.log.debug("Cached team names for %d new dag_ids",
len(missing))
- self.log.debug(
- "Resolved team names for %d DAGs: %s",
- len([team for team in result.values() if team is not None]),
- {dag_id: team for dag_id, team in result.items()},
- )
+ except Exception:
+ # Log the error, explicitly don't fail the scheduling loop
+ self.log.exception("Failed to resolve team names for DAG IDs:
%s", missing)
+ # Return dict with all None values to ensure graceful
degradation
+ return {}
- return result
-
- except Exception:
- # Log the error, explicitly don't fail the scheduling loop
- self.log.exception("Failed to resolve team names for DAG IDs: %s",
list(dag_ids))
- # Return dict with all None values to ensure graceful degradation
- return {}
+ # Ensure all requested dag_ids are in the result (with None for those
not found)
+ return {dag_id: self._dag_id_to_team_name.get(dag_id) for dag_id in
dag_ids}
def _get_workload_team_name(self, workload: SchedulerWorkload, session:
Session) -> str | None:
"""
@@ -716,6 +718,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
len(unique_dag_ids),
list(unique_dag_ids),
)
+ for ti in task_instances_to_examine:
+ # Set team as a transient attribute; team lives on the
Bundle, not
+ # on the TI/DagRun schema, so we resolve it at scheduling
time.
+ if team := dag_id_to_team_name.get(ti.dag_id):
+ ti._team_name = team
executor_slots_available: dict[ExecutorName, int] = {}
# First get a mapping of executor names to slots they have
available
@@ -925,12 +932,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
loop_count,
)
+ starving_pool_team_mapping = (
+
Pool.get_name_to_team_name_mapping(list(pool_num_starving_tasks.keys()),
session=session)
+ if self._multi_team and pool_num_starving_tasks
+ else {}
+ )
for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
- stats.gauge(
- "pool.starving_tasks",
- num_starving_tasks,
- tags={"pool_name": pool_name},
- )
+ starving_tags: dict[str, str] = {"pool_name":
normalize_pool_name_for_stats(pool_name)}
+ if team := starving_pool_team_mapping.get(pool_name):
+ starving_tags["team_name"] = team
+ stats.gauge("pool.starving_tasks", num_starving_tasks,
tags=starving_tags)
stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
stats.gauge("scheduler.tasks.executable", len(executable_tis))
@@ -1604,6 +1615,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
.group_by(DagRun)
)
)
+ if self._multi_team and paused_runs:
+ paused_dag_ids = {dr.dag_id for dr in paused_runs}
+ paused_team_mapping =
self._get_team_names_for_dag_ids(paused_dag_ids, session)
+ for dr in paused_runs:
+ if team := paused_team_mapping.get(dr.dag_id):
+ dr._team_name = team
for dag_run in paused_runs:
dag = self.scheduler_dag_bag.get_dag_for_run(dag_run=dag_run,
session=session)
if dag is not None:
@@ -1714,6 +1731,9 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
idle_count = 0
for loop_count in itertools.count(start=1):
+ # Reset per-loop team name cache so changes to bundle-team
assignments
+ # are picked up each iteration without requiring a scheduler
restart.
+ self._dag_id_to_team_name = {}
with stats.timer("scheduler.scheduler_loop_duration") as timer:
with create_session() as session:
# This will schedule for as many executors as possible.
@@ -1846,6 +1866,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
# examining, rather than making one query per DagRun
dag_runs = DagRun.get_running_dag_runs_to_examine(session=session)
+ if self._multi_team and dag_runs:
+ unique_dag_ids = {dr.dag_id for dr in dag_runs}
+ dr_team_mapping =
self._get_team_names_for_dag_ids(unique_dag_ids, session)
+ for dr in dag_runs:
+ if team := dr_team_mapping.get(dr.dag_id):
+ dr._team_name = team
+
callback_tuples = self._schedule_all_dag_runs(guard, dag_runs,
session)
# Send the callbacks after we commit to ensure the context is up to
date when it gets run
@@ -3086,33 +3113,20 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
from airflow.models.pool import Pool
pools = Pool.slots_stats(session=session)
+ pool_team_mapping = (
+ Pool.get_name_to_team_name_mapping(list(pools.keys()),
session=session)
+ if self._multi_team
+ else {}
+ )
for pool_name, slot_stats in pools.items():
- normalized_pool_name = normalize_pool_name_for_stats(pool_name)
- stats.gauge(
- "pool.open_slots",
- slot_stats["open"],
- tags={"pool_name": normalized_pool_name},
- )
- stats.gauge(
- "pool.queued_slots",
- slot_stats["queued"],
- tags={"pool_name": normalized_pool_name},
- )
- stats.gauge(
- "pool.running_slots",
- slot_stats["running"],
- tags={"pool_name": normalized_pool_name},
- )
- stats.gauge(
- "pool.deferred_slots",
- slot_stats["deferred"],
- tags={"pool_name": normalized_pool_name},
- )
- stats.gauge(
- "pool.scheduled_slots",
- slot_stats["scheduled"],
- tags={"pool_name": normalized_pool_name},
- )
+ metric_tags: dict[str, str] = {"pool_name":
normalize_pool_name_for_stats(pool_name)}
+ if team := pool_team_mapping.get(pool_name):
+ metric_tags["team_name"] = team
+ stats.gauge("pool.open_slots", slot_stats["open"],
tags=metric_tags)
+ stats.gauge("pool.queued_slots", slot_stats["queued"],
tags=metric_tags)
+ stats.gauge("pool.running_slots", slot_stats["running"],
tags=metric_tags)
+ stats.gauge("pool.deferred_slots", slot_stats["deferred"],
tags=metric_tags)
+ stats.gauge("pool.scheduled_slots", slot_stats["scheduled"],
tags=metric_tags)
@provide_session
def adopt_or_reset_orphaned_tasks(self, *, session: Session = NEW_SESSION)
-> int:
@@ -3393,6 +3407,9 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
if self._multi_team:
unique_dag_ids = {ti.dag_id for ti in
task_instances_without_heartbeats}
dag_id_to_team_name =
self._get_team_names_for_dag_ids(unique_dag_ids, session)
+ for ti in task_instances_without_heartbeats:
+ if team := dag_id_to_team_name.get(ti.dag_id):
+ ti._team_name = team
else:
dag_id_to_team_name = {}
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 09cb8b1206f..80e279f1499 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -501,7 +501,9 @@ class DagRun(Base, LoggingMixin):
@property
def stats_tags(self) -> dict[str, str]:
- return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
+ return prune_dict(
+ {"dag_id": self.dag_id, "run_type": self.run_type, "team_name":
getattr(self, "_team_name", None)}
+ )
def get_state(self):
return self._state
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index a31346097e8..740596f9d69 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -732,7 +732,9 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
@property
def stats_tags(self) -> dict[str, str]:
"""Returns task instance tags."""
- return prune_dict({"dag_id": self.dag_id, "task_id": self.task_id})
+ return prune_dict(
+ {"dag_id": self.dag_id, "task_id": self.task_id, "team_name":
getattr(self, "_team_name", None)}
+ )
@staticmethod
def insert_mapping(
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 6b9c6e34309..4ab1bc6415a 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -2839,6 +2839,38 @@ class TestSchedulerJob:
session.rollback()
session.close()
+ @pytest.mark.parametrize(
+ ("multi_team", "expected_tags"),
+ [
+ pytest.param("true", {"pool_name": "team_pool", "team_name":
"team_a"}, id="with_team"),
+ pytest.param("false", {"pool_name": "team_pool"},
id="without_team"),
+ ],
+ )
+ @mock.patch("airflow._shared.observability.metrics.stats._get_backend")
+ def test_emit_pool_metrics_team_name(self, mock_get_backend, multi_team,
expected_tags, session):
+ """Pool metrics include team_name only when multi_team is enabled."""
+ mock_stats = mock.MagicMock(spec=StatsLogger)
+ mock_get_backend.return_value = mock_stats
+
+ clear_db_teams()
+
+ team = Team(name="team_a")
+ session.add(team)
+ session.flush()
+
+ pool = Pool(pool="team_pool", slots=5, include_deferred=False,
team_name="team_a")
+ session.add(pool)
+ session.flush()
+
+ with conf_vars({("core", "multi_team"): multi_team}):
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=scheduler_job)
+ self.job_runner._emit_pool_metrics(session=session)
+
+ mock_stats.gauge.assert_any_call("pool.open_slots", mock.ANY,
tags=expected_tags)
+ mock_stats.gauge.assert_any_call("pool.queued_slots", mock.ANY,
tags=expected_tags)
+ mock_stats.gauge.assert_any_call("pool.running_slots", mock.ANY,
tags=expected_tags)
+
def test_enqueue_task_instances_with_queued_state(self, dag_maker,
session):
dag_id =
"SchedulerJobTest.test_enqueue_task_instances_with_queued_state"
task_id_1 = "dummy"
@@ -9274,6 +9306,41 @@ class TestSchedulerJob:
assert result1 == self.job_runner.executor # Default for no
explicit executor
assert result2 == mock_executors[1] # Matched by executor name
+ @conf_vars({("core", "multi_team"): "true"})
+ def test_multi_team_sets_team_name_on_task_instances(self, dag_maker,
mock_executors, session):
+ """Test that _team_name is set on TaskInstance objects during the
scheduling loop."""
+ clear_db_teams()
+ clear_db_dag_bundles()
+
+ team = Team(name="team_a")
+ session.add(team)
+ session.flush()
+
+ bundle = DagBundleModel(name="bundle_a")
+ bundle.teams.append(team)
+ session.add(bundle)
+ session.flush()
+
+ with dag_maker(dag_id="dag_a", bundle_name="bundle_a",
session=session):
+ EmptyOperator(task_id="task_a")
+
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance("task_a", session=session)
+ ti.state = State.SCHEDULED
+ session.flush()
+
+ scheduler_job = Job()
+ self.job_runner = SchedulerJobRunner(job=scheduler_job)
+ self.job_runner._multi_team = True
+
+ # Simulate what _executable_task_instances_to_queued does
+ dag_id_to_team_name =
self.job_runner._get_team_names_for_dag_ids(["dag_a"], session)
+ if team_name := dag_id_to_team_name.get(ti.dag_id):
+ ti._team_name = team_name
+
+ assert ti._team_name == "team_a"
+ assert ti.stats_tags == {"dag_id": "dag_a", "task_id": "task_a",
"team_name": "team_a"}
+
@pytest.mark.need_serialized_dag
def test_schedule_dag_run_with_upstream_skip(dag_maker, session):
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 68442bbfbc0..a8a831fb6fa 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -4192,3 +4192,33 @@ class TestDagRunTracing:
span = trace.get_current_span(ctx)
assert get_task_span_detail_level(span) == 2
+
+
+class TestDagRunStatsTagsTeamName:
+ def test_stats_tags_without_team_name(self, dag_maker):
+ """stats_tags should not include team_name when _team_name is not
set."""
+ with dag_maker("test_dag"):
+ EmptyOperator(task_id="t1")
+ dr = dag_maker.create_dagrun()
+ tags = dr.stats_tags
+ assert "team_name" not in tags
+ assert tags == {"dag_id": "test_dag", "run_type": "manual"}
+
+ def test_stats_tags_with_team_name(self, dag_maker):
+ """stats_tags should include team_name when _team_name is set."""
+ with dag_maker("test_dag"):
+ EmptyOperator(task_id="t1")
+ dr = dag_maker.create_dagrun()
+ dr._team_name = "my_team"
+ tags = dr.stats_tags
+ assert tags["team_name"] == "my_team"
+ assert tags == {"dag_id": "test_dag", "run_type": "manual",
"team_name": "my_team"}
+
+ def test_stats_tags_with_none_team_name(self, dag_maker):
+ """stats_tags should not include team_name when _team_name is None."""
+ with dag_maker("test_dag"):
+ EmptyOperator(task_id="t1")
+ dr = dag_maker.create_dagrun()
+ dr._team_name = None
+ tags = dr.stats_tags
+ assert "team_name" not in tags
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index f1a46cd52d9..b9d8c85f613 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -4093,3 +4093,36 @@ def
test_task_instance_repr_does_not_raise_for_deferred_columns(dag_maker, sessi
assert "<deferred>" in result
assert "[queued]" not in result
+
+
+class TestTaskInstanceStatsTagsTeamName:
+ def test_stats_tags_without_team_name(self, dag_maker, session):
+ """stats_tags should not include team_name when _team_name is not
set."""
+ with dag_maker("test_dag"):
+ EmptyOperator(task_id="my_task")
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance("my_task", session=session)
+ tags = ti.stats_tags
+ assert "team_name" not in tags
+ assert tags == {"dag_id": "test_dag", "task_id": "my_task"}
+
+ def test_stats_tags_with_team_name(self, dag_maker, session):
+ """stats_tags should include team_name when _team_name is set."""
+ with dag_maker("test_dag"):
+ EmptyOperator(task_id="my_task")
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance("my_task", session=session)
+ ti._team_name = "my_team"
+ tags = ti.stats_tags
+ assert tags["team_name"] == "my_team"
+ assert tags == {"dag_id": "test_dag", "task_id": "my_task",
"team_name": "my_team"}
+
+ def test_stats_tags_with_none_team_name(self, dag_maker, session):
+ """stats_tags should not include team_name when _team_name is None."""
+ with dag_maker("test_dag"):
+ EmptyOperator(task_id="my_task")
+ dr = dag_maker.create_dagrun()
+ ti = dr.get_task_instance("my_task", session=session)
+ ti._team_name = None
+ tags = ti.stats_tags
+ assert "team_name" not in tags
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 48822c52f5d..26d1268df02 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -251,6 +251,14 @@ class RuntimeTaskInstance(TaskInstance):
sentry_integration: str = ""
+ @property
+ def stats_tags(self) -> dict[str, str]:
+ """Metric tags for this task instance, including team_name when
available."""
+ tags: dict[str, str] = {"dag_id": self.dag_id, "task_id": self.task_id}
+ if self._ti_context_from_server and
self._ti_context_from_server.dag_run.team_name:
+ tags["team_name"] = self._ti_context_from_server.dag_run.team_name
+ return tags
+
def __rich_repr__(self):
yield "id", self.id
yield "task_id", self.task_id
@@ -1434,7 +1442,7 @@ def run(
state: TaskInstanceState | None = None
error: BaseException | None = None
- stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
+ stats_tags = ti.stats_tags
stats.incr("ti.start", tags=stats_tags)
try:
@@ -1613,7 +1621,7 @@ def _handle_current_task_success(
# Record operator and task instance success metrics
operator = ti.task.__class__.__name__
- stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
+ stats_tags = ti.stats_tags
stats.incr("operator_successes", tags={**stats_tags, "operator_name":
operator})
stats.incr("ti_successes", tags=stats_tags)
@@ -1720,7 +1728,7 @@ def _finalize_task_failure(
# Record operator and task instance failed metrics
operator = ti.task.__class__.__name__
- stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
+ stats_tags = ti.stats_tags
stats.incr("operator_failures", tags={**stats_tags, "operator_name":
operator})
stats.incr("ti_failures", tags=stats_tags)
@@ -2085,9 +2093,7 @@ def finalize(
# Record task duration metrics for all terminal states
if ti.start_date and ti.end_date:
duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
- stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
-
- stats.timing("task.duration", duration_ms, tags=stats_tags)
+ stats.timing("task.duration", duration_ms, tags=ti.stats_tags)
task = ti.task
# Pushing xcom for each operator extra links defined on the operator only.
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 5bf2434fba4..c71fbc85ce6 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -5358,6 +5358,55 @@ class TestTaskInstanceMetrics:
)
backend.incr.assert_any_call("ti_failures", tags=stats_tags)
+ @pytest.mark.parametrize(
+ ("team_name", "expected_tags_extra"),
+ [
+ pytest.param("my_team", {"team_name": "my_team"}, id="with_team"),
+ pytest.param(None, {}, id="without_team"),
+ ],
+ )
+ def test_ti_start_metric_respects_team_name(
+ self, team_name, expected_tags_extra, create_runtime_ti,
mock_supervisor_comms
+ ):
+ task = PythonOperator(task_id="test", python_callable=lambda:
"success")
+ ti = create_runtime_ti(task=task)
+ if team_name:
+ ti._ti_context_from_server.dag_run.team_name = team_name
+
+ with
mock.patch("airflow.sdk._shared.observability.metrics.stats._get_backend") as
mock_get_backend:
+ backend = mock.MagicMock(spec=StatsLogger)
+ mock_get_backend.return_value = backend
+ run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+
+ expected = {"dag_id": ti.dag_id, "task_id": ti.task_id,
**expected_tags_extra}
+ backend.incr.assert_any_call("ti.start", tags=expected)
+
+ @pytest.mark.parametrize(
+ ("task_callable", "operator_metric", "ti_metric"),
+ [
+ pytest.param(lambda: "success", "operator_successes",
"ti_successes", id="success"),
+ pytest.param(lambda: 1 / 0, "operator_failures", "ti_failures",
id="failure"),
+ ],
+ )
+ def test_operator_metrics_respect_team_name(
+ self, task_callable, operator_metric, ti_metric, create_runtime_ti,
mock_supervisor_comms
+ ):
+ task = PythonOperator(task_id="test", python_callable=task_callable)
+ ti = create_runtime_ti(task=task)
+ ti._ti_context_from_server.dag_run.team_name = "team_a"
+
+ with
mock.patch("airflow.sdk._shared.observability.metrics.stats._get_backend") as
mock_get_backend:
+ backend = mock.MagicMock(spec=StatsLogger)
+ mock_get_backend.return_value = backend
+ run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+
+ stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id,
"team_name": "team_a"}
+ backend.incr.assert_any_call(
+ operator_metric,
+ tags={**stats_tags, "operator_name": "PythonOperator"},
+ )
+ backend.incr.assert_any_call(ti_metric, tags=stats_tags)
+
class TestDetailSpan:
"""Tests for the detail_span decorator / context manager."""