This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new da9bb786fe4 fix: TriggerDagRunOperator deferral mode not working for 
Airflow 3 (#58497)
da9bb786fe4 is described below

commit da9bb786fe4b57135e508bca0f6895eb4ed7c863
Author: Kacper Muda <[email protected]>
AuthorDate: Wed Nov 26 20:31:49 2025 +0100

    fix: TriggerDagRunOperator deferral mode not working for Airflow 3 (#58497)
---
 .../providers/standard/operators/trigger_dagrun.py |  37 ++++++--
 .../providers/standard/triggers/external_task.py   |  19 ++--
 .../unit/standard/operators/test_trigger_dagrun.py |  14 +++
 .../unit/standard/triggers/test_external_task.py   | 102 +++++++++++++++++++++
 4 files changed, 155 insertions(+), 17 deletions(-)

diff --git 
a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py 
b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
index 7d99acd8e36..16c420898a8 100644
--- 
a/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
+++ 
b/providers/standard/src/airflow/providers/standard/operators/trigger_dagrun.py
@@ -349,21 +349,40 @@ class TriggerDagRunOperator(BaseOperator):
                     return
 
     def execute_complete(self, context: Context, event: tuple[str, dict[str, 
Any]]):
-        run_ids = event[1]["run_ids"]
+        """
+        Handle task completion after returning from a deferral.
+
+        Args:
+            context: The Airflow context dictionary.
+            event: A tuple containing the class path of the trigger and the 
trigger event data.
+        """
+        # Example event tuple content:
+        # (
+        #  "airflow.providers.standard.triggers.external_task.DagStateTrigger",
+        #  {
+        #   'dag_id': 'some_dag',
+        #   'states': ['success', 'failed'],
+        #   'poll_interval': 15,
+        #   'run_ids': ['manual__2025-11-19T17:49:20.907083+00:00'],
+        #   'execution_dates': [
+        #    DateTime(2025, 11, 19, 17, 49, 20, 907083, tzinfo=Timezone('UTC'))
+        #   ]
+        #  }
+        # )
+        _, event_data = event
+        run_ids = event_data["run_ids"]
         # Re-set as attribute after coming back from deferral - to be used by 
listeners.
         # Just a safety check on length, we should always have single run_id 
here.
         self.trigger_run_id = run_ids[0] if len(run_ids) == 1 else None
         if AIRFLOW_V_3_0_PLUS:
-            self._trigger_dag_run_af_3_execute_complete(event=event)
+            self._trigger_dag_run_af_3_execute_complete(event_data=event_data)
         else:
-            self._trigger_dag_run_af_2_execute_complete(event=event)
+            self._trigger_dag_run_af_2_execute_complete(event_data=event_data)
 
-    def _trigger_dag_run_af_3_execute_complete(self, event: tuple[str, 
dict[str, Any]]):
-        run_ids = event[1]["run_ids"]
-        event_data = event[1]
+    def _trigger_dag_run_af_3_execute_complete(self, event_data: dict[str, 
Any]):
         failed_run_id_conditions = []
 
-        for run_id in run_ids:
+        for run_id in event_data["run_ids"]:
             state = event_data.get(run_id)
             if state in self.failed_states:
                 failed_run_id_conditions.append(run_id)
@@ -387,10 +406,10 @@ class TriggerDagRunOperator(BaseOperator):
 
         @provide_session
         def _trigger_dag_run_af_2_execute_complete(
-            self, event: tuple[str, dict[str, Any]], session: Session = 
NEW_SESSION
+            self, event_data: dict[str, Any], session: Session = NEW_SESSION
         ):
             # This logical_date is parsed from the return trigger event
-            provided_logical_date = event[1]["execution_dates"][0]
+            provided_logical_date = event_data["execution_dates"][0]
             try:
                 # Note: here execution fails on database isolation mode. Needs 
structural changes for AIP-72
                 dag_run = session.execute(
diff --git 
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py 
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
index 5295963c1d9..8a8e7f9db6b 100644
--- 
a/providers/standard/src/airflow/providers/standard/triggers/external_task.py
+++ 
b/providers/standard/src/airflow/providers/standard/triggers/external_task.py
@@ -226,23 +226,26 @@ class DagStateTrigger(BaseTrigger):
         elif self.execution_dates:
             runs_ids_or_dates = len(self.execution_dates)
 
+        cls_path, data = self.serialize()
+
         if AIRFLOW_V_3_0_PLUS:
-            data = await 
self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
-            yield TriggerEvent(data)
+            data.update(  # update with {run_id: run_state} dict
+                await 
self.validate_count_dags_af_3(runs_ids_or_dates_len=runs_ids_or_dates)
+            )
+            yield TriggerEvent((cls_path, data))
             return
         else:
             while True:
                 num_dags = await self.count_dags()
                 if num_dags == runs_ids_or_dates:
-                    yield TriggerEvent(self.serialize())
+                    yield TriggerEvent((cls_path, data))
                     return
                 await asyncio.sleep(self.poll_interval)
 
-    async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) 
-> dict[str, typing.Any]:
+    async def validate_count_dags_af_3(self, runs_ids_or_dates_len: int = 0) 
-> dict[str, str]:
         from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
 
-        cls_path, data = self.serialize()
-
+        run_states: dict[str, str] = {}  # {run_id: run_state}
         while True:
             num_dags = await sync_to_async(RuntimeTaskInstance.get_dr_count)(
                 dag_id=self.dag_id,
@@ -257,8 +260,8 @@ class DagStateTrigger(BaseTrigger):
                             dag_id=self.dag_id,
                             run_id=run_id,
                         )
-                        data[run_id] = state
-                        return data
+                        run_states[run_id] = state
+                        return run_states
             await asyncio.sleep(self.poll_interval)
 
     if not AIRFLOW_V_3_0_PLUS:
diff --git 
a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py 
b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py
index 42427662911..3a2e7c6b199 100644
--- a/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py
+++ b/providers/standard/tests/unit/standard/operators/test_trigger_dagrun.py
@@ -282,6 +282,20 @@ class TestDagRunOperator:
 
         assert operator.trigger_run_id == "run_id_1"
 
+    def 
test_trigger_dag_run_execute_complete_fails_with_dict_as_input_type(self):
+        operator = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id=TRIGGERED_DAG_ID,
+            wait_for_completion=True,
+            poke_interval=10,
+            failed_states=[],
+        )
+
+        with pytest.raises(ValueError, match="too many values to unpack"):
+            operator.execute_complete(
+                {}, {"dag_id": "dag_id", "run_ids": ["run_id_1"], 
"poll_interval": 15, "run_id_1": "success"}
+            )
+
     @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Implementation is 
different for Airflow 2 & 3")
     def test_trigger_dag_run_with_fail_when_dag_is_paused_should_fail(self):
         with pytest.raises(
diff --git 
a/providers/standard/tests/unit/standard/triggers/test_external_task.py 
b/providers/standard/tests/unit/standard/triggers/test_external_task.py
index 4a970fd752e..840afb343e2 100644
--- a/providers/standard/tests/unit/standard/triggers/test_external_task.py
+++ b/providers/standard/tests/unit/standard/triggers/test_external_task.py
@@ -713,6 +713,108 @@ class TestDagStateTrigger:
         # Prevents error when task is destroyed while in "pending" state
         asyncio.get_event_loop().stop()
 
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    @pytest.mark.skipif(not AIRFLOW_V_3_0_PLUS, reason="Airflow 2 had a 
different implementation")
+    
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dr_count")
+    
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_dagrun_state")
+    async def test_dag_state_trigger_af_3_return_type(
+        self, mock_get_dagrun_state, mock_get_dag_run_count, session
+    ):
+        """
+        Assert that the DagStateTrigger returns a tuple with classpath and 
event_data.
+        """
+        mock_get_dag_run_count.return_value = 1
+        mock_get_dagrun_state.return_value = DagRunState.SUCCESS
+        dag = DAG(f"{self.DAG_ID}_return_type", schedule=None, 
start_date=timezone.datetime(2022, 1, 1))
+
+        dag_run = DagRun(
+            dag_id=dag.dag_id,
+            run_type="manual",
+            run_id="external_task_run_id",
+            logical_date=timezone.datetime(2022, 1, 1),
+        )
+        dag_run.state = DagRunState.SUCCESS
+        session.add(dag_run)
+        session.commit()
+
+        trigger = DagStateTrigger(
+            dag_id=dag.dag_id,
+            states=self.STATES,
+            run_ids=["external_task_run_id"],
+            poll_interval=0.2,
+            execution_dates=[timezone.datetime(2022, 1, 1)],
+        )
+
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+        assert task.done() is True
+
+        result = task.result()
+        assert isinstance(result, TriggerEvent)
+        assert result.payload == (
+            
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
+            {
+                "dag_id": "test_dag_state_trigger_return_type",
+                "execution_dates": [
+                    timezone.datetime(2022, 1, 1, 0, 0, tzinfo=timezone.utc),
+                ],
+                "external_task_run_id": DagRunState.SUCCESS,
+                "poll_interval": 0.2,
+                "run_ids": ["external_task_run_id"],
+                "states": ["success", "fail"],
+            },
+        )
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    @pytest.mark.skipif(AIRFLOW_V_3_0_PLUS, reason="Test only AF2 
implementation.")
+    async def test_dag_state_trigger_af_2_return_type(self, session):
+        """
+        Assert that the DagStateTrigger returns a tuple with classpath and 
event_data.
+        """
+        dag = DAG(f"{self.DAG_ID}_return_type", schedule=None, 
start_date=timezone.datetime(2022, 1, 1))
+
+        dag_run = DagRun(
+            dag_id=dag.dag_id,
+            run_type="manual",
+            run_id="external_task_run_id",
+            execution_date=timezone.datetime(2022, 1, 1),
+        )
+        dag_run.state = DagRunState.SUCCESS
+        session.add(dag_run)
+        session.commit()
+
+        trigger = DagStateTrigger(
+            dag_id=dag.dag_id,
+            states=self.STATES,
+            run_ids=["external_task_run_id"],
+            poll_interval=0.2,
+            execution_dates=[timezone.datetime(2022, 1, 1)],
+        )
+
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+        assert task.done() is True
+
+        result = task.result()
+        assert isinstance(result, TriggerEvent)
+        assert result.payload == (
+            
"airflow.providers.standard.triggers.external_task.DagStateTrigger",
+            {
+                "dag_id": "test_dag_state_trigger_return_type",
+                "execution_dates": [
+                    timezone.datetime(2022, 1, 1, 0, 0, tzinfo=timezone.utc),
+                ],
+                # 'external_task_run_id': DagRunState.SUCCESS,  # This is only 
appended in AF3
+                "poll_interval": 0.2,
+                "run_ids": ["external_task_run_id"],
+                "states": ["success", "fail"],
+            },
+        )
+        asyncio.get_event_loop().stop()
+
     def test_serialization(self):
         """Asserts that the DagStateTrigger correctly serializes its arguments 
and classpath."""
         trigger = DagStateTrigger(

Reply via email to