This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 28897f7a424f0c44b340a17591d54854aa192dd5
Author: Daniel Standish <[email protected]>
AuthorDate: Mon Nov 27 06:48:17 2023 -0800

    Run triggers inline with dag test (#34642)
    
    No need to have trigger running -- will just run them async.
    
    (cherry picked from commit 7b37a785d0b74d1e83c7ce84729febffd6e26821)
---
 airflow/models/dag.py                  | 68 +++++++++++++---------------
 airflow/models/taskinstance.py         |  3 ++
 tests/cli/commands/test_dag_command.py | 81 ++++++++++++++++++++--------------
 tests/models/test_mappedoperator.py    |  2 +-
 4 files changed, 81 insertions(+), 73 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 26c83754a8..27e8258a6d 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -17,7 +17,8 @@
 # under the License.
 from __future__ import annotations
 
-import collections.abc
+import asyncio
+import collections
 import copy
 import functools
 import itertools
@@ -82,11 +83,11 @@ from airflow.datasets.manager import dataset_manager
 from airflow.exceptions import (
     AirflowDagInconsistent,
     AirflowException,
-    AirflowSkipException,
     DuplicateTaskIdFound,
     FailStopDagInvalidTriggerRule,
     ParamValidationError,
     RemovedInAirflow3Warning,
+    TaskDeferred,
     TaskNotFound,
 )
 from airflow.jobs.job import run_job
@@ -101,7 +102,6 @@ from airflow.models.taskinstance import (
     Context,
     TaskInstance,
     TaskInstanceKey,
-    TaskReturnCode,
     clear_task_instances,
 )
 from airflow.secrets.local_filesystem import LocalFilesystemBackend
@@ -285,12 +285,11 @@ def get_dataset_triggered_next_run_info(
     }
 
 
-class _StopDagTest(Exception):
-    """
-    Raise when DAG.test should stop immediately.
+def _triggerer_is_healthy():
+    from airflow.jobs.triggerer_job_runner import TriggererJobRunner
 
-    :meta private:
-    """
+    job = TriggererJobRunner.most_recent_job()
+    return job and job.is_alive()
 
 
 @functools.total_ordering
@@ -2844,21 +2843,12 @@ class DAG(LoggingMixin):
             if not scheduled_tis and ids_unrunnable:
                 self.log.warning("No tasks to run. unrunnable tasks: %s", 
ids_unrunnable)
                 time.sleep(1)
+            triggerer_running = _triggerer_is_healthy()
             for ti in scheduled_tis:
                 try:
                     add_logger_if_needed(ti)
                     ti.task = tasks[ti.task_id]
-                    ret = _run_task(ti, session=session)
-                    if ret is TaskReturnCode.DEFERRED:
-                        if not _triggerer_is_healthy():
-                            raise _StopDagTest(
-                                "Task has deferred but triggerer component is 
not running. "
-                                "You can start the triggerer by running 
`airflow triggerer` in a terminal."
-                            )
-                except _StopDagTest:
-                    # Let this exception bubble out and not be swallowed by the
-                    # except block below.
-                    raise
+                    _run_task(ti=ti, inline_trigger=not triggerer_running, 
session=session)
                 except Exception:
                     self.log.exception("Task failed; ti=%s", ti)
         if conn_file_path or variable_file_path:
@@ -3992,14 +3982,15 @@ class DagContext:
             return None
 
 
-def _triggerer_is_healthy():
-    from airflow.jobs.triggerer_job_runner import TriggererJobRunner
+def _run_trigger(trigger):
+    async def _run_trigger_main():
+        async for event in trigger.run():
+            return event
 
-    job = TriggererJobRunner.most_recent_job()
-    return job and job.is_alive()
+    return asyncio.run(_run_trigger_main())
 
 
-def _run_task(ti: TaskInstance, session) -> TaskReturnCode | None:
+def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: 
Session):
     """
     Run a single task instance, and push result to Xcom for downstream tasks.
 
@@ -4009,20 +4000,21 @@ def _run_task(ti: TaskInstance, session) -> 
TaskReturnCode | None:
     Args:
         ti: TaskInstance to run
     """
