This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 27754b93fea Add team_name to Multi-Team metrics  (#68108)
27754b93fea is described below

commit 27754b93fea5995300b9df9bff45d4739a416a2c
Author: D. Ferruzzi <[email protected]>
AuthorDate: Wed Jun 10 14:54:15 2026 -0700

    Add team_name to Multi-Team metrics  (#68108)
    
    * Add team_name to Multi-Team metrics
    
    Add `team_name` to ~35 metrics (DagRun, TaskInstance, pool slots) for 
multi-team Airflow deployments. When multi_team is disabled or a DAG has no 
team, the tag is omitted entirely (no change to existing systems).
    
    Changes:
    - Add `team_name` to `DagRun.stats_tags` and `TaskInstance.stats_tags`
    - Set `_team_name` on `DagRun`/`TI` objects in 4 scheduler injection points
    - Add persistent `_dag_id_to_team_name` cache on `SchedulerJobRunner`
    - Add `stats_tags` property to `RuntimeTaskInstance` (Task SDK)
    - Add `team_name` to pool slot gauge metrics via 
`pool.get_name_to_team_name_mapping`
    
    Co-authored-by: Niko Oliveira <[email protected]>
---
 .../src/airflow/jobs/scheduler_job_runner.py       | 139 ++++++++++++---------
 airflow-core/src/airflow/models/dagrun.py          |   4 +-
 airflow-core/src/airflow/models/taskinstance.py    |   4 +-
 airflow-core/tests/unit/jobs/test_scheduler_job.py |  67 ++++++++++
 airflow-core/tests/unit/models/test_dagrun.py      |  30 +++++
 .../tests/unit/models/test_taskinstance.py         |  33 +++++
 .../src/airflow/sdk/execution_time/task_runner.py  |  18 ++-
 .../task_sdk/execution_time/test_task_runner.py    |  49 ++++++++
 8 files changed, 275 insertions(+), 69 deletions(-)

diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py 
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index a860810b746..21ce3f3c582 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -335,6 +335,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         self._parallelism = conf.getint("core", "parallelism")
         self._multi_team = conf.getboolean("core", "multi_team")
         self._max_partition_dag_runs_per_loop = MAX_PARTITION_DAG_RUNS_PER_LOOP
+        self._dag_id_to_team_name: dict[str, str | None] = {}
 
         self.executors: list[BaseExecutor] = executors if executors else 
ExecutorLoader.init_executors()
         self.executor: BaseExecutor = self.executors[0]
@@ -385,9 +386,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         self, dag_ids: Collection[str], session: Session
     ) -> dict[str, str | None]:
         """
-        Batch query to resolve team names for multiple DAG IDs using the DAG > 
Bundle > Team relationship chain.
+        Resolve team names for DAG IDs via the DAG > Bundle > Team 
relationship.
 
-        DAG IDs > DagModel (via dag_id) > DagBundleModel (via bundle_name) > 
Team
+        Results are cached for the current scheduler loop iteration. The cache 
is cleared
+        at the start of each loop so all injection points within one heartbeat 
share
+        a single query, but changes are picked up on the next iteration.
 
         :param dag_ids: Collection of DAG IDs to resolve team names for
         :param session: Database session for queries
@@ -396,36 +399,35 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         if not dag_ids:
             return {}
 
-        try:
-            # Query all team names for the given DAG IDs in a single query
-            query_results = session.execute(
-                select(DagModel.dag_id, Team.name)
-                .join(DagBundleModel.teams)  # Join Team to DagBundleModel via 
association table
-                .join(
-                    DagModel, DagModel.bundle_name == DagBundleModel.name
-                )  # Join DagBundleModel to DagModel
-                .where(DagModel.dag_id.in_(dag_ids))
-            ).all()
-
-            # Create mapping from results
-            dag_id_to_team_name = {dag_id: team_name for dag_id, team_name in 
query_results}
-
-            # Ensure all requested dag_ids are in the result (with None for 
those not found)
-            result = {dag_id: dag_id_to_team_name.get(dag_id) for dag_id in 
dag_ids}
+        missing = [dag_id for dag_id in dag_ids if dag_id not in 
self._dag_id_to_team_name]
+        if missing:
+            try:
+                # Query all team names for the given DAG IDs in a single query
+                query_results = session.execute(
+                    select(DagModel.dag_id, Team.name)
+                    .join(DagBundleModel.teams)  # Join Team to DagBundleModel 
via association table
+                    .join(
+                        DagModel, DagModel.bundle_name == DagBundleModel.name
+                    )  # Join DagBundleModel to DagModel
+                    .where(DagModel.dag_id.in_(missing))
+                ).all()
+
+                # Create mapping from results
+                queried = {dag_id: team_name for dag_id, team_name in 
query_results}
+
+                # Cache all results, including None for dag_ids with no team
+                for dag_id in missing:
+                    self._dag_id_to_team_name[dag_id] = queried.get(dag_id)
+                    self.log.debug("Cached team names for %d new dag_ids", 
len(missing))
 
-            self.log.debug(
-                "Resolved team names for %d DAGs: %s",
-                len([team for team in result.values() if team is not None]),
-                {dag_id: team for dag_id, team in result.items()},
-            )
+            except Exception:
+                # Log the error, explicitly don't fail the scheduling loop
+                self.log.exception("Failed to resolve team names for DAG IDs: 
%s", missing)
+                # Return dict with all None values to ensure graceful 
degradation
+                return {}
 
-            return result
-
-        except Exception:
-            # Log the error, explicitly don't fail the scheduling loop
-            self.log.exception("Failed to resolve team names for DAG IDs: %s", 
list(dag_ids))
-            # Return dict with all None values to ensure graceful degradation
-            return {}
+        # Ensure all requested dag_ids are in the result (with None for those 
not found)
+        return {dag_id: self._dag_id_to_team_name.get(dag_id) for dag_id in 
dag_ids}
 
     def _get_workload_team_name(self, workload: SchedulerWorkload, session: 
Session) -> str | None:
         """
@@ -716,6 +718,11 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     len(unique_dag_ids),
                     list(unique_dag_ids),
                 )
