This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6f71c973952 Populate trigger team_name at creation time for multi-team
support (#67605)
6f71c973952 is described below
commit 6f71c973952c01a0a900d547f988833f71ab395b
Author: Ramit Kataria <[email protected]>
AuthorDate: Thu Jun 4 07:07:46 2026 -0700
Populate trigger team_name at creation time for multi-team support (#67605)
Populate `trigger.team_name` at all 4 trigger creation paths so that
team-scoped triggerers can filter to their own team's triggers:
- **`TaskInstance.defer_task()`** — resolves via
`DagModel.get_team_name(dag_id)`
- **Execution API `PATCH .../state` (deferred)** — resolves via
`get_team_name_for_ti(ti_id)`
- **`TriggererCallback.queue()`** — resolves via
`DagBundleModel.get_team_name(bundle_name)`
- **Asset watcher triggers (`bulk_write_to_db`)** — resolves via
`DagBundleModel.get_team_name(bundle_name)`
All paths are gated on `core.multi_team`; when disabled, no DB query
is executed and `team_name` remains NULL.
To support the above changes:
- Add `DagBundleModel.get_team_name()` following the existing pattern
used by Pool, Connection, Variable, and DagModel.
- Add `session` parameter to `Callback.queue()` so triggerer callbacks
can resolve team.
---
.../execution_api/routes/task_instances.py | 10 +++-
.../src/airflow/dag_processing/collection.py | 10 +++-
airflow-core/src/airflow/models/callback.py | 13 ++++-
airflow-core/src/airflow/models/dagbundle.py | 17 +++++-
airflow-core/src/airflow/models/deadline.py | 4 +-
airflow-core/src/airflow/models/taskinstance.py | 7 +++
.../src/airflow/serialization/definitions/dag.py | 8 ++-
.../versions/head/test_task_instances.py | 63 ++++++++++++++++++++
.../tests/unit/dag_processing/test_collection.py | 47 +++++++++++++++
airflow-core/tests/unit/models/test_callback.py | 45 +++++++++++++--
airflow-core/tests/unit/models/test_dagbundle.py | 60 +++++++++++++++++++
airflow-core/tests/unit/models/test_deadline.py | 16 ++++++
.../tests/unit/models/test_taskinstance.py | 67 ++++++++++++++++++++++
13 files changed, 351 insertions(+), 16 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 2afd96806c4..e1062be1314 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -68,7 +68,12 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
)
from airflow.api_fastapi.execution_api.datamodels.token import TIToken
from airflow.api_fastapi.execution_api.deps import DepContainer
-from airflow.api_fastapi.execution_api.security import CurrentTIToken,
ExecutionAPIRoute, require_auth
+from airflow.api_fastapi.execution_api.security import (
+ CurrentTIToken,
+ ExecutionAPIRoute,
+ get_team_name_for_ti,
+ require_auth,
+)
from airflow.configuration import conf
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
@@ -288,8 +293,6 @@ def ti_run(
or 0
)
- from airflow.api_fastapi.execution_api.security import
get_team_name_for_ti
-
dr.team_name = get_team_name_for_ti(task_instance_id, session)
context = TIRunContext(
@@ -613,6 +616,7 @@ def _create_ti_state_update_query_and_update_state(
classpath=ti_patch_payload.classpath,
kwargs={},
queue=ti_patch_payload.queue,
+ team_name=get_team_name_for_ti(task_instance_id, session),
)
trigger_row.encrypted_kwargs = trigger_kwargs
session.add(trigger_row)
diff --git a/airflow-core/src/airflow/dag_processing/collection.py
b/airflow-core/src/airflow/dag_processing/collection.py
index 7cdf490d938..174e0872d20 100644
--- a/airflow-core/src/airflow/dag_processing/collection.py
+++ b/airflow-core/src/airflow/dag_processing/collection.py
@@ -1041,7 +1041,11 @@ class AssetModelOperation(NamedTuple):
)
def add_asset_trigger_references(
- self, assets: dict[tuple[str, str], AssetModel], *, session: Session
+ self,
+ assets: dict[tuple[str, str], AssetModel],
+ *,
+ team_name: str | None = None,
+ session: Session,
) -> None:
from airflow.serialization.encoders import encode_trigger
@@ -1113,7 +1117,9 @@ class AssetModelOperation(NamedTuple):
trigger
for trigger in [
Trigger(
- classpath=triggers[trigger_hash]["classpath"],
kwargs=triggers[trigger_hash]["kwargs"]
+ classpath=triggers[trigger_hash]["classpath"],
+ kwargs=triggers[trigger_hash]["kwargs"],
+ team_name=team_name,
)
for trigger_hash in all_trigger_hashes
if trigger_hash not in orm_triggers
diff --git a/airflow-core/src/airflow/models/callback.py
b/airflow-core/src/airflow/models/callback.py
index 454057083fb..78247be9ec0 100644
--- a/airflow-core/src/airflow/models/callback.py
+++ b/airflow-core/src/airflow/models/callback.py
@@ -33,10 +33,12 @@ from sqlalchemy.orm import Mapped, mapped_column,
relationship
from airflow._shared.module_loading import accepts_context as _accepts_context
# noqa: F401
from airflow._shared.observability.metrics import stats
from airflow._shared.timezones import timezone
+from airflow.configuration import conf
from airflow.executors.workloads import BaseWorkload
from airflow.executors.workloads.callback import CallbackFetchMethod
from airflow.models import Base
from airflow.models.base import StringID
+from airflow.models.dagbundle import DagBundleModel
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime
from airflow.utils.state import CallbackState
@@ -151,7 +153,7 @@ class Callback(Base, BaseWorkload):
if prefix:
self.data["prefix"] = prefix
- def queue(self):
+ def queue(self, *, session: Session) -> None:
self.state = CallbackState.QUEUED
def get_metric_info(self, status: CallbackState, result: Any) -> dict:
@@ -225,17 +227,22 @@ class TriggererCallback(Callback):
def __repr__(self):
return f"{self.data['path']}({self.data['kwargs'] or ''}) on a
triggerer"
- def queue(self):
+ def queue(self, *, session: Session) -> None:
from airflow.models.trigger import Trigger
from airflow.triggers.callback import CallbackTrigger
+ team_name: str | None = None
+ if self.bundle_name and conf.getboolean("core", "multi_team"):
+ team_name = DagBundleModel.get_team_name(self.bundle_name,
session=session)
+
self.trigger = Trigger.from_object(
CallbackTrigger(
callback_path=self.data["path"],
callback_kwargs=self.data["kwargs"],
)
)
- super().queue()
+ self.trigger.team_name = team_name
+ super().queue(session=session)
def handle_event(self, event: TriggerEvent, session: Session):
from airflow.triggers.callback import PAYLOAD_BODY_KEY,
PAYLOAD_STATUS_KEY
diff --git a/airflow-core/src/airflow/models/dagbundle.py
b/airflow-core/src/airflow/models/dagbundle.py
index 1ad43d0dcc3..abd723a4e67 100644
--- a/airflow-core/src/airflow/models/dagbundle.py
+++ b/airflow-core/src/airflow/models/dagbundle.py
@@ -17,16 +17,21 @@
from __future__ import annotations
from datetime import datetime
+from typing import TYPE_CHECKING
import sqlalchemy as sa
-from sqlalchemy import Boolean, String, Text
+from sqlalchemy import Boolean, String, Text, select
from sqlalchemy.orm import Mapped, mapped_column, relationship
from airflow.models.base import Base, StringID
-from airflow.models.team import dag_bundle_team_association_table
+from airflow.models.team import Team, dag_bundle_team_association_table
from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
class DagBundleModel(Base, LoggingMixin):
"""
@@ -108,3 +113,11 @@ class DagBundleModel(Base, LoggingMixin):
except (KeyError, ValueError) as e:
self.log.warning("Failed to render URL template for bundle %s:
%s", self.name, e)
return None
+
+ @staticmethod
+ @provide_session
+ def get_team_name(bundle_name: str, *, session: Session = NEW_SESSION) ->
str | None:
+ """Return the team name for a bundle, or None if not mapped to a
team."""
+ return session.scalar(
+
select(Team.name).join(DagBundleModel.teams).where(DagBundleModel.name ==
bundle_name)
+ )
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index 6fda3597504..d82b9e7b805 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -129,6 +129,7 @@ class Deadline(Base):
dagrun_id: int,
deadline_alert_id: UUID | None,
dag_id: str | None = None,
+ bundle_name: str | None = None,
):
super().__init__()
self.deadline_time = deadline_time
@@ -137,6 +138,7 @@ class Deadline(Base):
self.callback = Callback.create_from_sdk_def(
callback_def=callback, prefix=CALLBACK_METRICS_PREFIX,
dag_id=dag_id
)
+ self.callback.bundle_name = bundle_name
self.deadline_alert_id = deadline_alert_id
def __repr__(self):
@@ -240,7 +242,7 @@ class Deadline(Base):
"context": get_simple_context()
}
- self.callback.queue()
+ self.callback.queue(session=session)
session.add(self.callback)
session.flush()
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 3469cf4acb8..e7783fd4e1f 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -1744,9 +1744,16 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
else:
self.trigger_timeout = None
+ team_name: str | None = None
+ if conf.getboolean("core", "multi_team"):
+ from airflow.models.dag import DagModel
+
+ team_name = DagModel.get_team_name(self.dag_id,
session=session)
+
trigger_row = Trigger(
classpath=start_trigger_args.trigger_cls,
kwargs=trigger_kwargs,
+ team_name=team_name,
)
# First, make the trigger entry
diff --git a/airflow-core/src/airflow/serialization/definitions/dag.py
b/airflow-core/src/airflow/serialization/definitions/dag.py
index 092a3e589ce..c27675801c5 100644
--- a/airflow-core/src/airflow/serialization/definitions/dag.py
+++ b/airflow-core/src/airflow/serialization/definitions/dag.py
@@ -36,6 +36,7 @@ from airflow.configuration import conf as airflow_conf
from airflow.exceptions import AirflowException, NodeNotFound, TaskNotFound
from airflow.models.dag import DagModel
from airflow.models.dag_version import DagVersion
+from airflow.models.dagbundle import DagBundleModel
from airflow.models.dagrun import DagRun
from airflow.models.deadline import Deadline
from airflow.models.deadline_alert import DeadlineAlert as DeadlineAlertModel
@@ -222,7 +223,11 @@ class SerializedDAG:
asset_op.activate_assets_if_possible(orm_assets.values(),
session=session)
session.flush() # Activation is needed when we add trigger references.
- asset_op.add_asset_trigger_references(orm_assets, session=session)
+ team_name: str | None = None
+ if airflow_conf.getboolean("core", "multi_team"):
+ team_name = DagBundleModel.get_team_name(bundle_name,
session=session)
+
+ asset_op.add_asset_trigger_references(orm_assets, team_name=team_name,
session=session)
dag_op.update_dag_asset_expression(orm_dags=orm_dags,
orm_assets=orm_assets)
session.flush()
@@ -677,6 +682,7 @@ class SerializedDAG:
dagrun_id=orm_dagrun.id,
deadline_alert_id=deadline_alert.id,
dag_id=orm_dagrun.dag_id,
+ bundle_name=orm_dagrun.dag_model.bundle_name,
)
)
stats.incr("deadline_alerts.deadline_created",
tags={"dag_id": self.dag_id})
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 3022bbfea06..4d21bb406f9 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -1529,6 +1529,69 @@ class TestTIUpdateState:
else:
assert t[0].queue is None
+ @staticmethod
+ def _defer_ti_in_team_bundle(client, session, create_task_instance):
+ """Map the TI's Dag to a bundle/team, then defer it via the Execution
API."""
+ from airflow.models.dagbundle import DagBundleModel
+ from airflow.models.team import Team
+
+ ti = create_task_instance(
+ task_id="test_ti_deferred_team",
+ state=State.RUNNING,
+ session=session,
+ )
+
+ bundle_name = "bundle_deferred_team_test"
+ team_name = "team_deferred_test"
+ bundle = session.get(DagBundleModel, bundle_name) or
DagBundleModel(name=bundle_name)
+ team = session.get(Team, team_name) or Team(name=team_name)
+ if team not in bundle.teams:
+ bundle.teams.append(team)
+ session.add(bundle)
+ session.flush()
+ session.execute(update(DagModel).where(DagModel.dag_id ==
ti.dag_id).values(bundle_name=bundle_name))
+ session.commit()
+
+ payload = {
+ "state": "deferred",
+ "trigger_kwargs": {"key": "value"},
+ "classpath": "my-classpath",
+ "next_method": "execute_callback",
+ }
+ response = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
+ assert response.status_code == 204
+ return team_name
+
+ @conf_vars({("core", "multi_team"): "True"})
+ def test_ti_update_state_to_deferred_populates_trigger_team_name(
+ self, client, session, create_task_instance, time_machine
+ ):
+ """Trigger created on deferral gets team_name from the TI's bundle."""
+ from airflow.models.trigger import Trigger
+
+ time_machine.move_to(timezone.datetime(2024, 11, 22), tick=False)
+
+ team_name = self._defer_ti_in_team_bundle(client, session,
create_task_instance)
+
+ session.expire_all()
+ trigger = session.scalars(select(Trigger)).one()
+ assert trigger.team_name == team_name
+
+ @conf_vars({("core", "multi_team"): "False"})
+ def
test_ti_update_state_to_deferred_skips_trigger_team_name_when_multi_team_disabled(
+ self, client, session, create_task_instance, time_machine
+ ):
+ """When multi_team is disabled, the trigger team_name stays NULL."""
+ from airflow.models.trigger import Trigger
+
+ time_machine.move_to(timezone.datetime(2024, 11, 22), tick=False)
+
+ self._defer_ti_in_team_bundle(client, session, create_task_instance)
+
+ session.expire_all()
+ trigger = session.scalars(select(Trigger)).one()
+ assert trigger.team_name is None
+
def test_ti_update_state_to_reschedule(self, client, session,
create_task_instance, time_machine):
"""
Test that tests if the transition to reschedule state is handled
correctly.
diff --git a/airflow-core/tests/unit/dag_processing/test_collection.py
b/airflow-core/tests/unit/dag_processing/test_collection.py
index bc7a1490e16..ab2a950a9af 100644
--- a/airflow-core/tests/unit/dag_processing/test_collection.py
+++ b/airflow-core/tests/unit/dag_processing/test_collection.py
@@ -263,6 +263,53 @@ class TestAssetModelOperation:
asset_model = session.scalars(select(AssetModel)).one()
assert len(asset_model.triggers) == expected_num_triggers
+ @pytest.mark.usefixtures("testing_dag_bundle")
+ @pytest.mark.parametrize(
+ ("use_team", "expected"),
+ [
+ pytest.param(True, "testing", id="with-team"),
+ pytest.param(False, None, id="no-team"),
+ ],
+ )
+ def test_add_asset_trigger_references_populates_team_name(
+ self, dag_maker, session, testing_team, use_team, expected
+ ):
+ asset = Asset(
+ "test_trigger_team_asset",
+ watchers=[AssetWatcher(name="watcher",
trigger=FileDeleteTrigger(mock.Mock()))],
+ )
+
+ with dag_maker(dag_id="test_trigger_team_dag", schedule=[asset]) as
dag:
+ EmptyOperator(task_id="mytask")
+
+ # Use raw DagModelOperation (not dag_maker's bulk_write_to_db) to
control team_name
+ dags = {dag.dag_id: LazyDeserializedDAG.from_dag(dag)}
+ orm_dags = DagModelOperation(dags, "testing",
None).add_dags(session=session)
+ orm_dags[dag.dag_id].is_stale = False
+ orm_dags[dag.dag_id].is_paused = False
+ session.flush()
+
+ asset_op = AssetModelOperation.collect(dags)
+ orm_assets = asset_op.sync_assets(session=session)
+ session.flush()
+ asset_op.add_dag_asset_references(orm_dags, orm_assets,
session=session)
+ asset_op.activate_assets_if_possible(orm_assets.values(),
session=session)
+ session.flush()
+
+ # Clear any triggers created by dag_maker's bulk_write_to_db
+ session.execute(delete(Trigger))
+ for asset_model in orm_assets.values():
+ asset_model.watchers = []
+ session.flush()
+
+ team_name = testing_team.name if use_team else None
+ asset_op.add_asset_trigger_references(orm_assets, team_name=team_name,
session=session)
+ session.flush()
+
+ triggers = session.scalars(select(Trigger)).all()
+ assert len(triggers) == 1
+ assert triggers[0].team_name == expected
+
@pytest.mark.usefixtures("testing_dag_bundle")
def test_add_asset_trigger_references_hash_consistency(self, dag_maker,
session):
"""Trigger hash from the DAG-parsed path must equal the hash computed
diff --git a/airflow-core/tests/unit/models/test_callback.py
b/airflow-core/tests/unit/models/test_callback.py
index c31d1c99b49..b5296979ed4 100644
--- a/airflow-core/tests/unit/models/test_callback.py
+++ b/airflow-core/tests/unit/models/test_callback.py
@@ -37,6 +37,7 @@ from airflow.triggers.base import TriggerEvent
from airflow.triggers.callback import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY
from airflow.utils.session import create_session
+from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_callbacks
pytestmark = [pytest.mark.db_test]
@@ -165,12 +166,48 @@ class TestTriggererCallback:
assert callback.state == CallbackState.SCHEDULED
assert callback.trigger is None
- callback.queue()
+ callback.queue(session=session)
assert isinstance(callback.trigger, Trigger)
assert callback.trigger.kwargs["callback_path"] ==
TEST_ASYNC_CALLBACK.path
assert callback.trigger.kwargs["callback_kwargs"] ==
TEST_ASYNC_CALLBACK.kwargs
assert callback.state == CallbackState.QUEUED
+ @staticmethod
+ def _queue_callback(session, *, has_bundle, has_team):
+ from airflow.models.dagbundle import DagBundleModel
+ from airflow.models.team import Team
+
+ bundle = session.get(DagBundleModel, "testing")
+ bundle.teams = [session.get(Team, "testing")] if has_team else []
+ session.flush()
+
+ callback = TriggererCallback(TEST_ASYNC_CALLBACK)
+ callback.bundle_name = "testing" if has_bundle else None
+ callback.queue(session=session)
+ return callback
+
+ @conf_vars({("core", "multi_team"): "True"})
+ @pytest.mark.parametrize(
+ ("has_bundle", "has_team", "expected_team_name"),
+ [
+ pytest.param(True, True, "testing", id="bundle-mapped-to-team"),
+ pytest.param(True, False, None, id="bundle-without-team"),
+ pytest.param(False, False, None, id="no-bundle"),
+ ],
+ )
+ def test_queue_populates_trigger_team_name(
+ self, session, testing_dag_bundle, testing_team, has_bundle, has_team,
expected_team_name
+ ):
+ callback = self._queue_callback(session, has_bundle=has_bundle,
has_team=has_team)
+ assert callback.trigger.team_name == expected_team_name
+
+ @conf_vars({("core", "multi_team"): "False"})
+ def test_queue_skips_trigger_team_name_when_multi_team_disabled(
+ self, session, testing_dag_bundle, testing_team
+ ):
+ callback = self._queue_callback(session, has_bundle=True,
has_team=True)
+ assert callback.trigger.team_name is None
+
@pytest.mark.parametrize(
("event", "terminal_state"),
[
@@ -199,7 +236,7 @@ class TestTriggererCallback:
)
def test_handle_event(self, session, event, terminal_state):
callback = TriggererCallback(TEST_ASYNC_CALLBACK)
- callback.queue()
+ callback.queue(session=session)
callback.handle_event(event, session)
status = event.payload[PAYLOAD_STATUS_KEY]
@@ -230,11 +267,11 @@ class TestExecutorCallback:
assert retrieved.created_at is not None
assert retrieved.trigger_id is None
- def test_queue(self):
+ def test_queue(self, session):
callback = ExecutorCallback(TEST_SYNC_CALLBACK,
fetch_method=CallbackFetchMethod.DAG_ATTRIBUTE)
assert callback.state == CallbackState.SCHEDULED
- callback.queue()
+ callback.queue(session=session)
assert callback.state == CallbackState.QUEUED
def test_session_get_requires_uuid_not_str(self, session):
diff --git a/airflow-core/tests/unit/models/test_dagbundle.py
b/airflow-core/tests/unit/models/test_dagbundle.py
new file mode 100644
index 00000000000..84223e34a9b
--- /dev/null
+++ b/airflow-core/tests/unit/models/test_dagbundle.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import pytest
+
+from airflow.models.dagbundle import DagBundleModel
+
+from tests_common.test_utils.db import clear_db_dag_bundles, clear_db_teams
+
+if TYPE_CHECKING:
+ from sqlalchemy.orm import Session
+
+ from airflow.models.team import Team
+
+pytestmark = pytest.mark.db_test
+
+
+class TestDagBundleModel:
+ def setup_method(self):
+ clear_db_dag_bundles()
+ clear_db_teams()
+
+ def teardown_method(self):
+ clear_db_dag_bundles()
+ clear_db_teams()
+
+ def test_get_team_name(self, testing_team: Team, session: Session):
+ bundle = DagBundleModel(name="test_bundle")
+ bundle.teams.append(testing_team)
+ session.add(bundle)
+ session.flush()
+
+ assert DagBundleModel.get_team_name("test_bundle", session=session) ==
"testing"
+
+ def test_get_team_name_no_team(self, session: Session):
+ bundle = DagBundleModel(name="test_bundle")
+ session.add(bundle)
+ session.flush()
+
+ assert DagBundleModel.get_team_name("test_bundle", session=session) is
None
+
+ def test_get_team_name_unknown_bundle(self, session: Session):
+ assert DagBundleModel.get_team_name("does_not_exist", session=session)
is None
diff --git a/airflow-core/tests/unit/models/test_deadline.py
b/airflow-core/tests/unit/models/test_deadline.py
index 94c6977ae0c..c14ea44375e 100644
--- a/airflow-core/tests/unit/models/test_deadline.py
+++ b/airflow-core/tests/unit/models/test_deadline.py
@@ -210,6 +210,22 @@ class TestDeadline:
assert f"needed by {DEFAULT_DATE}" in repr_str
assert TEST_CALLBACK_PATH in repr_str
+ @pytest.mark.db_test
+ def test_bundle_name_propagated_to_callback(self, dagrun, session):
+ """The bundle name is forwarded to the callback so the triggerer can
resolve its team."""
+ deadline = Deadline(
+ deadline_time=DEFAULT_DATE,
+ callback=AsyncCallback(TEST_CALLBACK_PATH, TEST_CALLBACK_KWARGS),
+ dagrun_id=dagrun.id,
+ dag_id=dagrun.dag_id,
+ deadline_alert_id=None,
+ bundle_name="my_bundle",
+ )
+ session.add(deadline)
+ session.flush()
+
+ assert deadline.callback.bundle_name == "my_bundle"
+
@pytest.mark.db_test
def test_handle_miss(self, dagrun, session):
deadline_orm = Deadline(
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 99cdd38f4f3..6e7d0c569f9 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -111,6 +111,7 @@ from airflow.utils.state import DagRunState, State,
TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
from tests_common.test_utils import db
+from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.db import clear_db_runs
from tests_common.test_utils.mock_operators import MockOperator
from tests_common.test_utils.taskinstance import (
@@ -2954,6 +2955,72 @@ def test_defer_task_try_number_increment_on_state(
assert ti.try_number == expected_try_number, msg
+def _defer_ti_in_testing_bundle(create_task_instance, session, *, with_team):
+ """Create a deferrable TI whose Dag belongs to the ``testing`` bundle,
then defer it."""
+ from sqlalchemy import update
+
+ from airflow.models.dag import DagModel
+ from airflow.models.dagbundle import DagBundleModel
+ from airflow.models.team import Team
+ from airflow.triggers.base import StartTriggerArgs
+
+ bundle = session.get(DagBundleModel, "testing")
+ bundle.teams = [session.get(Team, "testing")] if with_team else []
+ session.flush()
+
+ ti = create_task_instance(
+ dag_id="test_defer_team",
+ task_id="op",
+ start_from_trigger=True,
+ start_trigger_args=StartTriggerArgs(
+ trigger_cls="airflow.triggers.testing.SuccessTrigger",
+ next_method="execute_complete",
+ trigger_kwargs={"moment": "2024-01-01"},
+ ),
+ session=session,
+ )
+ session.execute(
+ update(DagModel).where(DagModel.dag_id ==
"test_defer_team").values(bundle_name="testing")
+ )
+ session.flush()
+
+ ti.defer_task(session=session)
+ return ti
+
+
[email protected]_test
+@conf_vars({("core", "multi_team"): "True"})
[email protected](
+ ("with_team", "expected_team_name"),
+ [
+ pytest.param(True, "testing", id="bundle-mapped-to-team"),
+ pytest.param(False, None, id="bundle-without-team"),
+ ],
+)
+def test_defer_task_populates_trigger_team_name(
+ create_task_instance, session, testing_dag_bundle, testing_team,
with_team, expected_team_name
+):
+ from airflow.models.trigger import Trigger
+
+ _defer_ti_in_testing_bundle(create_task_instance, session,
with_team=with_team)
+
+ trigger = session.scalars(select(Trigger)).one()
+ assert trigger.team_name == expected_team_name
+
+
[email protected]_test
+@conf_vars({("core", "multi_team"): "False"})
+def test_defer_task_skips_trigger_team_name_when_multi_team_disabled(
+ create_task_instance, session, testing_dag_bundle, testing_team
+):
+ from airflow.models.trigger import Trigger
+
+ _defer_ti_in_testing_bundle(create_task_instance, session, with_team=True)
+
+ trigger = session.scalars(select(Trigger)).one()
+ assert trigger.team_name is None
+
+
class TestTaskInstanceRelationships:
@pytest.mark.parametrize(
"attr",