-    ret = None
-    log.info("*****************************************************")
-    if ti.map_index > 0:
-        log.info("Running task %s index %d", ti.task_id, ti.map_index)
-    else:
-        log.info("Running task %s", ti.task_id)
-    try:
-        ret = ti._run_raw_task(session=session)
-        session.flush()
-        log.info("%s ran successfully!", ti.task_id)
-    except AirflowSkipException:
-        log.info("Task Skipped, continuing")
-    log.info("*****************************************************")
-    return ret
+    log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
+    while True:
+        try:
+            log.info("[DAG TEST] running task %s", ti)
+            ti._run_raw_task(session=session, raise_on_defer=inline_trigger)
+            break
+        except TaskDeferred as e:
+            log.info("[DAG TEST] running trigger in line")
+            event = _run_trigger(e.trigger)
+            ti.next_method = e.method_name
+            ti.next_kwargs = {"event": event.payload} if event else e.kwargs
+            log.info("[DAG TEST] Trigger completed")
+        session.merge(ti)
+        session.commit()
+    log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
 
 
 def _get_or_create_dagrun(
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 95a2f5945f..f041dcf208 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2284,6 +2284,7 @@ class TaskInstance(Base, LoggingMixin):
         test_mode: bool = False,
         job_id: str | None = None,
         pool: str | None = None,
+        raise_on_defer: bool = False,
         session: Session = NEW_SESSION,
     ) -> TaskReturnCode | None:
         """
@@ -2338,6 +2339,8 @@ class TaskInstance(Base, LoggingMixin):
             except TaskDeferred as defer:
                 # The task has signalled it wants to defer execution based on
                 # a trigger.
+                if raise_on_defer:
+                    raise
                 self._defer_task(defer=defer, session=session)
                 self.log.info(
                     "Pausing task as DEFERRED. dag_id=%s, task_id=%s, 
execution_date=%s, start_date=%s",
diff --git a/tests/cli/commands/test_dag_command.py 
b/tests/cli/commands/test_dag_command.py
index 78b7fd4525..30b5c475ea 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -37,9 +37,10 @@ from airflow.decorators import task
 from airflow.exceptions import AirflowException
 from airflow.models import DagBag, DagModel, DagRun
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import _StopDagTest
+from airflow.models.dag import _run_trigger
 from airflow.models.serialized_dag import SerializedDagModel
-from airflow.triggers.temporal import TimeDeltaTrigger
+from airflow.triggers.base import TriggerEvent
+from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
 from airflow.utils import timezone
 from airflow.utils.session import create_session
 from airflow.utils.types import DagRunType
@@ -824,35 +825,47 @@ class TestCliDags:
         dag_command.dag_test(cli_args)
         assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs
 
-    def test_dag_test_no_triggerer(self, dag_maker):
-        with dag_maker() as dag:
-
-            @task
-            def one():
-                return 1
-
-            @task
-            def two(val):
-                return val + 1
-
-            class MyOp(BaseOperator):
-                template_fields = ("tfield",)
-
-                def __init__(self, tfield, **kwargs):
-                    self.tfield = tfield
-                    super().__init__(**kwargs)
-
-                def execute(self, context, event=None):
-                    if event is None:
-                        print("I AM DEFERRING")
-                        
self.defer(trigger=TimeDeltaTrigger(timedelta(seconds=20)), 
method_name="execute")
-                        return
-                    print("RESUMING")
-                    return self.tfield + 1
-
-            task_one = one()
-            task_two = two(task_one)
-            op = MyOp(task_id="abc", tfield=str(task_two))
-            task_two >> op
-        with pytest.raises(_StopDagTest, match="Task has deferred but 
triggerer component is not running"):
-            dag.test()
+    def test_dag_test_run_trigger(self, dag_maker):
+        now = timezone.utcnow()
+        trigger = DateTimeTrigger(moment=now)
+        e = _run_trigger(trigger)
+        assert isinstance(e, TriggerEvent)
+        assert e.payload == now
+
+    def test_dag_test_no_triggerer_running(self, dag_maker):
+        with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) 
as mock_run:
+            with dag_maker() as dag:
+
+                @task
+                def one():
+                    return 1
+
+                @task
+                def two(val):
+                    return val + 1
+
+                trigger = TimeDeltaTrigger(timedelta(seconds=0))
+
+                class MyOp(BaseOperator):
+                    template_fields = ("tfield",)
+
+                    def __init__(self, tfield, **kwargs):
+                        self.tfield = tfield
+                        super().__init__(**kwargs)
+
+                    def execute(self, context, event=None):
+                        if event is None:
+                            print("I AM DEFERRING")
+                            self.defer(trigger=trigger, method_name="execute")
+                            return
+                        print("RESUMING")
+                        return self.tfield + 1
+
+                task_one = one()
+                task_two = two(task_one)
+                op = MyOp(task_id="abc", tfield=task_two)
+                task_two >> op
+            dr = dag.test()
+            assert mock_run.call_args_list[0] == ((trigger,), {})
+            tis = dr.get_task_instances()
+            assert [x for x in tis if x.task_id == "abc"][0].state == "success"
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 7244c55774..78f0a0d271 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -95,7 +95,7 @@ def 
test_task_mapping_with_dag_and_list_of_pandas_dataframe(mock_render_template
         mapped = 
CustomOperator.partial(task_id="task_2").expand(arg=unrenderable_values)
         task1 >> mapped
     dag.test()
-    assert caplog.text.count("task_2 ran successfully") == 2
+    assert caplog.text.count("[DAG TEST] end task task_id=task_2") == 2
     assert (
         "Unable to check if the value of type 'UnrenderableClass' is False for 
task 'task_2', field 'arg'"
         in caplog.text

Reply via email to