This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new da9bb786fe4 fix: TriggerDagRunOperator deferral mode not working for
Airflow 3 (#58497)
da9bb786fe4 is described below
commit da9bb786fe4b57135e508bca0f6895eb4ed7c863
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Nov 26 20:31:49 2025 +0100
fix: TriggerDagRunOperator deferral mode not working for Airflow 3 (#58497)
---
.../providers/standard/operators/trigger_dagrun.py | 37 ++++++--
.../providers/standard/triggers/external_task.py | 19 ++--
.../unit/standard/operators/test_trigger_dagrun.py | 14 +++
.../unit/standard/triggers/test_external_task.py | 102 +++++++++++++++++++++
4 files changed, 155 insertions(+), 17 deletions(-)
diff --git
a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
index 7d99acd8e36..16c420898a8 100644
---
a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
+++
b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
@@ -349,21 +349,40 @@ class TriggerDagRunOperator(BaseOperator):
return
def execute_complete(self, context: Context, event: tuple[str, dict[str,
Any]]):
- run_ids = event[1]["run_ids"]
+ """
+ Handle task completion after returning from a deferral.
+
+ Args:
+ context: The Airflow context dictionary.
+ event: A tuple containing the class path of the trigger and the
trigger event data.
+ """
+ # Example event tuple content:
+ # (
+ # "airflow.providers.standard.triggers.external_task.DagStateTrigger",
+ # {
+ # 'dag_id': 'some_dag',
+ # 'states': ['success', 'failed'],
+ # 'poll_interval': 15,
+ # 'run_ids': ['manual__2025-11-19T17:49:20.907083+00:00'],
+ # 'execution_dates': [
+ # DateTime(2025, 11, 19, 17, 49, 20, 907083, tzinfo=Timezone('UTC'))
+ # ]
+ # }
+ # )
+ _, event_data = event
+ run_ids = event_data["run_ids"]
# Re-set as attribute after coming back from deferral - to be used by
listeners.
# Just a safety check on length, we should always have single run_id
here.
self.trigger_run_id = run_ids[0] if len(run_ids) == 1 else None
if AIRFLOW_V_3_0_PLUS:
- self._trigger_dag_run_af_3_execute_complete(event=event)
+ self._trigger_dag_run_af_3_execute_complete(event_data=event_data)
else:
- self._trigger_dag_run_af_2_execute_complete(event=event)
+ self._trigger_dag_run_af_2_execute_complete(event_data=event_data)
- def _trigger_dag_run_af_3_execute_complete(self, event: tuple[str,
dict[str, Any]]):
- run_ids = event[1]["run_ids"]
- event_data = event[1]
+ def _trigger_dag_run_af_3_execute_complete(self, event_data: dict[str,
Any]):
failed_run_id_conditions = []
- for run_id in run_ids:
+ for run_id in event_data["run_ids"]:
state = event_data.get(run_id)
if state in self.failed_states:
failed_run_id_conditions.append(run_id)
@@ -387,10 +406,10 @@ class TriggerDagRunOperator(BaseOperator):
@provide_session
def _trigger_dag_run_af_2_execute_complete(
- self, event: tuple[str, dict[str, Any]], session: Session =
NEW_SESSION
+ self, event_data: dict[str, Any], session: Session = NEW_SESSION
):
# This logical_date is parsed from the return trigger event
- provided_logical_date = event[1]["execution_dates"][0]
+ provided_logical_date = event_data["execution_dates"][0]
try:
# Note: here execution fails on database isolation mode. Needs
structural changes for AIP-72
dag_run = session.execute(
diff --git
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
index 5295963c1d9..8a8e7f9db6b 100644
---
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py
+++
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
@@ -226,23 +226,26 @@ class DagStateTrigger(BaseTrigger):
elif self.execution_dates:
runs_ids_or_dates = len(self.execution_dates)
+ cls_path, data = self.serialize()
+
if AIRFLOW_V_3_0_PLUS:
- data = await
self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
- yield TriggerEvent(data)
+ data.update( # update with {run_id: run_state} dict
+ await
self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
+ )
+ yield TriggerEvent((cls_path, data))
return
else:
while True:
num_dags = await self.count_dags()
if num_dags == runs_ids_or_dates:
- yield TriggerEvent(self.serialize())
+ yield TriggerEvent((cls_path, data))
return
await asyncio.sleep(self.poll_interval)
- async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0)
-> dict[str, typing.Any]:
+ async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0)
-> dict[str, str]:
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
- cls_path, data = self.serialize()
-
+ run_states: dict[str, str] = {} # {run_id: run_state}
while True:
num_dags = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
dag_id=self.dag_id,
@@ -257,8 +260,8 @@ class DagStateTrigger(BaseTrigger):
dag_id=self.dag_id,
run_id=run_id,
)
- data[run_id] = state
- return data
+ run_states[run_id] = state
+ return run_states
await asyncio.sleep(self.poll_interval)
if not AIRFLOW_V_3_0_PLUS:
diff --git
a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py
b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py
index 42427662911..3a2e7c6b199 100644
--- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py
+++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py
@@ -282,6 +282,20 @@ class TestDagRunOperator:
assert operator.trigger_run_id == "run_id_1"
+ def
test_trigger_dag_run_execute_complete_fails_with_dict_as_input_type(self):
+ operator = TriggerDagRunOperator(
+ task_id="test_task",
+ trigger_dag_id=TRIGGERED_DAG_ID,
+ wait_for_completion=True,
+ poke_interval=10,
+ failed_states=[],
+ )
+
+ with pytest.raises(ValueError, match="too many values to unpack"):
+ operator.execute_complete(
+ {}, {"dag_id": "dag_id", "run_ids": ["run_id_1"],
"poll_interval": 15, "run_id_1": "success"}
+ )
+
@pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is
different for Airflow 2 & 3")
def test_trigger_dag_run_with_fail_when_dag_is_paused_should_fail(self):
with pytest.raises(
diff --git
a/providers/standard/tests/unit/standard/triggers/test_external_task.py
b/providers/standard/tests/unit/standard/triggers/test_external_task.py
index 4a970fd752e..840afb343e2 100644
--- a/providers/standard/tests/unit/standard/triggers/test_external_task.py
+++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py
@@ -713,6 +713,108 @@ class TestDagStateTrigger:
# Prevents error when task is destroyed while in "pending" state
asyncio.get_event_loop().stop()
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+ @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 had a
different implementation")
+
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dr_count")
+
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dagrun_state")
+ async def test_dag_state_trigger_af_3_return_type(
+ self, mock_get_dagrun_state, mock_get_dag_run_count, session
+ ):
+ """
+ Assert that the DagStateTrigger returns a tuple with classpath and
event_data.
+ """
+ mock_get_dag_run_count.return_value = 1
+ mock_get_dagrun_state.return_value = DagRunState.SUCCESS
+ dag = DAG(f"{self.DAG_ID}_return_type", schedule=None,
start_date=timezone.datetime(2022, 1, 1))
+
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ run_type="manual",
+ run_id="external_task_run_id",
+ logical_date=timezone.datetime(2022, 1, 1),
+ )
+ dag_run.state = DagRunState.SUCCESS
+ session.add(dag_run)
+ session.commit()
+
+ trigger = DagStateTrigger(
+ dag_id=dag.dag_id,
+ states=self.STATES,
+ run_ids=["external_task_run_id"],
+ poll_interval=0.2,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ )
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+ assert task.done() is True
+
+ result = task.result()
+ assert isinstance(result, TriggerEvent)
+ assert result.payload == (
+
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
+ {
+ "dag_id": "test_dag_state_trigger_return_type",
+ "execution_dates": [
+ timezone.datetime(2022, 1, 1, 0, 0, tzinfo=timezone.utc),
+ ],
+ "external_task_run_id": DagRunState.SUCCESS,
+ "poll_interval": 0.2,
+ "run_ids": ["external_task_run_id"],
+ "states": ["success", "fail"],
+ },
+ )
+ asyncio.get_event_loop().stop()
+
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+ @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only AF2
implementation.")
+ async def test_dag_state_trigger_af_2_return_type(self, session):
+ """
+ Assert that the DagStateTrigger returns a tuple with classpath and
event_data.
+ """
+ dag = DAG(f"{self.DAG_ID}_return_type", schedule=None,
start_date=timezone.datetime(2022, 1, 1))
+
+ dag_run = DagRun(
+ dag_id=dag.dag_id,
+ run_type="manual",
+ run_id="external_task_run_id",
+ execution_date=timezone.datetime(2022, 1, 1),
+ )
+ dag_run.state = DagRunState.SUCCESS
+ session.add(dag_run)
+ session.commit()
+
+ trigger = DagStateTrigger(
+ dag_id=dag.dag_id,
+ states=self.STATES,
+ run_ids=["external_task_run_id"],
+ poll_interval=0.2,
+ execution_dates=[timezone.datetime(2022, 1, 1)],
+ )
+
+ task = asyncio.create_task(trigger.run().__anext__())
+ await asyncio.sleep(0.5)
+ assert task.done() is True
+
+ result = task.result()
+ assert isinstance(result, TriggerEvent)
+ assert result.payload == (
+
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
+ {
+ "dag_id": "test_dag_state_trigger_return_type",
+ "execution_dates": [
+ timezone.datetime(2022, 1, 1, 0, 0, tzinfo=timezone.utc),
+ ],
+ # 'external_task_run_id': DagRunState.SUCCESS, # This is only
appended in AF3
+ "poll_interval": 0.2,
+ "run_ids": ["external_task_run_id"],
+ "states": ["success", "fail"],
+ },
+ )
+ asyncio.get_event_loop().stop()
+
def test_serialization(self):
"""Asserts that the DagStateTrigger correctly serializes its arguments
and classpath."""
trigger = DagStateTrigger(