This is an automated email from the ASF dual-hosted git repository.
phanikumv pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7d3b5b4be8 Enable "airflow tasks test" to run deferrable operator
(#37542)
7d3b5b4be8 is described below
commit 7d3b5b4be8a8db675fedeaabc8151fdb5770d38a
Author: Wei Lee <[email protected]>
AuthorDate: Wed Feb 21 16:19:34 2024 +0800
Enable "airflow tasks test" to run deferrable operator (#37542)
---
airflow/cli/commands/task_command.py | 25 +++++++++++++++++++++----
airflow/models/dag.py | 8 ++++----
airflow/models/taskinstance.py | 17 +++++++++++++----
tests/cli/commands/test_dag_command.py | 8 ++++----
tests/cli/commands/test_task_command.py | 21 +++++++++++++++++++++
tests/jobs/test_triggerer_job.py | 4 ++--
6 files changed, 65 insertions(+), 18 deletions(-)
diff --git a/airflow/cli/commands/task_command.py
b/airflow/cli/commands/task_command.py
index 5c7c47d69b..8416789f53 100644
--- a/airflow/cli/commands/task_command.py
+++ b/airflow/cli/commands/task_command.py
@@ -18,6 +18,7 @@
"""Task sub-commands."""
from __future__ import annotations
+import functools
import importlib
import json
import logging
@@ -34,13 +35,13 @@ from sqlalchemy import select
from airflow import settings
from airflow.cli.simple_table import AirflowConsole
from airflow.configuration import conf
-from airflow.exceptions import AirflowException, DagRunNotFound,
TaskInstanceNotFound
+from airflow.exceptions import AirflowException, DagRunNotFound, TaskDeferred,
TaskInstanceNotFound
from airflow.executors.executor_loader import ExecutorLoader
from airflow.jobs.job import Job, run_job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.listeners.listener import get_listener_manager
from airflow.models import DagPickle, TaskInstance
-from airflow.models.dag import DAG
+from airflow.models.dag import DAG, _run_inline_trigger
from airflow.models.dagrun import DagRun
from airflow.models.operator import needs_expansion
from airflow.models.param import ParamsDict
@@ -588,7 +589,8 @@ def task_states_for_dag_run(args, session: Session =
NEW_SESSION) -> None:
@cli_utils.action_cli(check_db=False)
-def task_test(args, dag: DAG | None = None) -> None:
+@provide_session
+def task_test(args, dag: DAG | None = None, session: Session = NEW_SESSION) ->
None:
"""Test task for a given dag_id."""
# We want to log output from operators etc to show up here. Normally
# airflow.task would redirect to a file, but here we want it to propagate
@@ -632,7 +634,22 @@ def task_test(args, dag: DAG | None = None) -> None:
if args.dry_run:
ti.dry_run()
else:
- ti.run(ignore_task_deps=True, ignore_ti_state=True,
test_mode=True)
+ ti.run(ignore_task_deps=True, ignore_ti_state=True,
test_mode=True, raise_on_defer=True)
+ except TaskDeferred as defer:
+ ti.defer_task(defer=defer, session=session)
+ log.info("[TASK TEST] running trigger in line")
+
+ event = _run_inline_trigger(defer.trigger)
+ ti.next_method = defer.method_name
+ ti.next_kwargs = {"event": event.payload} if event else defer.kwargs
+
+ execute_callable = getattr(task, ti.next_method)
+ if ti.next_kwargs:
+ execute_callable = functools.partial(execute_callable,
**ti.next_kwargs)
+ context = ti.get_template_context(ignore_param_exceptions=False)
+ execute_callable(context)
+
+ log.info("[TASK TEST] Trigger completed")
except Exception:
if args.post_mortem:
debugger = _guess_debugger()
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 164e83a3f5..dd43568657 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -4057,12 +4057,12 @@ class DagContext:
return None
-def _run_trigger(trigger):
- async def _run_trigger_main():
+def _run_inline_trigger(trigger):
+ async def _run_inline_trigger_main():
async for event in trigger.run():
return event
- return asyncio.run(_run_trigger_main())
+ return asyncio.run(_run_inline_trigger_main())
def _run_task(*, ti: TaskInstance, inline_trigger: bool = False, session:
Session):
@@ -4083,7 +4083,7 @@ def _run_task(*, ti: TaskInstance, inline_trigger: bool =
False, session: Sessio
break
except TaskDeferred as e:
log.info("[DAG TEST] running trigger in line")
- event = _run_trigger(e.trigger)
+ event = _run_inline_trigger(e.trigger)
ti.next_method = e.method_name
ti.next_kwargs = {"event": event.payload} if event else e.kwargs
log.info("[DAG TEST] Trigger completed")
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4720361504..5d026a0667 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2378,7 +2378,7 @@ class TaskInstance(Base, LoggingMixin):
# a trigger.
if raise_on_defer:
raise
- self._defer_task(defer=defer, session=session)
+ self.defer_task(defer=defer, session=session)
self.log.info(
"Pausing task as DEFERRED. dag_id=%s, task_id=%s,
execution_date=%s, start_date=%s",
self.dag_id,
@@ -2565,8 +2565,11 @@ class TaskInstance(Base, LoggingMixin):
return _execute_task(self, context, task_orig)
@provide_session
- def _defer_task(self, session: Session, defer: TaskDeferred) -> None:
- """Mark the task as deferred and sets up the trigger that is needed to
resume it."""
+ def defer_task(self, session: Session, defer: TaskDeferred) -> None:
+ """Mark the task as deferred and sets up the trigger that is needed to
resume it.
+
+ :meta: private
+ """
from airflow.models.trigger import Trigger
# First, make the trigger entry
@@ -2625,6 +2628,7 @@ class TaskInstance(Base, LoggingMixin):
job_id: str | None = None,
pool: str | None = None,
session: Session = NEW_SESSION,
+ raise_on_defer: bool = False,
) -> None:
"""Run TaskInstance."""
res = self.check_and_change_state_before_execution(
@@ -2644,7 +2648,12 @@ class TaskInstance(Base, LoggingMixin):
return
self._run_raw_task(
- mark_success=mark_success, test_mode=test_mode, job_id=job_id,
pool=pool, session=session
+ mark_success=mark_success,
+ test_mode=test_mode,
+ job_id=job_id,
+ pool=pool,
+ session=session,
+ raise_on_defer=raise_on_defer,
)
def dry_run(self) -> None:
diff --git a/tests/cli/commands/test_dag_command.py
b/tests/cli/commands/test_dag_command.py
index 09b8164ee1..0df2c36f7d 100644
--- a/tests/cli/commands/test_dag_command.py
+++ b/tests/cli/commands/test_dag_command.py
@@ -37,7 +37,7 @@ from airflow.decorators import task
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import _run_trigger
+from airflow.models.dag import _run_inline_trigger
from airflow.models.serialized_dag import SerializedDagModel
from airflow.triggers.base import TriggerEvent
from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
@@ -878,15 +878,15 @@ class TestCliDags:
dag_command.dag_test(cli_args)
assert "data_interval" in mock__get_or_create_dagrun.call_args.kwargs
- def test_dag_test_run_trigger(self, dag_maker):
+ def test_dag_test_run_inline_trigger(self, dag_maker):
now = timezone.utcnow()
trigger = DateTimeTrigger(moment=now)
- e = _run_trigger(trigger)
+ e = _run_inline_trigger(trigger)
assert isinstance(e, TriggerEvent)
assert e.payload == now
def test_dag_test_no_triggerer_running(self, dag_maker):
- with mock.patch("airflow.models.dag._run_trigger", wraps=_run_trigger)
as mock_run:
+ with mock.patch("airflow.models.dag._run_inline_trigger",
wraps=_run_inline_trigger) as mock_run:
with dag_maker() as dag:
@task
diff --git a/tests/cli/commands/test_task_command.py
b/tests/cli/commands/test_task_command.py
index 00edf54e83..24c6c950d2 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -385,6 +385,27 @@ class TestCliTasks:
assert "foo=bar" in output
assert "AIRFLOW_TEST_MODE=True" in output
+ @pytest.mark.asyncio
+ @mock.patch("airflow.triggers.file.os.path.getmtime", return_value=0)
+ @mock.patch("airflow.triggers.file.glob", return_value=["/tmp/test"])
+ @mock.patch("airflow.triggers.file.os.path.isfile", return_value=True)
+ @mock.patch("airflow.sensors.filesystem.FileSensor.poke",
return_value=False)
+ def test_cli_test_with_deferrable_operator(self, mock_pock, mock_is_file,
mock_glob, mock_getmtime):
+ with redirect_stdout(StringIO()) as stdout:
+ task_command.task_test(
+ self.parser.parse_args(
+ [
+ "tasks",
+ "test",
+ "example_sensors",
+ "wait_for_file_async",
+ DEFAULT_DATE.isoformat(),
+ ]
+ )
+ )
+ output = stdout.getvalue()
+ assert "wait_for_file_async completed successfully as
/tmp/temporary_file_for_testing found" in output
+
@pytest.mark.parametrize(
"option",
[
diff --git a/tests/jobs/test_triggerer_job.py b/tests/jobs/test_triggerer_job.py
index 18e1dd7fb9..1cc69062e4 100644
--- a/tests/jobs/test_triggerer_job.py
+++ b/tests/jobs/test_triggerer_job.py
@@ -266,7 +266,7 @@ def test_trigger_lifecycle(session):
class TestTriggerRunner:
@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
- async def test_run_trigger_canceled(self, session) -> None:
+ async def test_run_inline_trigger_canceled(self, session) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name":
"mock_name", "events": 0}}
mock_trigger = MagicMock()
@@ -278,7 +278,7 @@ class TestTriggerRunner:
@pytest.mark.asyncio
@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.set_individual_trigger_logging")
- async def test_run_trigger_timeout(self, session, caplog) -> None:
+ async def test_run_inline_trigger_timeout(self, session, caplog) -> None:
trigger_runner = TriggerRunner()
trigger_runner.triggers = {1: {"task": MagicMock(), "name":
"mock_name", "events": 0}}
mock_trigger = MagicMock()