This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 80840cb5734 Add a fixture to easily replace `ti.run` usage (#48439)
80840cb5734 is described below

commit 80840cb5734630f95d831dd10ae103c3df3a4b2f
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Mar 27 18:53:42 2025 +0530

    Add a fixture to easily replace `ti.run` usage (#48439)
    
    As we are replacing BaseOperator usage from Core to Task SDK, we are 
running into several issues, one of the common one being over-usage of 
`task.run()`.
    
    While some cases can be easily replaced by `task.execute()` others needs 
execution of the tasks, sharing of XCom's in between, checking task state, 
correct exception etc.
    
    To make this easier I have added `run_task` fixture which I have been using 
in https://github.com/apache/airflow/pull/48244 and it has worked out well.
    
    Example:
---
 devel-common/src/tests_common/pytest_plugin.py | 427 +++++++++++++++++++++++++
 1 file changed, 427 insertions(+)

diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index ad885d176c7..7d04f4fe221 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -35,6 +35,8 @@ import pytest
 import time_machine
 
 if TYPE_CHECKING:
+    from uuid import UUID
+
     from itsdangerous import URLSafeSerializer
     from sqlalchemy.orm import Session
 
@@ -43,6 +45,10 @@ if TYPE_CHECKING:
     from airflow.models.dagrun import DagRun, DagRunType
     from airflow.models.taskinstance import TaskInstance
     from airflow.providers.standard.operators.empty import EmptyOperator
+    from airflow.sdk.api.datamodels._generated import IntermediateTIState, 
TerminalTIState
+    from airflow.sdk.definitions.baseoperator import BaseOperator as 
TaskSDKBaseOperator
+    from airflow.sdk.execution_time.comms import StartupDetails, ToSupervisor
+    from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
     from airflow.timetables.base import DataInterval
     from airflow.typing_compat import Self
     from airflow.utils.state import DagRunState, TaskInstanceState
@@ -1872,3 +1878,424 @@ def mock_supervisor_comms():
         "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", create=True
     ) as supervisor_comms:
         yield supervisor_comms
+
+
[email protected]
+def mocked_parse(spy_agency):
+    """
+    Fixture to set up an inline DAG and use it in a stubbed `parse` function.
+
+    Use this fixture if you want to isolate and test `parse` or `run` logic 
without having to define a DAG file.
+    In most cases, you should use `create_runtime_ti` fixture instead where 
you can directly pass an operator
+    compared to lower level AIP-72 constructs like `StartupDetails`.
+
+    This fixture returns a helper function `set_dag` that:
+    1. Creates an in line DAG with the given `dag_id` and `task` (limited to 
one task)
+    2. Constructs a `RuntimeTaskInstance` based on the provided 
`StartupDetails` and task.
+    3. Stubs the `parse` function using `spy_agency`, to return the mocked 
`RuntimeTaskInstance`.
+
+    After adding the fixture in your test function signature, you can use it 
like this ::
+
+            mocked_parse(
+                StartupDetails(
+                    ti=TaskInstance(id=uuid7(), task_id="hello", 
dag_id="super_basic_run", run_id="c", try_number=1),
+                    file="",
+                    requests_fd=0,
+                ),
+                "example_dag_id",
+                CustomOperator(task_id="hello"),
+            )
+    """
+
+    def set_dag(what: StartupDetails, dag_id: str, task: TaskSDKBaseOperator) 
-> RuntimeTaskInstance:
+        from airflow.sdk.definitions.dag import DAG
+        from airflow.sdk.execution_time.task_runner import 
RuntimeTaskInstance, parse
+        from airflow.utils import timezone
+
+        if not task.has_dag():
+            dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
+            task.dag = dag  # type: ignore[assignment]
+            task = dag.task_dict[task.task_id]
+        else:
+            dag = task.dag
+        if what.ti_context.dag_run.conf:
+            dag.params = what.ti_context.dag_run.conf  # type: 
ignore[assignment]
+        ti = RuntimeTaskInstance.model_construct(
+            **what.ti.model_dump(exclude_unset=True),
+            task=task,
+            _ti_context_from_server=what.ti_context,
+            max_tries=what.ti_context.max_tries,
+            start_date=what.start_date,
+        )
+        if hasattr(parse, "spy"):
+            spy_agency.unspy(parse)
+        spy_agency.spy_on(parse, call_fake=lambda _: ti)
+        return ti
+
+    return set_dag
+
+
+class _XComHelperProtocol(Protocol):
+    def get(
+        self,
+        key: str,
+        task_id: str | None = None,
+        dag_id: str | None = None,
+        run_id: str | None = None,
+        map_index: int | None = None,
+    ) -> Any: ...
+
+    def assert_pushed(
+        self,
+        key: str,
+        value: Any,
+        task_id: str | None = None,
+        dag_id: str | None = None,
+        run_id: str | None = None,
+        map_index: int | None = None,
+        **kwargs,
+    ) -> None: ...
+
+    def clear(self): ...
+
+
+class RunTaskCallable(Protocol):
+    """Protocol for better type hints for the fixture `run_task`."""
+
+    @property
+    def state(self) -> IntermediateTIState | TerminalTIState: ...
+
+    @property
+    def msg(self) -> ToSupervisor | None: ...
+
+    @property
+    def error(self) -> BaseException | None: ...
+
+    xcom: _XComHelperProtocol
+
+    def __call__(
+        self,
+        task: BaseOperator,
+        dag_id: str = ...,
+        run_id: str = ...,
+        logical_date: datetime | None = None,
+        start_date: datetime | None = None,
+        run_type: str = ...,
+        try_number: int = ...,
+        map_index: int | None = ...,
+        ti_id: UUID | None = None,
+        max_tries: int | None = None,
+        context_update: dict[str, Any] | None = None,
+    ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, 
BaseException | None]: ...
+
+
[email protected]
+def create_runtime_ti(mocked_parse):
+    """
+    Fixture to create a Runtime TaskInstance for testing purposes without 
defining a dag file.
+
+    It mimics the behavior of the `parse` function by creating a 
`RuntimeTaskInstance` based on the provided
+    `StartupDetails` (formed from arguments) and task. This allows you to test 
the logic of a task without
+    having to define a DAG file, parse it, get context from the server, etc.
+
+    Example usage: ::
+
+        def test_custom_task_instance(create_runtime_ti):
+            class MyTaskOperator(BaseOperator):
+                def execute(self, context):
+                    assert context["dag_run"].run_id == "test_run"
+
+            task = MyTaskOperator(task_id="test_task")
+            ti = create_runtime_ti(task)
+            # Further test logic...
+    """
+    from uuid6 import uuid7
+
+    from airflow.sdk.api.datamodels._generated import TaskInstance
+    from airflow.sdk.definitions.dag import DAG
+    from airflow.sdk.execution_time.comms import BundleInfo, StartupDetails
+    from airflow.utils import timezone
+
+    def _create_task_instance(
+        task: BaseOperator,
+        dag_id: str = "test_dag",
+        run_id: str = "test_run",
+        logical_date: str | datetime = "2024-12-01T01:00:00Z",
+        start_date: str | datetime = "2024-12-01T01:00:00Z",
+        run_type: str = "manual",
+        try_number: int = 1,
+        map_index: int | None = -1,
+        upstream_map_indexes: dict[str, int] | None = None,
+        task_reschedule_count: int = 0,
+        ti_id: UUID | None = None,
+        conf: dict[str, Any] | None = None,
+        should_retry: bool | None = None,
+        max_tries: int | None = None,
+    ) -> RuntimeTaskInstance:
+        from airflow.sdk.api.datamodels._generated import DagRun, TIRunContext
+
+        if not ti_id:
+            ti_id = uuid7()
+
+        if not task.has_dag():
+            dag = DAG(dag_id=dag_id, start_date=timezone.datetime(2024, 12, 3))
+            task.dag = dag  # type: ignore[assignment]
+            task = dag.task_dict[task.task_id]
+
+        if task.dag.timetable:
+            data_interval_start, data_interval_end = 
task.dag.timetable.infer_manual_data_interval(
+                run_after=logical_date  # type: ignore
+            )
+        else:
+            data_interval_start = None
+            data_interval_end = None
+
+        dag_id = task.dag.dag_id
+        task_retries = task.retries or 0
+        run_after = data_interval_end or logical_date or timezone.utcnow()
+
+        ti_context = TIRunContext(
+            dag_run=DagRun(
+                dag_id=dag_id,
+                run_id=run_id,
+                logical_date=logical_date,  # type: ignore
+                data_interval_start=data_interval_start,
+                data_interval_end=data_interval_end,
+                start_date=start_date,  # type: ignore
+                run_type=run_type,  # type: ignore
+                run_after=run_after,  # type: ignore
+                conf=conf,
+            ),
+            task_reschedule_count=task_reschedule_count,
+            max_tries=task_retries if max_tries is None else max_tries,
+            should_retry=should_retry if should_retry is not None else 
try_number <= task_retries,
+        )
+
+        if upstream_map_indexes is not None:
+            ti_context.upstream_map_indexes = upstream_map_indexes
+
+        startup_details = StartupDetails(
+            ti=TaskInstance(
+                id=ti_id,
+                task_id=task.task_id,
+                dag_id=dag_id,
+                run_id=run_id,
+                try_number=try_number,
+                map_index=map_index,
+            ),
+            dag_rel_path="",
+            bundle_info=BundleInfo(name="anything", version="any"),
+            requests_fd=0,
+            ti_context=ti_context,
+            start_date=start_date,  # type: ignore
+        )
+
+        ti = mocked_parse(startup_details, dag_id, task)
+        return ti
+
+    return _create_task_instance
+
+
[email protected]
+def run_task(create_runtime_ti, mock_supervisor_comms, spy_agency) -> 
RunTaskCallable:
+    """
+    Fixture to run a task without defining a dag file.
+
+    This fixture builds on top of create_runtime_ti to provide a convenient 
way to execute tasks and get their results.
+
+    The fixture provides:
+    - run_task.state - Get the task state
+    - run_task.msg - Get the task message
+    - run_task.error - Get the task error
+    - run_task.xcom.get(key) - Get an XCom value
+    - run_task.xcom.assert_pushed(key, value, ...) - Assert an XCom was pushed
+
+    Example usage: ::
+
+        def test_custom_task(run_task):
+            class MyTaskOperator(BaseOperator):
+                def execute(self, context):
+                    return "hello"
+
+            task = MyTaskOperator(task_id="test_task")
+            run_task(task)
+            assert run_task.state == TerminalTIState.SUCCESS
+            assert run_task.error is None
+    """
+    import structlog
+
+    from airflow.sdk.execution_time.task_runner import run
+    from airflow.sdk.execution_time.xcom import XCom
+    from airflow.utils import timezone
+
+    # Set up spies once at fixture level
+    if hasattr(XCom.set, "spy"):
+        spy_agency.unspy(XCom.set)
+    if hasattr(XCom.get_one, "spy"):
+        spy_agency.unspy(XCom.get_one)
+    spy_agency.spy_on(XCom.set, call_original=True)
+    spy_agency.spy_on(
+        XCom.get_one, call_fake=lambda cls, *args, **kwargs: 
_get_one_from_set_calls(*args, **kwargs)
+    )
+
+    def _get_one_from_set_calls(*args, **kwargs) -> Any | None:
+        """Get the most recent value from XCom.set calls that matches the 
criteria."""
+        key = kwargs.get("key")
+        task_id = kwargs.get("task_id")
+        dag_id = kwargs.get("dag_id")
+        run_id = kwargs.get("run_id")
+        map_index = kwargs.get("map_index") or -1
+
+        for call in reversed(XCom.set.calls):
+            if (
+                call.kwargs.get("task_id") == task_id
+                and call.kwargs.get("dag_id") == dag_id
+                and call.kwargs.get("run_id") == run_id
+                and call.kwargs.get("map_index") == map_index
+            ):
+                if call.args and len(call.args) >= 2:
+                    call_key, value = call.args
+                    if call_key == key:
+                        return value
+        return None
+
+    class XComHelper:
+        def __init__(self):
+            self._ti = None
+
+        def get(
+            self,
+            key: str,
+            task_id: str | None = None,
+            dag_id: str | None = None,
+            run_id: str | None = None,
+            map_index: int | None = None,
+        ) -> Any:
+            # Use task instance values as defaults
+            task_id = task_id or self._ti.task_id
+            dag_id = dag_id or self._ti.dag_id
+            run_id = run_id or self._ti.run_id
+            map_index = map_index if map_index is not None else 
self._ti.map_index
+
+            return XCom.get_one(
+                key=key,
+                task_id=task_id,
+                dag_id=dag_id,
+                run_id=run_id,
+                map_index=map_index,
+            )
+
+        def assert_pushed(
+            self,
+            key: str,
+            value: Any,
+            task_id: str | None = None,
+            dag_id: str | None = None,
+            run_id: str | None = None,
+            map_index: int | None = None,
+            **kwargs,
+        ):
+            """Assert that an XCom was pushed with the given key and value."""
+            task_id = task_id or self._ti.task_id
+            dag_id = dag_id or self._ti.dag_id
+            run_id = run_id or self._ti.run_id
+            map_index = map_index if map_index is not None else 
self._ti.map_index
+
+            spy_agency.assert_spy_called_with(
+                XCom.set,
+                key,
+                value,
+                task_id=task_id,
+                dag_id=dag_id,
+                run_id=run_id,
+                map_index=map_index,
+                **kwargs,
+            )
+
+        def clear(self):
+            """Clear all XCom calls."""
+            if hasattr(XCom.set, "spy"):
+                spy_agency.unspy(XCom.set)
+            if hasattr(XCom.get_one, "spy"):
+                spy_agency.unspy(XCom.get_one)
+
+    class RunTaskWithXCom:
+        def __init__(self, create_runtime_ti):
+            self.create_runtime_ti = create_runtime_ti
+            self.xcom = XComHelper()
+            self._state = None
+            self._msg = None
+            self._error = None
+
+        @property
+        def state(self) -> IntermediateTIState | TerminalTIState:
+            """Get the task state."""
+            return self._state
+
+        @property
+        def msg(self) -> ToSupervisor | None:
+            """Get the task message to send to supervisor."""
+            return self._msg
+
+        @property
+        def error(self) -> BaseException | None:
+            """Get the error message if there was any."""
+            return self._error
+
+        def __call__(
+            self,
+            task: BaseOperator,
+            dag_id: str = "test_dag",
+            run_id: str = "test_run",
+            logical_date: datetime | None = None,
+            start_date: datetime | None = None,
+            run_type: str = "manual",
+            try_number: int = 1,
+            map_index: int | None = -1,
+            ti_id: UUID | None = None,
+            max_tries: int | None = None,
+            context_update: dict[str, Any] | None = None,
+        ) -> tuple[IntermediateTIState | TerminalTIState, ToSupervisor | None, 
BaseException | None]:
+            now = timezone.utcnow()
+            if logical_date is None:
+                logical_date = now
+
+            if start_date is None:
+                start_date = now
+
+            ti = self.create_runtime_ti(
+                task=task,
+                dag_id=dag_id,
+                run_id=run_id,
+                logical_date=logical_date,
+                start_date=start_date,
+                run_type=run_type,
+                try_number=try_number,
+                map_index=map_index,
+                ti_id=ti_id,
+                max_tries=max_tries,
+            )
+
+            context = ti.get_template_context()
+            if context_update:
+                context.update(context_update)
+            log = structlog.get_logger(logger_name="task")
+
+            # Store the task instance for XCom operations
+            self.xcom._ti = ti
+
+            # Run the task
+            state, msg, error = run(ti, context, log)
+            self._state = state
+            self._msg = msg
+            self._error = error
+
+            return state, msg, error
+
+    return RunTaskWithXCom(create_runtime_ti)
+
+
[email protected]
+def mock_xcom_backend():
+    with mock.patch("airflow.sdk.execution_time.task_runner.XCom", 
create=True) as xcom_backend:
+        yield xcom_backend

Reply via email to