+                for ti in task_instances_to_examine:
+                    # Set team as a transient attribute; team lives on the 
Bundle, not
+                    # on the TI/DagRun schema, so we resolve it at scheduling 
time.
+                    if team := dag_id_to_team_name.get(ti.dag_id):
+                        ti._team_name = team
 
             executor_slots_available: dict[ExecutorName, int] = {}
             # First get a mapping of executor names to slots they have 
available
@@ -925,12 +932,16 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 loop_count,
             )
 
+        starving_pool_team_mapping = (
+            
Pool.get_name_to_team_name_mapping(list(pool_num_starving_tasks.keys()), 
session=session)
+            if self._multi_team and pool_num_starving_tasks
+            else {}
+        )
         for pool_name, num_starving_tasks in pool_num_starving_tasks.items():
-            stats.gauge(
-                "pool.starving_tasks",
-                num_starving_tasks,
-                tags={"pool_name": pool_name},
-            )
+            starving_tags: dict[str, str] = {"pool_name": 
normalize_pool_name_for_stats(pool_name)}
+            if team := starving_pool_team_mapping.get(pool_name):
+                starving_tags["team_name"] = team
+            stats.gauge("pool.starving_tasks", num_starving_tasks, 
tags=starving_tags)
 
         stats.gauge("scheduler.tasks.starving", num_starving_tasks_total)
         stats.gauge("scheduler.tasks.executable", len(executable_tis))
@@ -1604,6 +1615,12 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     .group_by(DagRun)
                 )
             )
+            if self._multi_team and paused_runs:
+                paused_dag_ids = {dr.dag_id for dr in paused_runs}
+                paused_team_mapping = 
self._get_team_names_for_dag_ids(paused_dag_ids, session)
+                for dr in paused_runs:
+                    if team := paused_team_mapping.get(dr.dag_id):
+                        dr._team_name = team
             for dag_run in paused_runs:
                 dag = self.scheduler_dag_bag.get_dag_for_run(dag_run=dag_run, 
session=session)
                 if dag is not None:
@@ -1714,6 +1731,9 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         idle_count = 0
 
         for loop_count in itertools.count(start=1):
