This is an automated email from the ASF dual-hosted git repository.

utkarsharma pushed a commit to branch v2-9-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 133f5e4738fe33e69fefb9d57359570cd6807aa1
Author: Josh Fell <[email protected]>
AuthorDate: Sat Jun 1 23:02:40 2024 -0400

    Pass triggered or existing DAG Run logical date to DagStateTrigger (#39960)
    
    Closes: #38353
    
    When using the TriggerDagRunOperator in `deferrable=True` mode, the 
DagStateTrigger is being passed the incorrect logical date to poll for. The 
trigger is using a logical date that is calculated on every execution rather 
than the logical from either the triggered DAG run or an existing DAG run (if 
the task is configured to not fail for existing DAG runs).
    
    This change corrects the logical date being used by the DagStateTrigger to 
poll for the triggered (or reset) DAG run.
    
    (cherry picked from commit ae648e6884a91d200ee63418563e32a2f78874c3)
---
 airflow/operators/trigger_dagrun.py    |  6 +--
 tests/operators/test_trigger_dagrun.py | 86 ++++++++++++++++++++++++++++++++--
 2 files changed, 85 insertions(+), 7 deletions(-)

diff --git a/airflow/operators/trigger_dagrun.py 
b/airflow/operators/trigger_dagrun.py
index f8cfa5256a..cebd18ae96 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -184,7 +184,8 @@ class TriggerDagRunOperator(BaseOperator):
 
         except DagRunAlreadyExists as e:
             if self.reset_dag_run:
-                self.log.info("Clearing %s on %s", self.trigger_dag_id, 
parsed_logical_date)
+                dag_run = e.dag_run
+                self.log.info("Clearing %s on %s", self.trigger_dag_id, 
dag_run.logical_date)
 
                 # Get target dag object and call clear()
                 dag_model = DagModel.get_current(self.trigger_dag_id)
@@ -193,7 +194,6 @@ class TriggerDagRunOperator(BaseOperator):
 
                 dag_bag = DagBag(dag_folder=dag_model.fileloc, 
read_dags_from_db=True)
                 dag = dag_bag.get_dag(self.trigger_dag_id)
-                dag_run = e.dag_run
                 dag.clear(start_date=dag_run.logical_date, 
end_date=dag_run.logical_date)
             else:
                 raise e
@@ -212,7 +212,7 @@ class TriggerDagRunOperator(BaseOperator):
                     trigger=DagStateTrigger(
                         dag_id=self.trigger_dag_id,
                         states=self.allowed_states + self.failed_states,
-                        execution_dates=[parsed_logical_date],
+                        execution_dates=[dag_run.logical_date],
                         poll_interval=self.poke_interval,
                     ),
                     method_name="execute_complete",
diff --git a/tests/operators/test_trigger_dagrun.py 
b/tests/operators/test_trigger_dagrun.py
index 9eed9b786e..90f7827f6f 100644
--- a/tests/operators/test_trigger_dagrun.py
+++ b/tests/operators/test_trigger_dagrun.py
@@ -22,9 +22,10 @@ import tempfile
 from datetime import datetime
 from unittest import mock
 
+import pendulum
 import pytest
 
-from airflow.exceptions import AirflowException, DagRunAlreadyExists
+from airflow.exceptions import AirflowException, DagRunAlreadyExists, 
TaskDeferred
 from airflow.models.dag import DAG, DagModel
 from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun
@@ -35,7 +36,7 @@ from airflow.operators.trigger_dagrun import 
TriggerDagRunOperator
 from airflow.triggers.external_task import DagStateTrigger
 from airflow.utils import timezone
 from airflow.utils.session import create_session
-from airflow.utils.state import State
+from airflow.utils.state import DagRunState, State
 from airflow.utils.types import DagRunType
 
 pytestmark = pytest.mark.db_test
@@ -50,8 +51,8 @@ from airflow.operators.empty import EmptyOperator
 
 dag = DAG(
     dag_id='{TRIGGERED_DAG_ID}',
-    default_args={{'start_date': datetime(2019, 1, 1)}},
-    schedule_interval=None
+    schedule=None,
+    start_date=datetime(2019, 1, 1),
 )
 
 task = EmptyOperator(task_id='test', dag=dag)
@@ -547,3 +548,80 @@ class TestDagRunOperator:
             assert dagrun.logical_date == custom_execution_date
             assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, 
custom_execution_date)
             self.assert_extra_link(dagrun, task, session)
