This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6f71c973952 Populate trigger team_name at creation time for multi-team 
support (#67605)
6f71c973952 is described below

commit 6f71c973952c01a0a900d547f988833f71ab395b
Author: Ramit Kataria <[email protected]>
AuthorDate: Thu Jun 4 07:07:46 2026 -0700

    Populate trigger team_name at creation time for multi-team support (#67605)
    
    Populate `trigger.team_name` at all 4 trigger creation paths so that
    team-scoped triggerers can filter to their own team's triggers:
    - **`TaskInstance.defer_task()`** — resolves via 
`DagModel.get_team_name(dag_id)`
    - **Execution API `PATCH .../state` (deferred)** — resolves via 
`get_team_name_for_ti(ti_id)`
    - **`TriggererCallback.queue()`** — resolves via 
`DagBundleModel.get_team_name(bundle_name)`
    - **Asset watcher triggers (`bulk_write_to_db`)** — resolves via 
`DagBundleModel.get_team_name(bundle_name)`
    
    All paths are gated on `core.multi_team`; when disabled, no DB query
    is executed and `team_name` remains NULL.
    
    To support the above changes:
    - Add `DagBundleModel.get_team_name()` following the existing pattern
    used by Pool, Connection, Variable, and DagModel.
    - Add `session` parameter to `Callback.queue()` so triggerer callbacks
    can resolve team.
---
 .../execution_api/routes/task_instances.py         | 10 +++-
 .../src/airflow/dag_processing/collection.py       | 10 +++-
 airflow-core/src/airflow/models/callback.py        | 13 ++++-
 airflow-core/src/airflow/models/dagbundle.py       | 17 +++++-
 airflow-core/src/airflow/models/deadline.py        |  4 +-
 airflow-core/src/airflow/models/taskinstance.py    |  7 +++
 .../src/airflow/serialization/definitions/dag.py   |  8 ++-
 .../versions/head/test_task_instances.py           | 63 ++++++++++++++++++++
 .../tests/unit/dag_processing/test_collection.py   | 47 +++++++++++++++
 airflow-core/tests/unit/models/test_callback.py    | 45 +++++++++++++--
 airflow-core/tests/unit/models/test_dagbundle.py   | 60 +++++++++++++++++++
 airflow-core/tests/unit/models/test_deadline.py    | 16 ++++++
 .../tests/unit/models/test_taskinstance.py         | 67 ++++++++++++++++++++++
 13 files changed, 351 insertions(+), 16 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 2afd96806c4..e1062be1314 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -68,7 +68,12 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
 )
 from airflow.api_fastapi.execution_api.datamodels.token import TIToken
 from airflow.api_fastapi.execution_api.deps import DepContainer
-from airflow.api_fastapi.execution_api.security import CurrentTIToken, 
ExecutionAPIRoute, require_auth
+from airflow.api_fastapi.execution_api.security import (
+    CurrentTIToken,
+    ExecutionAPIRoute,
+    get_team_name_for_ti,
+    require_auth,
+)
 from airflow.configuration import conf
 from airflow.exceptions import TaskNotFound
 from airflow.models.asset import AssetActive
@@ -288,8 +293,6 @@ def ti_run(
             or 0
         )
 
-        from airflow.api_fastapi.execution_api.security import 
get_team_name_for_ti
-
         dr.team_name = get_team_name_for_ti(task_instance_id, session)
 
         context = TIRunContext(
@@ -613,6 +616,7 @@ def _create_ti_state_update_query_and_update_state(
             classpath=ti_patch_payload.classpath,
             kwargs={},
             queue=ti_patch_payload.queue,
+            team_name=get_team_name_for_ti(task_instance_id, session),
         )
         trigger_row.encrypted_kwargs = trigger_kwargs
         session.add(trigger_row)
diff --git a/airflow-core/src/airflow/dag_processing/collection.py 
b/airflow-core/src/airflow/dag_processing/collection.py
index 7cdf490d938..174e0872d20 100644
--- a/airflow-core/src/airflow/dag_processing/collection.py
+++ b/airflow-core/src/airflow/dag_processing/collection.py
@@ -1041,7 +1041,11 @@ class AssetModelOperation(NamedTuple):
             )
 
     def add_asset_trigger_references(
-        self, assets: dict[tuple[str, str], AssetModel], *, session: Session
+        self,
+        assets: dict[tuple[str, str], AssetModel],
+        *,
+        team_name: str | None = None,
+        session: Session,
     ) -> None:
         from airflow.serialization.encoders import encode_trigger
 
@@ -1113,7 +1117,9 @@ class AssetModelOperation(NamedTuple):
                 trigger
                 for trigger in [
                     Trigger(
-                        classpath=triggers[trigger_hash]["classpath"], 
kwargs=triggers[trigger_hash]["kwargs"]
+                        classpath=triggers[trigger_hash]["classpath"],
+                        kwargs=triggers[trigger_hash]["kwargs"],
+                        team_name=team_name,
                     )
                     for trigger_hash in all_trigger_hashes
                     if trigger_hash not in orm_triggers
diff --git a/airflow-core/src/airflow/models/callback.py 
b/airflow-core/src/airflow/models/callback.py
index 454057083fb..78247be9ec0 100644
--- a/airflow-core/src/airflow/models/callback.py
+++ b/airflow-core/src/airflow/models/callback.py
@@ -33,10 +33,12 @@ from sqlalchemy.orm import Mapped, mapped_column, 
relationship
 from airflow._shared.module_loading import accepts_context as _accepts_context 
 # noqa: F401
 from airflow._shared.observability.metrics import stats
 from airflow._shared.timezones import timezone
+from airflow.configuration import conf
 from airflow.executors.workloads import BaseWorkload
 from airflow.executors.workloads.callback import CallbackFetchMethod
 from airflow.models import Base
 from airflow.models.base import StringID
+from airflow.models.dagbundle import DagBundleModel
 from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
 from airflow.utils.state import CallbackState
 
@@ -151,7 +153,7 @@ class Callback(Base, BaseWorkload):
         if prefix:
             self.data["prefix"] = prefix
 
-    def queue(self):
+    def queue(self, *, session: Session) -> None:
         self.state = CallbackState.QUEUED
 
     def get_metric_info(self, status: CallbackState, result: Any) -> dict:
@@ -225,17 +227,22 @@ class TriggererCallback(Callback):
     def __repr__(self):
         return f"{self.data['path']}({self.data['kwargs'] or ''}) on a 
triggerer"
 
-    def queue(self):
+    def queue(self, *, session: Session) -> None:
         from airflow.models.trigger import Trigger
         from airflow.triggers.callback import CallbackTrigger
 
+        team_name: str | None = None
+        if self.bundle_name and conf.getboolean("core", "multi_team"):
+            team_name = DagBundleModel.get_team_name(self.bundle_name, 
session=session)
+
         self.trigger = Trigger.from_object(
             CallbackTrigger(
                 callback_path=self.data["path"],
                 callback_kwargs=self.data["kwargs"],
             )
         )
-        super().queue()
+        self.trigger.team_name = team_name
+        super().queue(session=session)
 
     def handle_event(self, event: TriggerEvent, session: Session):
         from airflow.triggers.callback import PAYLOAD_BODY_KEY, 
PAYLOAD_STATUS_KEY
diff --git a/airflow-core/src/airflow/models/dagbundle.py 
b/airflow-core/src/airflow/models/dagbundle.py
index 1ad43d0dcc3..abd723a4e67 100644
--- a/airflow-core/src/airflow/models/dagbundle.py
+++ b/airflow-core/src/airflow/models/dagbundle.py
@@ -17,16 +17,21 @@
 from __future__ import annotations
 
 from datetime import datetime
+from typing import TYPE_CHECKING
 
 import sqlalchemy as sa
-from sqlalchemy import Boolean, String, Text
+from sqlalchemy import Boolean, String, Text, select
 from sqlalchemy.orm import Mapped, mapped_column, relationship
 
 from airflow.models.base import Base, StringID
-from airflow.models.team import dag_bundle_team_association_table
+from airflow.models.team import Team, dag_bundle_team_association_table
 from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
 
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
 
 class DagBundleModel(Base, LoggingMixin):
     """
@@ -108,3 +113,11 @@ class DagBundleModel(Base, LoggingMixin):
         except (KeyError, ValueError) as e:
             self.log.warning("Failed to render URL template for bundle %s: 
%s", self.name, e)
             return None
+
+    @staticmethod
+    @provide_session
+    def get_team_name(bundle_name: str, *, session: Session = NEW_SESSION) -> 
str | None:
+        """Return the team name for a bundle, or None if not mapped to a 
team."""
+        return session.scalar(
+            
select(Team.name).join(DagBundleModel.teams).where(DagBundleModel.name == 
bundle_name)
+        )
diff --git a/airflow-core/src/airflow/models/deadline.py 
b/airflow-core/src/airflow/models/deadline.py
index 6fda3597504..d82b9e7b805 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -129,6 +129,7 @@ class Deadline(Base):
         dagrun_id: int,
         deadline_alert_id: UUID | None,
         dag_id: str | None = None,
+        bundle_name: str | None = None,
     ):
         super().__init__()
         self.deadline_time = deadline_time
@@ -137,6 +138,7 @@ class Deadline(Base):
         self.callback = Callback.create_from_sdk_def(
             callback_def=callback, prefix=CALLBACK_METRICS_PREFIX, 
dag_id=dag_id
         )
+        self.callback.bundle_name = bundle_name
         self.deadline_alert_id = deadline_alert_id
 
     def __repr__(self):
@@ -240,7 +242,7 @@ class Deadline(Base):
                 "context": get_simple_context()
             }
 
-            self.callback.queue()
+            self.callback.queue(session=session)
             session.add(self.callback)
             session.flush()
 
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 3469cf4acb8..e7783fd4e1f 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -1744,9 +1744,16 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
             else:
                 self.trigger_timeout = None
 
+            team_name: str | None = None
+            if conf.getboolean("core", "multi_team"):
+                from airflow.models.dag import DagModel
+
+                team_name = DagModel.get_team_name(self.dag_id, 
session=session)
+
             trigger_row = Trigger(
                 classpath=start_trigger_args.trigger_cls,
                 kwargs=trigger_kwargs,
+                team_name=team_name,
             )
 
             # First, make the trigger entry
diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py 
b/airflow-core/src/airflow/serialization/definitions/dag.py
index 092a3e589ce..c27675801c5 100644
--- a/airflow-core/src/airflow/serialization/definitions/dag.py
+++ b/airflow-core/src/airflow/serialization/definitions/dag.py
@@ -36,6 +36,7 @@ from airflow.configuration import conf as airflow_conf
 from airflow.exceptions import AirflowException, NodeNotFound, TaskNotFound
 from airflow.models.dag import DagModel
 from airflow.models.dag_version import DagVersion
+from airflow.models.dagbundle import DagBundleModel
 from airflow.models.dagrun import DagRun
 from airflow.models.deadline import Deadline
 from airflow.models.deadline_alert import DeadlineAlert as DeadlineAlertModel
@@ -222,7 +223,11 @@ class SerializedDAG:
         asset_op.activate_assets_if_possible(orm_assets.values(), 
session=session)
         session.flush()  # Activation is needed when we add trigger references.
 
-        asset_op.add_asset_trigger_references(orm_assets, session=session)
+        team_name: str | None = None
+        if airflow_conf.getboolean("core", "multi_team"):
+            team_name = DagBundleModel.get_team_name(bundle_name, 
session=session)
+
+        asset_op.add_asset_trigger_references(orm_assets, team_name=team_name, 
session=session)
         dag_op.update_dag_asset_expression(orm_dags=orm_dags, 
orm_assets=orm_assets)
         session.flush()
 
@@ -677,6 +682,7 @@ class SerializedDAG:
                             dagrun_id=orm_dagrun.id,
                             deadline_alert_id=deadline_alert.id,
                             dag_id=orm_dagrun.dag_id,
+                            bundle_name=orm_dagrun.dag_model.bundle_name,
                         )
                     )
                     stats.incr("deadline_alerts.deadline_created", 
tags={"dag_id": self.dag_id})
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 3022bbfea06..4d21bb406f9 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -1529,6 +1529,69 @@ class TestTIUpdateState:
             else:
                 assert t[0].queue is None
 
+    @staticmethod
+    def _defer_ti_in_team_bundle(client, session, create_task_instance):
+        """Map the TI's Dag to a bundle/team, then defer it via the Execution 
API."""
+        from airflow.models.dagbundle import DagBundleModel
+        from airflow.models.team import Team
+
+        ti = create_task_instance(
+            task_id="test_ti_deferred_team",
+            state=State.RUNNING,
+            session=session,
+        )
+
+        bundle_name = "bundle_deferred_team_test"
+        team_name = "team_deferred_test"
+        bundle = session.get(DagBundleModel, bundle_name) or 
DagBundleModel(name=bundle_name)
+        team = session.get(Team, team_name) or Team(name=team_name)
+        if team not in bundle.teams:
+            bundle.teams.append(team)
+        session.add(bundle)
+        session.flush()
+        session.execute(update(DagModel).where(DagModel.dag_id == 
ti.dag_id).values(bundle_name=bundle_name))
+        session.commit()
+
+        payload = {
+            "state": "deferred",
+            "trigger_kwargs": {"key": "value"},
+            "classpath": "my-classpath",
+            "next_method": "execute_callback",
+        }
+        response = client.patch(f"/execution/task-instances/{ti.id}/state", 
json=payload)
+        assert response.status_code == 204
+        return team_name
+
+    @conf_vars({("core", "multi_team"): "True"})
+    def test_ti_update_state_to_deferred_populates_trigger_team_name(
+        self, client, session, create_task_instance, time_machine
+    ):
+        """Trigger created on deferral gets team_name from the TI's bundle."""
+        from airflow.models.trigger import Trigger
+
+        time_machine.move_to(timezone.datetime(2024, 11, 22), tick=False)
+
+        team_name = self._defer_ti_in_team_bundle(client, session, 
create_task_instance)
+
+        session.expire_all()
+        trigger = session.scalars(select(Trigger)).one()
+        assert trigger.team_name == team_name
+
+    @conf_vars({("core", "multi_team"): "False"})
+    def 
test_ti_update_state_to_deferred_skips_trigger_team_name_when_multi_team_disabled(
+        self, client, session, create_task_instance, time_machine
+    ):
+        """When multi_team is disabled, the trigger team_name stays NULL."""
+        from airflow.models.trigger import Trigger
+
+        time_machine.move_to(timezone.datetime(2024, 11, 22), tick=False)
+
+        self._defer_ti_in_team_bundle(client, session, create_task_instance)
+
+        session.expire_all()
+        trigger = session.scalars(select(Trigger)).one()
+        assert trigger.team_name is None
+
     def test_ti_update_state_to_reschedule(self, client, session, 
create_task_instance, time_machine):
         """
         Test that tests if the transition to reschedule state is handled 
correctly.
diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py 
b/airflow-core/tests/unit/dag_processing/test_collection.py
index bc7a1490e16..ab2a950a9af 100644
--- a/airflow-core/tests/unit/dag_processing/test_collection.py
+++ b/airflow-core/tests/unit/dag_processing/test_collection.py
@@ -263,6 +263,53 @@ class TestAssetModelOperation:
         asset_model = session.scalars(select(AssetModel)).one()
         assert len(asset_model.triggers) == expected_num_triggers
 
+    @pytest.mark.usefixtures("testing_dag_bundle")
+    @pytest.mark.parametrize(
+        ("use_team", "expected"),
+        [
+            pytest.param(True, "testing", id="with-team"),
+            pytest.param(False, None, id="no-team"),
+        ],
+    )
+    def test_add_asset_trigger_references_populates_team_name(
+        self, dag_maker, session, testing_team, use_team, expected
+    ):
+        asset = Asset(
+            "test_trigger_team_asset",
+            watchers=[AssetWatcher(name="watcher", 
trigger=FileDeleteTrigger(mock.Mock()))],
+        )
+
+        with dag_maker(dag_id="test_trigger_team_dag", schedule=[asset]) as 
dag:
+            EmptyOperator(task_id="mytask")
+
+        # Use raw DagModelOperation (not dag_maker's bulk_write_to_db) to 
control team_name
+        dags = {dag.dag_id: LazyDeserializedDAG.from_dag(dag)}
+        orm_dags = DagModelOperation(dags, "testing", 
None).add_dags(session=session)
+        orm_dags[dag.dag_id].is_stale = False
+        orm_dags[dag.dag_id].is_paused = False
+        session.flush()
+
+        asset_op = AssetModelOperation.collect(dags)
+        orm_assets = asset_op.sync_assets(session=session)
+        session.flush()
+        asset_op.add_dag_asset_references(orm_dags, orm_assets, 
session=session)
+        asset_op.activate_assets_if_possible(orm_assets.values(), 
session=session)
+        session.flush()
+
+        # Clear any triggers created by dag_maker's bulk_write_to_db
+        session.execute(delete(Trigger))
+        for asset_model in orm_assets.values():
+            asset_model.watchers = []
+        session.flush()
+
+        team_name = testing_team.name if use_team else None
+        asset_op.add_asset_trigger_references(orm_assets, team_name=team_name, 
session=session)
+        session.flush()
+
+        triggers = session.scalars(select(Trigger)).all()
+        assert len(triggers) == 1
+        assert triggers[0].team_name == expected
+
     @pytest.mark.usefixtures("testing_dag_bundle")
     def test_add_asset_trigger_references_hash_consistency(self, dag_maker, 
session):
         """Trigger hash from the DAG-parsed path must equal the hash computed
diff --git a/airflow-core/tests/unit/models/test_callback.py 
b/airflow-core/tests/unit/models/test_callback.py
index c31d1c99b49..b5296979ed4 100644
--- a/airflow-core/tests/unit/models/test_callback.py
+++ b/airflow-core/tests/unit/models/test_callback.py
@@ -37,6 +37,7 @@ from airflow.triggers.base import TriggerEvent
 from airflow.triggers.callback import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY
 from airflow.utils.session import create_session
 
+from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.db import clear_db_callbacks
 
 pytestmark = [pytest.mark.db_test]
@@ -165,12 +166,48 @@ class TestTriggererCallback:
         assert callback.state == CallbackState.SCHEDULED
         assert callback.trigger is None
 
-        callback.queue()
+        callback.queue(session=session)
         assert isinstance(callback.trigger, Trigger)
         assert callback.trigger.kwargs["callback_path"] == 
TEST_ASYNC_CALLBACK.path
         assert callback.trigger.kwargs["callback_kwargs"] == 
TEST_ASYNC_CALLBACK.kwargs
         assert callback.state == CallbackState.QUEUED
 
+    @staticmethod
+    def _queue_callback(session, *, has_bundle, has_team):
+        from airflow.models.dagbundle import DagBundleModel
+        from airflow.models.team import Team
+
+        bundle = session.get(DagBundleModel, "testing")
+        bundle.teams = [session.get(Team, "testing")] if has_team else []
+        session.flush()
+
+        callback = TriggererCallback(TEST_ASYNC_CALLBACK)
+        callback.bundle_name = "testing" if has_bundle else None
+        callback.queue(session=session)
+        return callback
+
+    @conf_vars({("core", "multi_team"): "True"})
+    @pytest.mark.parametrize(
+        ("has_bundle", "has_team", "expected_team_name"),
+        [
+            pytest.param(True, True, "testing", id="bundle-mapped-to-team"),
+            pytest.param(True, False, None, id="bundle-without-team"),
+            pytest.param(False, False, None, id="no-bundle"),
+        ],
+    )
+    def test_queue_populates_trigger_team_name(
+        self, session, testing_dag_bundle, testing_team, has_bundle, has_team, 
expected_team_name
+    ):
+        callback = self._queue_callback(session, has_bundle=has_bundle, 
has_team=has_team)
+        assert callback.trigger.team_name == expected_team_name
+
+    @conf_vars({("core", "multi_team"): "False"})
+    def test_queue_skips_trigger_team_name_when_multi_team_disabled(
+        self, session, testing_dag_bundle, testing_team
+    ):
+        callback = self._queue_callback(session, has_bundle=True, 
has_team=True)
+        assert callback.trigger.team_name is None
+
     @pytest.mark.parametrize(
         ("event", "terminal_state"),
         [
@@ -199,7 +236,7 @@ class TestTriggererCallback:
     )
     def test_handle_event(self, session, event, terminal_state):
         callback = TriggererCallback(TEST_ASYNC_CALLBACK)
-        callback.queue()
+        callback.queue(session=session)
         callback.handle_event(event, session)
 
         status = event.payload[PAYLOAD_STATUS_KEY]
@@ -230,11 +267,11 @@ class TestExecutorCallback:
         assert retrieved.created_at is not None
         assert retrieved.trigger_id is None
 
-    def test_queue(self):
+    def test_queue(self, session):
         callback = ExecutorCallback(TEST_SYNC_CALLBACK, 
fetch_method=CallbackFetchMethod.DAG_ATTRIBUTE)
         assert callback.state == CallbackState.SCHEDULED
 
-        callback.queue()
+        callback.queue(session=session)
         assert callback.state == CallbackState.QUEUED
 
     def test_session_get_requires_uuid_not_str(self, session):
diff --git a/airflow-core/tests/unit/models/test_dagbundle.py 
b/airflow-core/tests/unit/models/test_dagbundle.py
new file mode 100644
index 00000000000..84223e34a9b
--- /dev/null
+++ b/airflow-core/tests/unit/models/test_dagbundle.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.models.dagbundle import DagBundleModel
+
+from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_teams
+
+if TYPE_CHECKING:
+    from sqlalchemy.orm import Session
+
+    from airflow.models.team import Team
+
+pytestmark = pytest.mark.db_test
+
+
+class TestDagBundleModel:
+    def setup_method(self):
+        clear_db_dag_bundles()
+        clear_db_teams()
+
+    def teardown_method(self):
+        clear_db_dag_bundles()
+        clear_db_teams()
+
+    def test_get_team_name(self, testing_team: Team, session: Session):
+        bundle = DagBundleModel(name="test_bundle")
+        bundle.teams.append(testing_team)
+        session.add(bundle)
+        session.flush()
+
+        assert DagBundleModel.get_team_name("test_bundle", session=session) == 
"testing"
+
+    def test_get_team_name_no_team(self, session: Session):
+        bundle = DagBundleModel(name="test_bundle")
+        session.add(bundle)
+        session.flush()
+
+        assert DagBundleModel.get_team_name("test_bundle", session=session) is 
None
+
+    def test_get_team_name_unknown_bundle(self, session: Session):
+        assert DagBundleModel.get_team_name("does_not_exist", session=session) 
is None
diff --git a/airflow-core/tests/unit/models/test_deadline.py 
b/airflow-core/tests/unit/models/test_deadline.py
index 94c6977ae0c..c14ea44375e 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -210,6 +210,22 @@ class TestDeadline:
             assert f"needed by {DEFAULT_DATE}" in repr_str
             assert TEST_CALLBACK_PATH in repr_str
 
+    @pytest.mark.db_test
+    def test_bundle_name_propagated_to_callback(self, dagrun, session):
+        """The bundle name is forwarded to the callback so the triggerer can 
resolve its team."""
+        deadline = Deadline(
+            deadline_time=DEFAULT_DATE,
+            callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+            dagrun_id=dagrun.id,
+            dag_id=dagrun.dag_id,
+            deadline_alert_id=None,
+            bundle_name="my_bundle",
+        )
+        session.add(deadline)
+        session.flush()
+
+        assert deadline.callback.bundle_name == "my_bundle"
+
     @pytest.mark.db_test
     def test_handle_miss(self, dagrun, session):
         deadline_orm = Deadline(
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 99cdd38f4f3..6e7d0c569f9 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -111,6 +111,7 @@ from airflow.utils.state import DagRunState, State, 
TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 from tests_common.test_utils import db
+from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.db import clear_db_runs
 from tests_common.test_utils.mock_operators import MockOperator
 from tests_common.test_utils.taskinstance import (
@@ -2954,6 +2955,72 @@ def test_defer_task_try_number_increment_on_state(
     assert ti.try_number == expected_try_number, msg
 
 
+def _defer_ti_in_testing_bundle(create_task_instance, session, *, with_team):
+    """Create a deferrable TI whose Dag belongs to the ``testing`` bundle, 
then defer it."""
+    from sqlalchemy import update
+
+    from airflow.models.dag import DagModel
+    from airflow.models.dagbundle import DagBundleModel
+    from airflow.models.team import Team
+    from airflow.triggers.base import StartTriggerArgs
+
+    bundle = session.get(DagBundleModel, "testing")
+    bundle.teams = [session.get(Team, "testing")] if with_team else []
+    session.flush()
+
+    ti = create_task_instance(
+        dag_id="test_defer_team",
+        task_id="op",
+        start_from_trigger=True,
+        start_trigger_args=StartTriggerArgs(
+            trigger_cls="airflow.triggers.testing.SuccessTrigger",
+            next_method="execute_complete",
+            trigger_kwargs={"moment": "2024-01-01"},
+        ),
+        session=session,
+    )
+    session.execute(
+        update(DagModel).where(DagModel.dag_id == 
"test_defer_team").values(bundle_name="testing")
+    )
+    session.flush()
+
+    ti.defer_task(session=session)
+    return ti
+
+
[email protected]_test
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+    ("with_team", "expected_team_name"),
+    [
+        pytest.param(True, "testing", id="bundle-mapped-to-team"),
+        pytest.param(False, None, id="bundle-without-team"),
+    ],
+)
+def test_defer_task_populates_trigger_team_name(
+    create_task_instance, session, testing_dag_bundle, testing_team, 
with_team, expected_team_name
+):
+    from airflow.models.trigger import Trigger
+
+    _defer_ti_in_testing_bundle(create_task_instance, session, 
with_team=with_team)
+
+    trigger = session.scalars(select(Trigger)).one()
+    assert trigger.team_name == expected_team_name
+
+
[email protected]_test
+@conf_vars({("core", "multi_team"): "False"})
+def test_defer_task_skips_trigger_team_name_when_multi_team_disabled(
+    create_task_instance, session, testing_dag_bundle, testing_team
+):
+    from airflow.models.trigger import Trigger
+
+    _defer_ti_in_testing_bundle(create_task_instance, session, with_team=True)
+
+    trigger = session.scalars(select(Trigger)).one()
+    assert trigger.team_name is None
+
+
 class TestTaskInstanceRelationships:
     @pytest.mark.parametrize(
         "attr",

Reply via email to