+            # Reset per-loop team name cache so changes to bundle-team 
assignments
+            # are picked up each iteration without requiring a scheduler 
restart.
+            self._dag_id_to_team_name = {}
             with stats.timer("scheduler.scheduler_loop_duration") as timer:
                 with create_session() as session:
                     # This will schedule for as many executors as possible.
@@ -1846,6 +1866,13 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             # examining, rather than making one query per DagRun
             dag_runs = DagRun.get_running_dag_runs_to_examine(session=session)
 
+            if self._multi_team and dag_runs:
+                unique_dag_ids = {dr.dag_id for dr in dag_runs}
+                dr_team_mapping = 
self._get_team_names_for_dag_ids(unique_dag_ids, session)
+                for dr in dag_runs:
+                    if team := dr_team_mapping.get(dr.dag_id):
+                        dr._team_name = team
+
             callback_tuples = self._schedule_all_dag_runs(guard, dag_runs, 
session)
 
         # Send the callbacks after we commit to ensure the context is up to 
date when it gets run
@@ -3086,33 +3113,20 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         from airflow.models.pool import Pool
 
         pools = Pool.slots_stats(session=session)
+        pool_team_mapping = (
+            Pool.get_name_to_team_name_mapping(list(pools.keys()), 
session=session)
+            if self._multi_team
+            else {}
+        )
         for pool_name, slot_stats in pools.items():
-            normalized_pool_name = normalize_pool_name_for_stats(pool_name)
-            stats.gauge(
-                "pool.open_slots",
-                slot_stats["open"],
-                tags={"pool_name": normalized_pool_name},
-            )
-            stats.gauge(
-                "pool.queued_slots",
-                slot_stats["queued"],
-                tags={"pool_name": normalized_pool_name},
-            )
-            stats.gauge(
-                "pool.running_slots",
-                slot_stats["running"],
-                tags={"pool_name": normalized_pool_name},
-            )
-            stats.gauge(
-                "pool.deferred_slots",
-                slot_stats["deferred"],
-                tags={"pool_name": normalized_pool_name},
-            )
-            stats.gauge(
-                "pool.scheduled_slots",
-                slot_stats["scheduled"],
-                tags={"pool_name": normalized_pool_name},
-            )
+            metric_tags: dict[str, str] = {"pool_name": 
normalize_pool_name_for_stats(pool_name)}
+            if team := pool_team_mapping.get(pool_name):
+                metric_tags["team_name"] = team
+            stats.gauge("pool.open_slots", slot_stats["open"], 
tags=metric_tags)
+            stats.gauge("pool.queued_slots", slot_stats["queued"], 
tags=metric_tags)
+            stats.gauge("pool.running_slots", slot_stats["running"], 
tags=metric_tags)
+            stats.gauge("pool.deferred_slots", slot_stats["deferred"], 
tags=metric_tags)
+            stats.gauge("pool.scheduled_slots", slot_stats["scheduled"], 
tags=metric_tags)
 
     @provide_session
     def adopt_or_reset_orphaned_tasks(self, *, session: Session = NEW_SESSION) 
-> int:
@@ -3393,6 +3407,9 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         if self._multi_team:
             unique_dag_ids = {ti.dag_id for ti in 
task_instances_without_heartbeats}
             dag_id_to_team_name = 
self._get_team_names_for_dag_ids(unique_dag_ids, session)
+            for ti in task_instances_without_heartbeats:
+                if team := dag_id_to_team_name.get(ti.dag_id):
+                    ti._team_name = team
         else:
             dag_id_to_team_name = {}
 
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index 09cb8b1206f..80e279f1499 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -501,7 +501,9 @@ class DagRun(Base, LoggingMixin):
 
     @property
     def stats_tags(self) -> dict[str, str]:
-        return prune_dict({"dag_id": self.dag_id, "run_type": self.run_type})
+        return prune_dict(
+            {"dag_id": self.dag_id, "run_type": self.run_type, "team_name": 
getattr(self, "_team_name", None)}
+        )
 
     def get_state(self):
         return self._state
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index a31346097e8..740596f9d69 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -732,7 +732,9 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
     @property
     def stats_tags(self) -> dict[str, str]:
         """Returns task instance tags."""
