This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 80840cb5734 Add a fixture to easily replace `ti.run` usage (#48439)
80840cb5734 is described below
commit 80840cb5734630f95d831dd10ae103c3df3a4b2f
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Mar 27 18:53:42 2025 +0530
Add a fixture to easily replace `ti.run` usage (#48439)
As we are replacing BaseOperator usage from Core to Task SDK, we are
running into several issues, one of the common one being over-usage of
`task.run()`.
While some cases can be easily replaced by `task.execute()` others needs
execution of the tasks, sharing of XCom's in between, checking task state,
correct exception etc.
To make this easier I have added `run_task` fixture which I have been using
in https://github.com/apache/airflow/pull/48244 and it has worked out well.
Example:
---
devel-common/src/tests_common/pytest_plugin.py | 427 +++++++++++++++++++++++++
1 file changed, 427 insertions(+)
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index ad885d176c7..7d04f4fe221 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -35,6 +35,8 @@ import pytest
import time_machine
if TYPE_CHECKING:
+ from uuid import UUID
+
from itsdangerous import URLSafeSerializer
from sqlalchemy.orm import Session
@@ -43,6 +45,10 @@ if TYPE_CHECKING:
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.taskinstance import TaskInstance
from airflow.providers.standard.operators.empty import EmptyOperator
+ from airflow.sdk.api.datamodels._generated import IntermediateTIState,
TerminalTIState
+ from airflow.sdk.definitions.baseoperator import BaseOperator as
TaskSDKBaseOperator
+ from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
from airflow.timetables.base import DataInterval
from airflow.typing_compat import Self
from airflow.utils.state import DagRunState, TaskInstanceState
@@ -1872,3 +1878,424 @@ def mock_supervisor_comms():
"airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
) as supervisor_comms:
yield supervisor_comms
+
+
[email protected]
+def mocked_parse(spy_agency):
+ """
+ Fixture to set up an inline DAG and use it in a stubbed `parse` function.
+
+ Use this fixture if you want to isolate and test `parse` or `run` logic
without having to define a DAG file.
+ In most cases, you should use `create_runtime_ti` fixture instead where
you can directly pass an operator
+ compared to lower level AIP-72 constructs like `StartupDetails`.
+
+ This fixture returns a helper function `set_dag` that:
+ 1. Creates an in line DAG with the given `dag_id` and `task` (limited to
one task)
+ 2. Constructs a `RuntimeTaskInstance` based on the provided
`StartupDetails` and task.
+ 3. Stubs the `parse` function using `spy_agency`, to return the mocked
`RuntimeTaskInstance`.
+
+ After adding the fixture in your test function signature, you can use it
like this ::
+
+ mocked_parse(
+ StartupDetails(
+ ti=TaskInstance(id=uuid7(), task_id="hello",
dag_id="super_basic_run", run_id="c", try_number=1),
+ file="",
+ requests_fd=0,
+ ),
+ "example_dag_id",
+ CustomOperator(task_id="hello"),
+ )
+ """
+
+ def set_dag(what: StartupDetails, dag_id: str, task: TaskSDKBaseOperator)
-> RuntimeTaskInstance:
+ from airflow.sdk.definitions.dag import DAG
+ from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance, parse
+ from airflow.utils import timezone
+
+ if not task.has_dag():
+ dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
+ task.dag = dag # type: ignore[assignment]
+ task = dag.task_dict[task.task_id]
+ else:
+ dag = task.dag
+ if what.ti_context.dag_run.conf:
+ dag.params = what.ti_context.dag_run.conf # type:
ignore[assignment]
+ ti = RuntimeTaskInstance.model_construct(
+ **what.ti.model_dump(exclude_unset=True),
+ task=task,
+ _ti_context_from_server=what.ti_context,
+ max_tries=what.ti_context.max_tries,
+ start_date=what.start_date,
+ )
+ if hasattr(parse, "spy"):
+ spy_agency.unspy(parse)
+ spy_agency.spy_on(parse, call_fake=lambda _: ti)
+ return ti
+
+ return set_dag
+
+
+class _XComHelperProtocol(Protocol):
+ def get(
+ self,
+ key: str,
+ task_id: str | None = None,
+ dag_id: str | None = None,
+ run_id: str | None = None,
+ map_index: int | None = None,
+ ) -> Any: ...
+
+ def assert_pushed(
+ self,
+ key: str,
+ value: Any,
+ task_id: str | None = None,
+ dag_id: str | None = None,
+ run_id: str | None = None,
+ map_index: int | None = None,
+ **kwargs,
+ ) -> None: ...
+
+ def clear(self): ...
+
+
+class RunTaskCallable(Protocol):
+ """Protocol for better type hints for the fixture `run_task`."""
+
+ @property
+ def state(self) -> IntermediateTIState | TerminalTIState: ...
+
+ @property
+ def msg(self) -> ToSupervisor | None: ...
+
+ @property
+ def error(self) -> BaseException | None: ...
+
+ xcom: _XComHelperProtocol
+
+ def __call__(
+ self,
+ task: BaseOperator,
+ dag_id: str = ...,
+ run_id: str = ...,
+ logical_date: datetime | None = None,
+ start_date: datetime | None = None,
+ run_type: str = ...,
+ try_number: int = ...,
+ map_index: int | None = ...,
+ ti_id: UUID | None = None,
+ max_tries: int | None = None,
+ context_update: dict[str, Any] | None = None,
+ ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None,
BaseException | None]: ...
+
+
[email protected]
+def create_runtime_ti(mocked_parse):
+ """
+ Fixture to create a Runtime TaskInstance for testing purposes without
defining a dag file.
+
+ It mimics the behavior of the `parse` function by creating a
`RuntimeTaskInstance` based on the provided
+ `StartupDetails` (formed from arguments) and task. This allows you to test
the logic of a task without
+ having to define a DAG file, parse it, get context from the server, etc.
+
+ Example usage: ::
+
+ def test_custom_task_instance(create_runtime_ti):
+ class MyTaskOperator(BaseOperator):
+ def execute(self, context):
+ assert context["dag_run"].run_id == "test_run"
+
+ task = MyTaskOperator(task_id="test_task")
+ ti = create_runtime_ti(task)
+ # Further test logic...
+ """
+ from uuid6 import uuid7
+
+ from airflow.sdk.api.datamodels._generated import TaskInstance
+ from airflow.sdk.definitions.dag import DAG
+ from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
+ from airflow.utils import timezone
+
+ def _create_task_instance(
+ task: BaseOperator,
+ dag_id: str = "test_dag",
+ run_id: str = "test_run",
+ logical_date: str | datetime = "2024-12-01T01:00:00Z",
+ start_date: str | datetime = "2024-12-01T01:00:00Z",
+ run_type: str = "manual",
+ try_number: int = 1,
+ map_index: int | None = -1,
+ upstream_map_indexes: dict[str, int] | None = None,
+ task_reschedule_count: int = 0,
+ ti_id: UUID | None = None,
+ conf: dict[str, Any] | None = None,
+ should_retry: bool | None = None,
+ max_tries: int | None = None,
+ ) -> RuntimeTaskInstance:
+ from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
+
+ if not ti_id:
+ ti_id = uuid7()
+
+ if not task.has_dag():
+ dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
+ task.dag = dag # type: ignore[assignment]
+ task = dag.task_dict[task.task_id]
+
+ if task.dag.timetable:
+ data_interval_start, data_interval_end =
task.dag.timetable.infer_manual_data_interval(
+ run_after=logical_date # type: ignore
+ )
+ else:
+ data_interval_start = None
+ data_interval_end = None
+
+ dag_id = task.dag.dag_id
+ task_retries = task.retries or 0
+ run_after = data_interval_end or logical_date or timezone.utcnow()
+
+ ti_context = TIRunContext(
+ dag_run=DagRun(
+ dag_id=dag_id,
+ run_id=run_id,
+ logical_date=logical_date, # type: ignore
+ data_interval_start=data_interval_start,
+ data_interval_end=data_interval_end,
+ start_date=start_date, # type: ignore
+ run_type=run_type, # type: ignore
+ run_after=run_after, # type: ignore
+ conf=conf,
+ ),
+ task_reschedule_count=task_reschedule_count,
+ max_tries=task_retries if max_tries is None else max_tries,
+ should_retry=should_retry if should_retry is not None else
try_number <= task_retries,
+ )
+
+ if upstream_map_indexes is not None:
+ ti_context.upstream_map_indexes = upstream_map_indexes
+
+ startup_details = StartupDetails(
+ ti=TaskInstance(
+ id=ti_id,
+ task_id=task.task_id,
+ dag_id=dag_id,
+ run_id=run_id,
+ try_number=try_number,
+ map_index=map_index,
+ ),
+ dag_rel_path="",
+ bundle_info=BundleInfo(name="anything", version="any"),
+ requests_fd=0,
+ ti_context=ti_context,
+ start_date=start_date, # type: ignore
+ )
+
+ ti = mocked_parse(startup_details, dag_id, task)
+ return ti
+
+ return _create_task_instance
+
+
[email protected]
+def run_task(create_runtime_ti, mock_supervisor_comms, spy_agency) ->
RunTaskCallable:
+ """
+ Fixture to run a task without defining a dag file.
+
+ This fixture builds on top of create_runtime_ti to provide a convenient
way to execute tasks and get their results.
+
+ The fixture provides:
+ - run_task.state - Get the task state
+ - run_task.msg - Get the task message
+ - run_task.error - Get the task error
+ - run_task.xcom.get(key) - Get an XCom value
+ - run_task.xcom.assert_pushed(key, value, ...) - Assert an XCom was pushed
+
+ Example usage: ::
+
+ def test_custom_task(run_task):
+ class MyTaskOperator(BaseOperator):
+ def execute(self, context):
+ return "hello"
+
+ task = MyTaskOperator(task_id="test_task")
+ run_task(task)
+ assert run_task.state == TerminalTIState.SUCCESS
+ assert run_task.error is None
+ """
+ import structlog
+
+ from airflow.sdk.execution_time.task_runner import run
+ from airflow.sdk.execution_time.xcom import XCom
+ from airflow.utils import timezone
+
+ # Set up spies once at fixture level
+ if hasattr(XCom.set, "spy"):
+ spy_agency.unspy(XCom.set)
+ if hasattr(XCom.get_one, "spy"):
+ spy_agency.unspy(XCom.get_one)
+ spy_agency.spy_on(XCom.set, call_original=True)
+ spy_agency.spy_on(
+ XCom.get_one, call_fake=lambda cls, *args, **kwargs:
_get_one_from_set_calls(*args, **kwargs)
+ )
+
+ def _get_one_from_set_calls(*args, **kwargs) -> Any | None:
+ """Get the most recent value from XCom.set calls that matches the
criteria."""
+ key = kwargs.get("key")
+ task_id = kwargs.get("task_id")
+ dag_id = kwargs.get("dag_id")
+ run_id = kwargs.get("run_id")
+ map_index = kwargs.get("map_index") or -1
+
+ for call in reversed(XCom.set.calls):
+ if (
+ call.kwargs.get("task_id") == task_id
+ and call.kwargs.get("dag_id") == dag_id
+ and call.kwargs.get("run_id") == run_id
+ and call.kwargs.get("map_index") == map_index
+ ):
+ if call.args and len(call.args) >= 2:
+ call_key, value = call.args
+ if call_key == key:
+ return value
+ return None
+
+ class XComHelper:
+ def __init__(self):
+ self._ti = None
+
+ def get(
+ self,
+ key: str,
+ task_id: str | None = None,
+ dag_id: str | None = None,
+ run_id: str | None = None,
+ map_index: int | None = None,
+ ) -> Any:
+ # Use task instance values as defaults
+ task_id = task_id or self._ti.task_id
+ dag_id = dag_id or self._ti.dag_id
+ run_id = run_id or self._ti.run_id
+ map_index = map_index if map_index is not None else
self._ti.map_index
+
+ return XCom.get_one(
+ key=key,
+ task_id=task_id,
+ dag_id=dag_id,
+ run_id=run_id,
+ map_index=map_index,
+ )
+
+ def assert_pushed(
+ self,
+ key: str,
+ value: Any,
+ task_id: str | None = None,
+ dag_id: str | None = None,
+ run_id: str | None = None,
+ map_index: int | None = None,
+ **kwargs,
+ ):
+ """Assert that an XCom was pushed with the given key and value."""
+ task_id = task_id or self._ti.task_id
+ dag_id = dag_id or self._ti.dag_id
+ run_id = run_id or self._ti.run_id
+ map_index = map_index if map_index is not None else
self._ti.map_index
+
+ spy_agency.assert_spy_called_with(
+ XCom.set,
+ key,
+ value,
+ task_id=task_id,
+ dag_id=dag_id,
+ run_id=run_id,
+ map_index=map_index,
+ **kwargs,
+ )
+
+ def clear(self):
+ """Clear all XCom calls."""
+ if hasattr(XCom.set, "spy"):
+ spy_agency.unspy(XCom.set)
+ if hasattr(XCom.get_one, "spy"):
+ spy_agency.unspy(XCom.get_one)
+
+ class RunTaskWithXCom:
+ def __init__(self, create_runtime_ti):
+ self.create_runtime_ti = create_runtime_ti
+ self.xcom = XComHelper()
+ self._state = None
+ self._msg = None
+ self._error = None
+
+ @property
+ def state(self) -> IntermediateTIState | TerminalTIState:
+ """Get the task state."""
+ return self._state
+
+ @property
+ def msg(self) -> ToSupervisor | None:
+ """Get the task message to send to supervisor."""
+ return self._msg
+
+ @property
+ def error(self) -> BaseException | None:
+ """Get the error message if there was any."""
+ return self._error
+
+ def __call__(
+ self,
+ task: BaseOperator,
+ dag_id: str = "test_dag",
+ run_id: str = "test_run",
+ logical_date: datetime | None = None,
+ start_date: datetime | None = None,
+ run_type: str = "manual",
+ try_number: int = 1,
+ map_index: int | None = -1,
+ ti_id: UUID | None = None,
+ max_tries: int | None = None,
+ context_update: dict[str, Any] | None = None,
+ ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None,
BaseException | None]:
+ now = timezone.utcnow()
+ if logical_date is None:
+ logical_date = now
+
+ if start_date is None:
+ start_date = now
+
+ ti = self.create_runtime_ti(
+ task=task,
+ dag_id=dag_id,
+ run_id=run_id,
+ logical_date=logical_date,
+ start_date=start_date,
+ run_type=run_type,
+ try_number=try_number,
+ map_index=map_index,
+ ti_id=ti_id,
+ max_tries=max_tries,
+ )
+
+ context = ti.get_template_context()
+ if context_update:
+ context.update(context_update)
+ log = structlog.get_logger(logger_name="task")
+
+ # Store the task instance for XCom operations
+ self.xcom._ti = ti
+
+ # Run the task
+ state, msg, error = run(ti, context, log)
+ self._state = state
+ self._msg = msg
+ self._error = error
+
+ return state, msg, error
+
+ return RunTaskWithXCom(create_runtime_ti)
+
+
[email protected]
+def mock_xcom_backend():
+ with mock.patch("airflow.sdk.execution_time.task_runner.XCom",
create=True) as xcom_backend:
+ yield xcom_backend