This is an automated email from the ASF dual-hosted git repository.
basph pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1e7849e341 Deferrable `TriggerDagRunOperator` (#30292)
1e7849e341 is described below
commit 1e7849e341c0c33bb93058a0f0805cfb13d5f4ac
Author: Dylan Storey <[email protected]>
AuthorDate: Wed Mar 29 05:33:56 2023 -0400
Deferrable `TriggerDagRunOperator` (#30292)
* need to test the deferral loop
* need to test the deferral loop
* TriggerDagRun is deferrable
* Update airflow/operators/trigger_dagrun.py
Co-authored-by: Bas Harenslak <[email protected]>
* Update airflow/operators/trigger_dagrun.py
Co-authored-by: Bas Harenslak <[email protected]>
* Update airflow/operators/trigger_dagrun.py
Co-authored-by: Bas Harenslak <[email protected]>
* Update airflow/operators/trigger_dagrun.py
Co-authored-by: Bas Harenslak <[email protected]>
* inforporating feedback
* feedback in and pre-commit run
* feedback in and pre-commit run
* feedback in and pre-commit run
* incorporating feedback
* one day i'll run pre-commit before pushing
---------
Co-authored-by: Bas Harenslak <[email protected]>
---
airflow/operators/trigger_dagrun.py | 53 ++++++++++++++++++++
tests/operators/test_trigger_dagrun.py | 88 ++++++++++++++++++++++++++++++++++
2 files changed, 141 insertions(+)
diff --git a/airflow/operators/trigger_dagrun.py
b/airflow/operators/trigger_dagrun.py
index 256923e30a..9a84bfac97 100644
--- a/airflow/operators/trigger_dagrun.py
+++ b/airflow/operators/trigger_dagrun.py
@@ -22,6 +22,8 @@ import json
import time
from typing import TYPE_CHECKING, Sequence, cast
+from sqlalchemy.orm.exc import NoResultFound
+
from airflow.api.common.trigger_dag import trigger_dag
from airflow.exceptions import AirflowException, DagNotFound,
DagRunAlreadyExists
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
@@ -29,9 +31,11 @@ from airflow.models.dag import DagModel
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun
from airflow.models.xcom import XCom
+from airflow.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.context import Context
from airflow.utils.helpers import build_airflow_url_with_query
+from airflow.utils.session import provide_session
from airflow.utils.state import State
from airflow.utils.types import DagRunType
@@ -40,6 +44,8 @@ XCOM_RUN_ID = "trigger_run_id"
if TYPE_CHECKING:
+ from sqlalchemy.orm.session import Session
+
from airflow.models.taskinstance import TaskInstanceKey
@@ -79,6 +85,8 @@ class TriggerDagRunOperator(BaseOperator):
(default: 60)
:param allowed_states: List of allowed states, default is ``['success']``.
:param failed_states: List of failed or dis-allowed states, default is
``None``.
+ :param deferrable: If waiting for completion, whether or not to defer the
task until done,
+ default is ``False``.
"""
template_fields: Sequence[str] = ("trigger_dag_id", "trigger_run_id",
"execution_date", "conf")
@@ -98,6 +106,7 @@ class TriggerDagRunOperator(BaseOperator):
poke_interval: int = 60,
allowed_states: list | None = None,
failed_states: list | None = None,
+ deferrable: bool = False,
**kwargs,
) -> None:
super().__init__(**kwargs)
@@ -109,6 +118,7 @@ class TriggerDagRunOperator(BaseOperator):
self.poke_interval = poke_interval
self.allowed_states = allowed_states or [State.SUCCESS]
self.failed_states = failed_states or [State.FAILED]
+ self._defer = deferrable
if execution_date is not None and not isinstance(execution_date, (str,
datetime.datetime)):
raise TypeError(
@@ -118,6 +128,7 @@ class TriggerDagRunOperator(BaseOperator):
self.execution_date = execution_date
def execute(self, context: Context):
+
if isinstance(self.execution_date, datetime.datetime):
parsed_execution_date = self.execution_date
elif isinstance(self.execution_date, str):
@@ -134,6 +145,7 @@ class TriggerDagRunOperator(BaseOperator):
run_id = self.trigger_run_id
else:
run_id = DagRun.generate_run_id(DagRunType.MANUAL,
parsed_execution_date)
+
try:
dag_run = trigger_dag(
dag_id=self.trigger_dag_id,
@@ -168,6 +180,18 @@ class TriggerDagRunOperator(BaseOperator):
ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id)
if self.wait_for_completion:
+
+ # Kick off the deferral process
+ if self._defer:
+ self.defer(
+ trigger=DagStateTrigger(
+ dag_id=self.trigger_dag_id,
+ states=self.allowed_states + self.failed_states,
+ execution_dates=[parsed_execution_date],
+ poll_interval=self.poke_interval,
+ ),
+ method_name="execute_complete",
+ )
# wait for dag to complete
while True:
self.log.info(
@@ -185,3 +209,32 @@ class TriggerDagRunOperator(BaseOperator):
if state in self.allowed_states:
self.log.info("%s finished with allowed state %s",
self.trigger_dag_id, state)
return
+
+ @provide_session
+ def execute_complete(self, context: Context, session: Session, **kwargs):
+ parsed_execution_date = context["execution_date"]
+
+ try:
+ dag_run = (
+ session.query(DagRun)
+ .filter(DagRun.dag_id == self.trigger_dag_id,
DagRun.execution_date == parsed_execution_date)
+ .one()
+ )
+
+ except NoResultFound:
+ raise AirflowException(
+ f"No DAG run found for DAG {self.trigger_dag_id} and execution
date {self.execution_date}"
+ )
+
+ state = dag_run.state
+
+ if state in self.failed_states:
+ raise AirflowException(f"{self.trigger_dag_id} failed with failed
state {state}")
+ if state in self.allowed_states:
+ self.log.info("%s finished with allowed state %s",
self.trigger_dag_id, state)
+ return
+
+ raise AirflowException(
+ f"{self.trigger_dag_id} return {state} which is not in
{self.failed_states}"
+ f" or {self.allowed_states}"
+ )
diff --git a/tests/operators/test_trigger_dagrun.py
b/tests/operators/test_trigger_dagrun.py
index fdaa7263b2..cb2d75e84c 100644
--- a/tests/operators/test_trigger_dagrun.py
+++ b/tests/operators/test_trigger_dagrun.py
@@ -371,3 +371,91 @@ class TestDagRunOperator:
)
with pytest.raises(DagRunAlreadyExists):
task.run(start_date=execution_date, end_date=execution_date)
+
+ def test_trigger_dagrun_with_wait_for_completion_true_defer_false(self):
+ """Test TriggerDagRunOperator with wait_for_completion."""
+ execution_date = DEFAULT_DATE
+ task = TriggerDagRunOperator(
+ task_id="test_task",
+ trigger_dag_id=TRIGGERED_DAG_ID,
+ execution_date=execution_date,
+ wait_for_completion=True,
+ poke_interval=10,
+ allowed_states=[State.QUEUED],
+ deferrable=False,
+ dag=self.dag,
+ )
+ task.run(start_date=execution_date, end_date=execution_date)
+
+ with create_session() as session:
+ dagruns = session.query(DagRun).filter(DagRun.dag_id ==
TRIGGERED_DAG_ID).all()
+ assert len(dagruns) == 1
+
+ def test_trigger_dagrun_with_wait_for_completion_true_defer_true(self):
+ """Test TriggerDagRunOperator with wait_for_completion."""
+ execution_date = DEFAULT_DATE
+ task = TriggerDagRunOperator(
+ task_id="test_task",
+ trigger_dag_id=TRIGGERED_DAG_ID,
+ execution_date=execution_date,
+ wait_for_completion=True,
+ poke_interval=10,
+ allowed_states=[State.QUEUED],
+ deferrable=True,
+ dag=self.dag,
+ )
+
+ task.run(start_date=execution_date, end_date=execution_date)
+
+ with create_session() as session:
+ dagruns = session.query(DagRun).filter(DagRun.dag_id ==
TRIGGERED_DAG_ID).all()
+ assert len(dagruns) == 1
+
+ task.execute_complete(context={"execution_date": execution_date,
"logical_date": execution_date})
+
+ def
test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure(self):
+ """Test TriggerDagRunOperator with wait_for_completion."""
+ execution_date = DEFAULT_DATE
+ task = TriggerDagRunOperator(
+ task_id="test_task",
+ trigger_dag_id=TRIGGERED_DAG_ID,
+ execution_date=execution_date,
+ wait_for_completion=True,
+ poke_interval=10,
+ allowed_states=[State.SUCCESS],
+ deferrable=True,
+ dag=self.dag,
+ )
+
+ task.run(start_date=execution_date, end_date=execution_date)
+
+ with create_session() as session:
+ dagruns = session.query(DagRun).filter(DagRun.dag_id ==
TRIGGERED_DAG_ID).all()
+ assert len(dagruns) == 1
+
+ with pytest.raises(AirflowException):
+ task.execute_complete(context={"execution_date": execution_date,
"logical_date": execution_date})
+
+ def
test_trigger_dagrun_with_wait_for_completion_true_defer_true_failure_2(self):
+ """Test TriggerDagRunOperator with wait_for_completion."""
+ execution_date = DEFAULT_DATE
+ task = TriggerDagRunOperator(
+ task_id="test_task",
+ trigger_dag_id=TRIGGERED_DAG_ID,
+ execution_date=execution_date,
+ wait_for_completion=True,
+ poke_interval=10,
+ allowed_states=[State.SUCCESS],
+ failed_states=[State.QUEUED],
+ deferrable=True,
+ dag=self.dag,
+ )
+
+ task.run(start_date=execution_date, end_date=execution_date)
+
+ with create_session() as session:
+ dagruns = session.query(DagRun).filter(DagRun.dag_id ==
TRIGGERED_DAG_ID).all()
+ assert len(dagruns) == 1
+
+ with pytest.raises(AirflowException):
+ task.execute_complete(context={"execution_date": execution_date,
"logical_date": execution_date})