+
+    @pytest.mark.parametrize(
+        argnames=["trigger_logical_date"],
+        argvalues=[
+            pytest.param(DEFAULT_DATE, id=f"logical_date={DEFAULT_DATE}"),
+            pytest.param(None, id="logical_date=None"),
+        ],
+    )
+    def test_dagstatetrigger_execution_dates(self, trigger_logical_date):
+        """Ensure that the DagStateTrigger is called with the triggered DAG's 
logical date."""
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id=TRIGGERED_DAG_ID,
+            logical_date=trigger_logical_date,
+            wait_for_completion=True,
+            poke_interval=5,
+            allowed_states=[DagRunState.QUEUED],
+            deferrable=True,
+            dag=self.dag,
+        )
+
+        mock_task_defer = mock.MagicMock(side_effect=task.defer)
+        with mock.patch.object(TriggerDagRunOperator, "defer", 
mock_task_defer), pytest.raises(TaskDeferred):
+            task.execute({"task_instance": mock.MagicMock()})
+
+        with create_session() as session:
+            dagruns = session.query(DagRun).filter(DagRun.dag_id == 
TRIGGERED_DAG_ID).all()
+            assert len(dagruns) == 1
+
+        assert 
mock_task_defer.call_args_list[0].kwargs["trigger"].execution_dates == [
+            pendulum.instance(dagruns[0].logical_date)
+        ]
+
+    def test_dagstatetrigger_execution_dates_with_clear_and_reset(self):
+        """Check DagStateTrigger is called with the triggered DAG's logical 
date on subsequent defers."""
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id=TRIGGERED_DAG_ID,
+            trigger_run_id="custom_run_id",
+            wait_for_completion=True,
+            poke_interval=5,
+            allowed_states=[DagRunState.QUEUED],
+            deferrable=True,
+            reset_dag_run=True,
+            dag=self.dag,
+        )
+
+        mock_task_defer = mock.MagicMock(side_effect=task.defer)
+        with mock.patch.object(TriggerDagRunOperator, "defer", 
mock_task_defer), pytest.raises(TaskDeferred):
+            task.execute({"task_instance": mock.MagicMock()})
+
+        with create_session() as session:
+            dagruns = session.query(DagRun).filter(DagRun.dag_id == 
TRIGGERED_DAG_ID).all()
+            triggered_logical_date = dagruns[0].logical_date
+            assert len(dagruns) == 1
+
+        assert 
mock_task_defer.call_args_list[0].kwargs["trigger"].execution_dates == [
+            pendulum.instance(triggered_logical_date)
+        ]
+
+        # Simulate the TriggerDagRunOperator task being cleared (aka executed 
again). A DagRunAlreadyExists
+        # exception should be raised because of the previous DAG run.
+        with mock.patch.object(TriggerDagRunOperator, "defer", 
mock_task_defer), pytest.raises(
+            (DagRunAlreadyExists, TaskDeferred)
+        ):
+            task.execute({"task_instance": mock.MagicMock()})
+
+        # Still only one DAG run should exist for the triggered DAG since the 
DAG will be cleared since the
+        # TriggerDagRunOperator task is configured with `reset_dag_run=True`.
+        with create_session() as session:
+            dagruns = session.query(DagRun).filter(DagRun.dag_id == 
TRIGGERED_DAG_ID).all()
+            assert len(dagruns) == 1
+
+        # The second DagStateTrigger call should still use the original 
`logical_date` value.
+        assert 
mock_task_defer.call_args_list[1].kwargs["trigger"].execution_dates == [
+            pendulum.instance(triggered_logical_date)
+        ]

Reply via email to