-        return prune_dict({"dag_id": self.dag_id, "task_id": self.task_id})
+        return prune_dict(
+            {"dag_id": self.dag_id, "task_id": self.task_id, "team_name": 
getattr(self, "_team_name", None)}
+        )
 
     @staticmethod
     def insert_mapping(
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 6b9c6e34309..4ab1bc6415a 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -2839,6 +2839,38 @@ class TestSchedulerJob:
         session.rollback()
         session.close()
 
+    @pytest.mark.parametrize(
+        ("multi_team", "expected_tags"),
+        [
+            pytest.param("true", {"pool_name": "team_pool", "team_name": 
"team_a"}, id="with_team"),
+            pytest.param("false", {"pool_name": "team_pool"}, 
id="without_team"),
+        ],
+    )
+    @mock.patch("airflow._shared.observability.metrics.stats._get_backend")
+    def test_emit_pool_metrics_team_name(self, mock_get_backend, multi_team, 
expected_tags, session):
+        """Pool metrics include team_name only when multi_team is enabled."""
+        mock_stats = mock.MagicMock(spec=StatsLogger)
+        mock_get_backend.return_value = mock_stats
+
+        clear_db_teams()
+
+        team = Team(name="team_a")
+        session.add(team)
+        session.flush()
+
+        pool = Pool(pool="team_pool", slots=5, include_deferred=False, 
team_name="team_a")
+        session.add(pool)
+        session.flush()
+
+        with conf_vars({("core", "multi_team"): multi_team}):
+            scheduler_job = Job()
+            self.job_runner = SchedulerJobRunner(job=scheduler_job)
+            self.job_runner._emit_pool_metrics(session=session)
+
+        mock_stats.gauge.assert_any_call("pool.open_slots", mock.ANY, 
tags=expected_tags)
+        mock_stats.gauge.assert_any_call("pool.queued_slots", mock.ANY, 
tags=expected_tags)
+        mock_stats.gauge.assert_any_call("pool.running_slots", mock.ANY, 
tags=expected_tags)
+
     def test_enqueue_task_instances_with_queued_state(self, dag_maker, 
session):
         dag_id = 
"SchedulerJobTest.test_enqueue_task_instances_with_queued_state"
         task_id_1 = "dummy"
@@ -9274,6 +9306,41 @@ class TestSchedulerJob:
             assert result1 == self.job_runner.executor  # Default for no 
explicit executor
             assert result2 == mock_executors[1]  # Matched by executor name
 
+    @conf_vars({("core", "multi_team"): "true"})
+    def test_multi_team_sets_team_name_on_task_instances(self, dag_maker, 
mock_executors, session):
+        """Test that _team_name is set on TaskInstance objects during the 
scheduling loop."""
+        clear_db_teams()
+        clear_db_dag_bundles()
+
+        team = Team(name="team_a")
+        session.add(team)
+        session.flush()
+
+        bundle = DagBundleModel(name="bundle_a")
+        bundle.teams.append(team)
+        session.add(bundle)
+        session.flush()
+
+        with dag_maker(dag_id="dag_a", bundle_name="bundle_a", 
session=session):
+            EmptyOperator(task_id="task_a")
+
+        dr = dag_maker.create_dagrun()
+        ti = dr.get_task_instance("task_a", session=session)
+        ti.state = State.SCHEDULED
+        session.flush()
+
+        scheduler_job = Job()
+        self.job_runner = SchedulerJobRunner(job=scheduler_job)
+        self.job_runner._multi_team = True
+
+        # Simulate what _executable_task_instances_to_queued does
+        dag_id_to_team_name = 
self.job_runner._get_team_names_for_dag_ids(["dag_a"], session)
+        if team_name := dag_id_to_team_name.get(ti.dag_id):
+            ti._team_name = team_name
+
+        assert ti._team_name == "team_a"
+        assert ti.stats_tags == {"dag_id": "dag_a", "task_id": "task_a", 
"team_name": "team_a"}
+
 
 @pytest.mark.need_serialized_dag
 def test_schedule_dag_run_with_upstream_skip(dag_maker, session):
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index 68442bbfbc0..a8a831fb6fa 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -4192,3 +4192,33 @@ class TestDagRunTracing:
 
         span = trace.get_current_span(ctx)
         assert get_task_span_detail_level(span) == 2
+
+
+class TestDagRunStatsTagsTeamName:
+    def test_stats_tags_without_team_name(self, dag_maker):
+        """stats_tags should not include team_name when _team_name is not 
set."""
+        with dag_maker("test_dag"):
+            EmptyOperator(task_id="t1")
+        dr = dag_maker.create_dagrun()
+        tags = dr.stats_tags
+        assert "team_name" not in tags
+        assert tags == {"dag_id": "test_dag", "run_type": "manual"}
+
+    def test_stats_tags_with_team_name(self, dag_maker):
+        """stats_tags should include team_name when _team_name is set."""
+        with dag_maker("test_dag"):
+            EmptyOperator(task_id="t1")
+        dr = dag_maker.create_dagrun()
+        dr._team_name = "my_team"
+        tags = dr.stats_tags
+        assert tags["team_name"] == "my_team"
+        assert tags == {"dag_id": "test_dag", "run_type": "manual", 
"team_name": "my_team"}
+
+    def test_stats_tags_with_none_team_name(self, dag_maker):
+        """stats_tags should not include team_name when _team_name is None."""
+        with dag_maker("test_dag"):
+            EmptyOperator(task_id="t1")
+        dr = dag_maker.create_dagrun()
+        dr._team_name = None
+        tags = dr.stats_tags
+        assert "team_name" not in tags
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index f1a46cd52d9..b9d8c85f613 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -4093,3 +4093,36 @@ def 
test_task_instance_repr_does_not_raise_for_deferred_columns(dag_maker, sessi
 
     assert "<deferred>" in result
     assert "[queued]" not in result
+
+
+class TestTaskInstanceStatsTagsTeamName:
+    def test_stats_tags_without_team_name(self, dag_maker, session):
+        """stats_tags should not include team_name when _team_name is not 
set."""
+        with dag_maker("test_dag"):
+            EmptyOperator(task_id="my_task")
+        dr = dag_maker.create_dagrun()
+        ti = dr.get_task_instance("my_task", session=session)
+        tags = ti.stats_tags
+        assert "team_name" not in tags
+        assert tags == {"dag_id": "test_dag", "task_id": "my_task"}
+
+    def test_stats_tags_with_team_name(self, dag_maker, session):
+        """stats_tags should include team_name when _team_name is set."""
+        with dag_maker("test_dag"):
+            EmptyOperator(task_id="my_task")
+        dr = dag_maker.create_dagrun()
+        ti = dr.get_task_instance("my_task", session=session)
+        ti._team_name = "my_team"
+        tags = ti.stats_tags
+        assert tags["team_name"] == "my_team"
+        assert tags == {"dag_id": "test_dag", "task_id": "my_task", 
"team_name": "my_team"}
+
+    def test_stats_tags_with_none_team_name(self, dag_maker, session):
+        """stats_tags should not include team_name when _team_name is None."""
+        with dag_maker("test_dag"):
+            EmptyOperator(task_id="my_task")
+        dr = dag_maker.create_dagrun()
+        ti = dr.get_task_instance("my_task", session=session)
+        ti._team_name = None
+        tags = ti.stats_tags
+        assert "team_name" not in tags
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 48822c52f5d..26d1268df02 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -251,6 +251,14 @@ class RuntimeTaskInstance(TaskInstance):
 
     sentry_integration: str = ""
 
+    @property
+    def stats_tags(self) -> dict[str, str]:
+        """Metric tags for this task instance, including team_name when 
available."""
+        tags: dict[str, str] = {"dag_id": self.dag_id, "task_id": self.task_id}
+        if self._ti_context_from_server and 
self._ti_context_from_server.dag_run.team_name:
+            tags["team_name"] = self._ti_context_from_server.dag_run.team_name
+        return tags
+
     def __rich_repr__(self):
         yield "id", self.id
         yield "task_id", self.task_id
@@ -1434,7 +1442,7 @@ def run(
     state: TaskInstanceState | None = None
     error: BaseException | None = None
 
-    stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
+    stats_tags = ti.stats_tags
     stats.incr("ti.start", tags=stats_tags)
 
     try:
@@ -1613,7 +1621,7 @@ def _handle_current_task_success(
 
     # Record operator and task instance success metrics
     operator = ti.task.__class__.__name__
-    stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
+    stats_tags = ti.stats_tags
 
     stats.incr("operator_successes", tags={**stats_tags, "operator_name": 
operator})
     stats.incr("ti_successes", tags=stats_tags)
@@ -1720,7 +1728,7 @@ def _finalize_task_failure(
 
     # Record operator and task instance failed metrics
     operator = ti.task.__class__.__name__
-    stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
+    stats_tags = ti.stats_tags
 
     stats.incr("operator_failures", tags={**stats_tags, "operator_name": 
operator})
     stats.incr("ti_failures", tags=stats_tags)
@@ -2085,9 +2093,7 @@ def finalize(
     # Record task duration metrics for all terminal states
     if ti.start_date and ti.end_date:
         duration_ms = (ti.end_date - ti.start_date).total_seconds() * 1000
-        stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id}
-
-        stats.timing("task.duration", duration_ms, tags=stats_tags)
+        stats.timing("task.duration", duration_ms, tags=ti.stats_tags)
 
     task = ti.task
     # Pushing xcom for each operator extra links defined on the operator only.
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 5bf2434fba4..c71fbc85ce6 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -5358,6 +5358,55 @@ class TestTaskInstanceMetrics:
             )
             backend.incr.assert_any_call("ti_failures", tags=stats_tags)
 
+    @pytest.mark.parametrize(
+        ("team_name", "expected_tags_extra"),
+        [
+            pytest.param("my_team", {"team_name": "my_team"}, id="with_team"),
+            pytest.param(None, {}, id="without_team"),
+        ],
+    )
+    def test_ti_start_metric_respects_team_name(
+        self, team_name, expected_tags_extra, create_runtime_ti, 
mock_supervisor_comms
+    ):
+        task = PythonOperator(task_id="test", python_callable=lambda: 
"success")
+        ti = create_runtime_ti(task=task)
+        if team_name:
+            ti._ti_context_from_server.dag_run.team_name = team_name
+
+        with 
mock.patch("airflow.sdk._shared.observability.metrics.stats._get_backend") as 
mock_get_backend:
+            backend = mock.MagicMock(spec=StatsLogger)
+            mock_get_backend.return_value = backend
+            run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+
+            expected = {"dag_id": ti.dag_id, "task_id": ti.task_id, 
**expected_tags_extra}
+            backend.incr.assert_any_call("ti.start", tags=expected)
+
+    @pytest.mark.parametrize(
+        ("task_callable", "operator_metric", "ti_metric"),
+        [
+            pytest.param(lambda: "success", "operator_successes", 
"ti_successes", id="success"),
+            pytest.param(lambda: 1 / 0, "operator_failures", "ti_failures", 
id="failure"),
+        ],
+    )
+    def test_operator_metrics_respect_team_name(
+        self, task_callable, operator_metric, ti_metric, create_runtime_ti, 
mock_supervisor_comms
+    ):
+        task = PythonOperator(task_id="test", python_callable=task_callable)
+        ti = create_runtime_ti(task=task)
+        ti._ti_context_from_server.dag_run.team_name = "team_a"
+
+        with 
mock.patch("airflow.sdk._shared.observability.metrics.stats._get_backend") as 
mock_get_backend:
+            backend = mock.MagicMock(spec=StatsLogger)
+            mock_get_backend.return_value = backend
+            run(ti, context=ti.get_template_context(), log=mock.MagicMock())
+
+            stats_tags = {"dag_id": ti.dag_id, "task_id": ti.task_id, 
"team_name": "team_a"}
+            backend.incr.assert_any_call(
+                operator_metric,
+                tags={**stats_tags, "operator_name": "PythonOperator"},
+            )
+            backend.incr.assert_any_call(ti_metric, tags=stats_tags)
+
 
 class TestDetailSpan:
     """Tests for the detail_span decorator / context manager."""

Reply via email to