This is an automated email from the ASF dual-hosted git repository.

phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d3b5b4be8 Enable "airflow tasks test" to run deferrable operator 
(#37542)
7d3b5b4be8 is described below

commit 7d3b5b4be8a8db675fedeaabc8151fdb5770d38a
Author: Wei Lee <[email protected]>
AuthorDate: Wed Feb 21 16:19:34 2024 +0800

    Enable "airflow tasks test" to run deferrable operator (#37542)
---
 airflow/cli/commands/task_command.py    | 25 +++++++++++++++++++++----
 airflow/models/dag.py                   |  8 ++++----
 airflow/models/taskinstance.py          | 17 +++++++++++++----
 tests/cli/commands/test_dag_command.py  |  8 ++++----
 tests/cli/commands/test_task_command.py | 21 +++++++++++++++++++++
 tests/jobs/test_triggerer_job.py        |  4 ++--
 6 files changed, 65 insertions(+), 18 deletions(-)

diff --git a/airflow/cli/commands/task_command.py 
b/airflow/cli/commands/task_command.py
index 5c7c47d69b..8416789f53 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -18,6 +18,7 @@
 """Task sub-commands."""
 from __future__ import annotations
 
+import functools
 import importlib
 import json
 import logging
@@ -34,13 +35,13 @@ from sqlalchemy import select
 from airflow import settings
 from airflow.cli.simple_table import AirflowConsole
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException, DagRunNotFound, 
TaskInstanceNotFound
+from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred, 
TaskInstanceNotFound
 from airflow.executors.executor_loader import ExecutorLoader
 from airflow.jobs.job import Job, run_job
 from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
 from airflow.listeners.listener import get_listener_manager
 from airflow.models import DagPickle, TaskInstance
-from airflow.models.dag import DAG
+from airflow.models.dag import DAG, _run_inline_trigger
 from airflow.models.dagrun import DagRun
 from airflow.models.operator import needs_expansion
 from airflow.models.param import ParamsDict
@@ -588,7 +589,8 @@ def task_states_for_dag_run(args, session: Session = 
NEW_SESSION) -> None:
 
 
 @cli_utils.action_cli(check_db=False)
-def task_test(args, dag: DAG | None = None) -> None:
+@provide_session
+def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) -> 
None:
     """Test task for a given dag_id."""
     # We want to log output from operators etc to show up here. Normally
     # airflow.task would redirect to a file, but here we want it to propagate
@@ -632,7 +634,22 @@ def task_test(args, dag: DAG | None = None) -> None:
             if args.dry_run:
                 ti.dry_run()
             else:
-                ti.run(ignore_task_deps=True, ignore_ti_state=True, 
test_mode=True)
+                ti.run(ignore_task_deps=True, ignore_ti_state=True, 
test_mode=True, raise_on_defer=True)
+    except TaskDeferred as defer:
+        ti.defer_task(defer=defer, session=session)
+        log.info("[TASK TEST] running trigger in line")
+
+        event = _run_inline_trigger(defer.trigger)
+        ti.next_method = defer.method_name
+        ti.next_kwargs = {"event": event.payload} if event else defer.kwargs
+
+        execute_callable = getattr(task, ti.next_method)
+        if ti.next_kwargs:
+            execute_callable = functools.partial(execute_callable, 
**ti.next_kwargs)
+        context = ti.get_template_context(ignore_param_exceptions=False)
+        execute_callable(context)
+
+        log.info("[TASK TEST] Trigger completed")
     except Exception:
         if args.post_mortem:
             debugger = _guess_debugger()
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 164e83a3f5..dd43568657 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -4057,12 +4057,12 @@ class DagContext:
             return None
 
 
-def _run_trigger(trigger):
-    async def _run_trigger_main():
+def _run_inline_trigger(trigger):
+    async def _run_inline_trigger_main():
         async for event in trigger.run():
             return event
 
-    return asyncio.run(_run_trigger_main())
+    return asyncio.run(_run_inline_trigger_main())
 
 
 def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session: 
Session):
@@ -4083,7 +4083,7 @@ def _run_task(*, ti: TaskInstance, inline_trigger: bool = 
False, session: Sessio
             break
         except TaskDeferred as e:
             log.info("[DAG TEST] running trigger in line")
-            event = _run_trigger(e.trigger)
+            event = _run_inline_trigger(e.trigger)
             ti.next_method = e.method_name
             ti.next_kwargs = {"event": event.payload} if event else e.kwargs
             log.info("[DAG TEST] Trigger completed")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4720361504..5d026a0667 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2378,7 +2378,7 @@ class TaskInstance(Base, LoggingMixin):
                 # a trigger.
                 if raise_on_defer:
                     raise
-                self._defer_task(defer=defer, session=session)
+                self.defer_task(defer=defer, session=session)
                 self.log.info(
                     "Pausing task as DEFERRED. dag_id=%s, task_id=%s, 
execution_date=%s, start_date=%s",
                     self.dag_id,
@@ -2565,8 +2565,11 @@ class TaskInstance(Base, LoggingMixin):
         return _execute_task(self, context, task_orig)
 
     @provide_session
-    def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
-        """Mark the task as deferred and sets up the trigger that is needed to 
resume it."""
+    def defer_task(self, session: Session, defer: TaskDeferred) -> None:
+        """Mark the task as deferred and sets up the trigger that is needed to 
resume it.
+
+        :meta: private
+        """
         from airflow.models.trigger import Trigger
 
         # First, make the trigger entry
@@ -2625,6 +2628,7 @@ class TaskInstance(Base, LoggingMixin):
         job_id: str | None = None,
         pool: str | None = None,
         session: Session = NEW_SESSION,
+        raise_on_defer: bool = False,
     ) -> None:
         """Run TaskInstance."""
         res = self.check_and_change_state_before_execution(
@@ -2644,7 +2648,12 @@ class TaskInstance(Base, LoggingMixin):
             return
 
         self._run_raw_task(
-            mark_success=mark_success, test_mode=test_mode, job_id=job_id, 
pool=pool, session=session
+            mark_success=mark_success,
+            test_mode=test_mode,
+            job_id=job_id,
+            pool=pool,
+            session=session,
+            raise_on_defer=raise_on_defer,
         )
 
     def dry_run(self) -> None:
diff --git a/tests/cli/commands/test_dag_command.py 
b/tests/cli/commands/test_dag_command.py
index 09b8164ee1..0df2c36f7d 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -37,7 +37,7 @@ from airflow.decorators import task
 from airflow.exceptions import AirflowException
 from airflow.models import DagBag, DagModel, DagRun
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import _run_trigger
+from airflow.models.dag import _run_inline_trigger
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.triggers.base import TriggerEvent
 from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
@@ -878,15 +878,15 @@ class TestCliDags:
         dag_command.dag_test(cli_args)
         assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs
 
-    def test_dag_test_run_trigger(self, dag_maker):
+    def test_dag_test_run_inline_trigger(self, dag_maker):
         now = timezone.utcnow()
         trigger = DateTimeTrigger(moment=now)
-        e = _run_trigger(trigger)
+        e = _run_inline_trigger(trigger)
         assert isinstance(e, TriggerEvent)
         assert e.payload == now
 
     def test_dag_test_no_triggerer_running(self, dag_maker):
-        with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger) 
as mock_run:
+        with mock.patch("airflow.models.dag._run_inline_trigger", 
wraps=_run_inline_trigger) as mock_run:
             with dag_maker() as dag:
 
                 @task
diff --git a/tests/cli/commands/test_task_command.py 
b/tests/cli/commands/test_task_command.py
index 00edf54e83..24c6c950d2 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -385,6 +385,27 @@ class TestCliTasks:
         assert "foo=bar" in output
         assert "AIRFLOW_TEST_MODE=True" in output
 
+    @pytest.mark.asyncio
+    @mock.patch("airflow.triggers.file.os.path.getmtime", return_value=0)
+    @mock.patch("airflow.triggers.file.glob", return_value=["/tmp/test"])
+    @mock.patch("airflow.triggers.file.os.path.isfile", return_value=True)
+    @mock.patch("airflow.sensors.filesystem.FileSensor.poke", 
return_value=False)
+    def test_cli_test_with_deferrable_operator(self, mock_pock, mock_is_file, 
mock_glob, mock_getmtime):
+        with redirect_stdout(StringIO()) as stdout:
+            task_command.task_test(
+                self.parser.parse_args(
+                    [
+                        "tasks",
+                        "test",
+                        "example_sensors",
+                        "wait_for_file_async",
+                        DEFAULT_DATE.isoformat(),
+                    ]
+                )
+            )
+        output = stdout.getvalue()
+        assert "wait_for_file_async completed successfully as 
/tmp/temporary_file_for_testing found" in output
+
     @pytest.mark.parametrize(
         "option",
         [
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 18e1dd7fb9..1cc69062e4 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -266,7 +266,7 @@ def test_trigger_lifecycle(session):
 class TestTriggerRunner:
     @pytest.mark.asyncio
     
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
-    async def test_run_trigger_canceled(self, session) -> None:
+    async def test_run_inline_trigger_canceled(self, session) -> None:
         trigger_runner = TriggerRunner()
         trigger_runner.triggers = {1: {"task": MagicMock(), "name": 
"mock_name", "events": 0}}
         mock_trigger = MagicMock()
@@ -278,7 +278,7 @@ class TestTriggerRunner:
 
     @pytest.mark.asyncio
     
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
-    async def test_run_trigger_timeout(self, session, caplog) -> None:
+    async def test_run_inline_trigger_timeout(self, session, caplog) -> None:
         trigger_runner = TriggerRunner()
         trigger_runner.triggers = {1: {"task": MagicMock(), "name": 
"mock_name", "events": 0}}
         mock_trigger = MagicMock()

Reply via email to