This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v2-8-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 28897f7a424f0c44b340a17591d54854aa192dd5 Author: Daniel Standish <[email protected]> AuthorDate: Mon Nov 27 06:48:17 2023 -0800 Run triggers inline with dag test (#34642) No need to have trigger running -- will just run them async. (cherry picked from commit 7b37a785d0b74d1e83c7ce84729febffd6e26821) --- airflow/models/dag.py | 68 +++++++++++++--------------- airflow/models/taskinstance.py | 3 ++ tests/cli/commands/test_dag_command.py | 81 ++++++++++++++++++++-------------- tests/models/test_mappedoperator.py | 2 +- 4 files changed, 81 insertions(+), 73 deletions(-) diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 26c83754a8..27e8258a6d 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -17,7 +17,8 @@ # under the License. from __future__ import annotations -import collections.abc +import asyncio +import collections import copy import functools import itertools @@ -82,11 +83,11 @@ from airflow.datasets.manager import dataset_manager from airflow.exceptions import ( AirflowDagInconsistent, AirflowException, - AirflowSkipException, DuplicateTaskIdFound, FailStopDagInvalidTriggerRule, ParamValidationError, RemovedInAirflow3Warning, + TaskDeferred, TaskNotFound, ) from airflow.jobs.job import run_job @@ -101,7 +102,6 @@ from airflow.models.taskinstance import ( Context, TaskInstance, TaskInstanceKey, - TaskReturnCode, clear_task_instances, ) from airflow.secrets.local_filesystem import LocalFilesystemBackend @@ -285,12 +285,11 @@ def get_dataset_triggered_next_run_info( } -class _StopDagTest(Exception): - """ - Raise when DAG.test should stop immediately. +def _triggerer_is_healthy(): + from airflow.jobs.triggerer_job_runner import TriggererJobRunner - :meta private: - """ + job = TriggererJobRunner.most_recent_job() + return job and job.is_alive() @functools.total_ordering @@ -2844,21 +2843,12 @@ class DAG(LoggingMixin): if not scheduled_tis and ids_unrunnable: self.log.warning("No tasks to run. unrunnable tasks: %s", ids_unrunnable) time.sleep(1) + triggerer_running = _triggerer_is_healthy() for ti in scheduled_tis: try: add_logger_if_needed(ti) ti.task = tasks[ti.task_id] - ret = _run_task(ti, session=session) - if ret is TaskReturnCode.DEFERRED: - if not _triggerer_is_healthy(): - raise _StopDagTest( - "Task has deferred but triggerer component is not running. " - "You can start the triggerer by running `airflow triggerer` in a terminal." - ) - except _StopDagTest: - # Let this exception bubble out and not be swallowed by the - # except block below. - raise + _run_task(ti=ti, inline_trigger=not triggerer_running, session=session) except Exception: self.log.exception("Task failed; ti=%s", ti) if conn_file_path or variable_file_path: @@ -3992,14 +3982,15 @@ class DagContext: return None -def _triggerer_is_healthy(): - from airflow.jobs.triggerer_job_runner import TriggererJobRunner +def _run_trigger(trigger): + async def _run_trigger_main(): + async for event in trigger.run(): + return event - job = TriggererJobRunner.most_recent_job() - return job and job.is_alive() + return asyncio.run(_run_trigger_main()) -def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None: +def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: Session): """ Run a single task instance, and push result to Xcom for downstream tasks. @@ -4009,20 +4000,21 @@ def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None: Args: ti: TaskInstance to run """ - ret = None - log.info("*****************************************************") - if ti.map_index > 0: - log.info("Running task %s index %d", ti.task_id, ti.map_index) - else: - log.info("Running task %s", ti.task_id) - try: - ret = ti._run_raw_task(session=session) - session.flush() - log.info("%s ran successfully!", ti.task_id) - except AirflowSkipException: - log.info("Task Skipped, continuing") - log.info("*****************************************************") - return ret + log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) + while True: + try: + log.info("[DAG TEST] running task %s", ti) + ti._run_raw_task(session=session, raise_on_defer=inline_trigger) + break + except TaskDeferred as e: + log.info("[DAG TEST] running trigger in line") + event = _run_trigger(e.trigger) + ti.next_method = e.method_name + ti.next_kwargs = {"event": event.payload} if event else e.kwargs + log.info("[DAG TEST] Trigger completed") + session.merge(ti) + session.commit() + log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) def _get_or_create_dagrun( diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 95a2f5945f..f041dcf208 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2284,6 +2284,7 @@ class TaskInstance(Base, LoggingMixin): test_mode: bool = False, job_id: str | None = None, pool: str | None = None, + raise_on_defer: bool = False, session: Session = NEW_SESSION, ) -> TaskReturnCode | None: """ @@ -2338,6 +2339,8 @@ class TaskInstance(Base, LoggingMixin): except TaskDeferred as defer: # The task has signalled it wants to defer execution based on # a trigger. + if raise_on_defer: + raise self._defer_task(defer=defer, session=session) self.log.info( "Pausing task as DEFERRED. dag_id=%s, task_id=%s, execution_date=%s, start_date=%s", diff --git a/tests/cli/commands/test_dag_command.py b/tests/cli/commands/test_dag_command.py index 78b7fd4525..30b5c475ea 100644 --- a/tests/cli/commands/test_dag_command.py +++ b/tests/cli/commands/test_dag_command.py @@ -37,9 +37,10 @@ from airflow.decorators import task from airflow.exceptions import AirflowException from airflow.models import DagBag, DagModel, DagRun from airflow.models.baseoperator import BaseOperator -from airflow.models.dag import _StopDagTest +from airflow.models.dag import _run_trigger from airflow.models.serialized_dag import SerializedDagModel -from airflow.triggers.temporal import TimeDeltaTrigger +from airflow.triggers.base import TriggerEvent +from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.types import DagRunType @@ -824,35 +825,47 @@ class TestCliDags: dag_command.dag_test(cli_args) assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs - def test_dag_test_no_triggerer(self, dag_maker): - with dag_maker() as dag: - - @task - def one(): - return 1 - - @task - def two(val): - return val + 1 - - class MyOp(BaseOperator): - template_fields = ("tfield",) - - def __init__(self, tfield, **kwargs): - self.tfield = tfield - super().__init__(**kwargs) - - def execute(self, context, event=None): - if event is None: - print("I AM DEFERRING") - self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), method_name="execute") - return - print("RESUMING") - return self.tfield + 1 - - task_one = one() - task_two = two(task_one) - op = MyOp(task_id="abc", tfield=str(task_two)) - task_two >> op - with pytest.raises(_StopDagTest, match="Task has deferred but triggerer component is not running"): - dag.test() + def test_dag_test_run_trigger(self, dag_maker): + now = timezone.utcnow() + trigger = DateTimeTrigger(moment=now) + e = _run_trigger(trigger) + assert isinstance(e, TriggerEvent) + assert e.payload == now + + def test_dag_test_no_triggerer_running(self, dag_maker): + with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) as mock_run: + with dag_maker() as dag: + + @task + def one(): + return 1 + + @task + def two(val): + return val + 1 + + trigger = TimeDeltaTrigger(timedelta(seconds=0)) + + class MyOp(BaseOperator): + template_fields = ("tfield",) + + def __init__(self, tfield, **kwargs): + self.tfield = tfield + super().__init__(**kwargs) + + def execute(self, context, event=None): + if event is None: + print("I AM DEFERRING") + self.defer(trigger=trigger, method_name="execute") + return + print("RESUMING") + return self.tfield + 1 + + task_one = one() + task_two = two(task_one) + op = MyOp(task_id="abc", tfield=task_two) + task_two >> op + dr = dag.test() + assert mock_run.call_args_list[0] == ((trigger,), {}) + tis = dr.get_task_instances() + assert [x for x in tis if x.task_id == "abc"][0].state == "success" diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 7244c55774..78f0a0d271 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -95,7 +95,7 @@ def test_task_mapping_with_dag_and_list_of_pandas_dataframe(mock_render_template mapped = CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values) task1 >> mapped dag.test() - assert caplog.text.count("task_2 ran successfully") == 2 + assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2 assert ( "Unable to check if the value of type 'UnrenderableClass' is False for task 'task_2', field 'arg'" in caplog.text
