This is an automated email from the ASF dual-hosted git repository.

basph pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e7849e341 Deferrable `TriggerDagRunOperator` (#30292)
1e7849e341 is described below

commit 1e7849e341c0c33bb93058a0f0805cfb13d5f4ac
Author: Dylan Storey <[email protected]>
AuthorDate: Wed Mar 29 05:33:56 2023 -0400

    Deferrable `TriggerDagRunOperator` (#30292)
    
    * need to test the deferral loop
    
    * need to test the deferral loop
    
    * TriggerDagRun is deferrable
    
    * Update airflow/operators/trigger_dagrun.py
    
    Co-authored-by: Bas Harenslak <[email protected]>
    
    * Update airflow/operators/trigger_dagrun.py
    
    Co-authored-by: Bas Harenslak <[email protected]>
    
    * Update airflow/operators/trigger_dagrun.py
    
    Co-authored-by: Bas Harenslak <[email protected]>
    
    * Update airflow/operators/trigger_dagrun.py
    
    Co-authored-by: Bas Harenslak <[email protected]>
    
    * inforporating feedback
    
    * feedback in and pre-commit run
    
    * feedback in and pre-commit run
    
    * feedback in and pre-commit run
    
    * incorporating feedback
    
    * one day i'll run pre-commit before pushing
    
    ---------
    
    Co-authored-by: Bas Harenslak <[email protected]>
---
 airflow/operators/trigger_dagrun.py    | 53 ++++++++++++++++++++
 tests/operators/test_trigger_dagrun.py | 88 ++++++++++++++++++++++++++++++++++
 2 files changed, 141 insertions(+)

diff --git a/airflow/operators/trigger_dagrun.py 
b/airflow/operators/trigger_dagrun.py
index 256923e30a..9a84bfac97 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -22,6 +22,8 @@ import json
 import time
 from typing import TYPE_CHECKING, Sequence, cast
 
+from sqlalchemy.orm.exc import NoResultFound
+
 from airflow.api.common.trigger_dag import trigger_dag
 from airflow.exceptions import AirflowException, DagNotFound, 
DagRunAlreadyExists
 from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
@@ -29,9 +31,11 @@ from airflow.models.dag import DagModel
 from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun
 from airflow.models.xcom import XCom
+from airflow.triggers.external_task import DagStateTrigger
 from airflow.utils import timezone
 from airflow.utils.context import Context
 from airflow.utils.helpers import build_airflow_url_with_query
+from airflow.utils.session import provide_session
 from airflow.utils.state import State
 from airflow.utils.types import DagRunType
 
@@ -40,6 +44,8 @@ XCOM_RUN_ID = "trigger_run_id"
 
 
 if TYPE_CHECKING:
+    from sqlalchemy.orm.session import Session
+
     from airflow.models.taskinstance import TaskInstanceKey
 
 
@@ -79,6 +85,8 @@ class TriggerDagRunOperator(BaseOperator):
         (default: 60)
     :param allowed_states: List of allowed states, default is ``['success']``.
     :param failed_states: List of failed or dis-allowed states, default is 
``None``.
+    :param deferrable: If waiting for completion, whether or not to defer the 
task until done,
+        default is ``False``.
     """
 
     template_fields: Sequence[str] = ("trigger_dag_id", "trigger_run_id", 
"execution_date", "conf")
@@ -98,6 +106,7 @@ class TriggerDagRunOperator(BaseOperator):
         poke_interval: int = 60,
         allowed_states: list | None = None,
         failed_states: list | None = None,
+        deferrable: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -109,6 +118,7 @@ class TriggerDagRunOperator(BaseOperator):
         self.poke_interval = poke_interval
         self.allowed_states = allowed_states or [State.SUCCESS]
         self.failed_states = failed_states or [State.FAILED]
+        self._defer = deferrable
 
         if execution_date is not None and not isinstance(execution_date, (str, 
datetime.datetime)):
             raise TypeError(
@@ -118,6 +128,7 @@ class TriggerDagRunOperator(BaseOperator):
         self.execution_date = execution_date
 
     def execute(self, context: Context):
+
         if isinstance(self.execution_date, datetime.datetime):
             parsed_execution_date = self.execution_date
         elif isinstance(self.execution_date, str):
@@ -134,6 +145,7 @@ class TriggerDagRunOperator(BaseOperator):
             run_id = self.trigger_run_id
         else:
             run_id = DagRun.generate_run_id(DagRunType.MANUAL, 
parsed_execution_date)
+
         try:
             dag_run = trigger_dag(
                 dag_id=self.trigger_dag_id,
@@ -168,6 +180,18 @@ class TriggerDagRunOperator(BaseOperator):
         ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)
 
         if self.wait_for_completion:
+
+            # Kick off the deferral process
+            if self._defer:
+                self.defer(
+                    trigger=DagStateTrigger(
+                        dag_id=self.trigger_dag_id,
+                        states=self.allowed_states + self.failed_states,
+                        execution_dates=[parsed_execution_date],
+                        poll_interval=self.poke_interval,
+                    ),
+                    method_name="execute_complete",
+                )
             # wait for dag to complete
             while True:
                 self.log.info(
@@ -185,3 +209,32 @@ class TriggerDagRunOperator(BaseOperator):
                 if state in self.allowed_states:
                     self.log.info("%s finished with allowed state %s", 
self.trigger_dag_id, state)
                     return
+
+    @provide_session
+    def execute_complete(self, context: Context, session: Session, **kwargs):
+        parsed_execution_date = context["execution_date"]
+
+        try:
+            dag_run = (
+                session.query(DagRun)
+                .filter(DagRun.dag_id == self.trigger_dag_id, 
DagRun.execution_date == parsed_execution_date)
+                .one()
+            )
+
+        except NoResultFound:
+            raise AirflowException(
+                f"No DAG run found for DAG {self.trigger_dag_id} and execution 
date {self.execution_date}"
+            )
+
+        state = dag_run.state
+
+        if state in self.failed_states:
+            raise AirflowException(f"{self.trigger_dag_id} failed with failed 
state {state}")
+        if state in self.allowed_states:
+            self.log.info("%s finished with allowed state %s", 
self.trigger_dag_id, state)
+            return
+
+        raise AirflowException(
+            f"{self.trigger_dag_id} return {state} which is not in 
{self.failed_states}"
+            f" or {self.allowed_states}"
+        )
diff --git a/tests/operators/test_trigger_dagrun.py 
b/tests/operators/test_trigger_dagrun.py
index fdaa7263b2..cb2d75e84c 100644
--- a/tests/operators/test_trigger_dagrun.py
+++ b/tests/operators/test_trigger_dagrun.py
@@ -371,3 +371,91 @@ class TestDagRunOperator:
         )
         with pytest.raises(DagRunAlreadyExists):
             task.run(start_date=execution_date, end_date=execution_date)
+
+    def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self):
+        """Test TriggerDagRunOperator with wait_for_completion."""
+        execution_date = DEFAULT_DATE
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id=TRIGGERED_DAG_ID,
+            execution_date=execution_date,
+            wait_for_completion=True,
+            poke_interval=10,
+            allowed_states=[State.QUEUED],
+            deferrable=False,
+            dag=self.dag,
+        )
+        task.run(start_date=execution_date, end_date=execution_date)
+
+        with create_session() as session:
+            dagruns = session.query(DagRun).filter(DagRun.dag_id == 
TRIGGERED_DAG_ID).all()
+            assert len(dagruns) == 1
+
+    def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self):
+        """Test TriggerDagRunOperator with wait_for_completion."""
+        execution_date = DEFAULT_DATE
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id=TRIGGERED_DAG_ID,
+            execution_date=execution_date,
+            wait_for_completion=True,
+            poke_interval=10,
+            allowed_states=[State.QUEUED],
+            deferrable=True,
+            dag=self.dag,
+        )
+
+        task.run(start_date=execution_date, end_date=execution_date)
+
+        with create_session() as session:
+            dagruns = session.query(DagRun).filter(DagRun.dag_id == 
TRIGGERED_DAG_ID).all()
+            assert len(dagruns) == 1
+
+        task.execute_complete(context={"execution_date": execution_date, 
"logical_date": execution_date})
+
+    def 
test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self):
+        """Test TriggerDagRunOperator with wait_for_completion."""
+        execution_date = DEFAULT_DATE
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id=TRIGGERED_DAG_ID,
+            execution_date=execution_date,
+            wait_for_completion=True,
+            poke_interval=10,
+            allowed_states=[State.SUCCESS],
+            deferrable=True,
+            dag=self.dag,
+        )
+
+        task.run(start_date=execution_date, end_date=execution_date)
+
+        with create_session() as session:
+            dagruns = session.query(DagRun).filter(DagRun.dag_id == 
TRIGGERED_DAG_ID).all()
+            assert len(dagruns) == 1
+
+        with pytest.raises(AirflowException):
+            task.execute_complete(context={"execution_date": execution_date, 
"logical_date": execution_date})
+
+    def 
test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self):
+        """Test TriggerDagRunOperator with wait_for_completion."""
+        execution_date = DEFAULT_DATE
+        task = TriggerDagRunOperator(
+            task_id="test_task",
+            trigger_dag_id=TRIGGERED_DAG_ID,
+            execution_date=execution_date,
+            wait_for_completion=True,
+            poke_interval=10,
+            allowed_states=[State.SUCCESS],
+            failed_states=[State.QUEUED],
+            deferrable=True,
+            dag=self.dag,
+        )
+
+        task.run(start_date=execution_date, end_date=execution_date)
+
+        with create_session() as session:
+            dagruns = session.query(DagRun).filter(DagRun.dag_id == 
TRIGGERED_DAG_ID).all()
+            assert len(dagruns) == 1
+
+        with pytest.raises(AirflowException):
+            task.execute_complete(context={"execution_date": execution_date, 
"logical_date": execution_date})

Reply via email to