This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ef2da6f1efd Port `dag.test` to Task SDK (#50300)
ef2da6f1efd is described below
commit ef2da6f1efd6606e424964dec2a184a3a8521e27
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri May 9 19:03:07 2025 +0530
Port `dag.test` to Task SDK (#50300)
closes https://github.com/apache/airflow/issues/45549
Key changes:
- Moves `dag.test` implementation to Task SDK, leveraging the existing
in-process execution infrastructure
- Adds `JWTBearerTIPathDep` for proper task instance path validation
- Updates `InProcessExecutionAPI` to support task instance validation
- Removes legacy `dag.test` implementation from DAG class
The changes ensure that `dag.test` uses the same execution path as regular
task execution.
---
.../src/airflow/api_fastapi/execution_api/app.py | 7 +-
.../src/airflow/api_fastapi/execution_api/deps.py | 3 +
.../execution_api/routes/task_instances.py | 6 +-
.../src/airflow/cli/commands/dag_command.py | 1 -
.../src/airflow/cli/commands/task_command.py | 3 +-
.../src/airflow/dag_processing/processor.py | 6 +-
airflow-core/src/airflow/models/dag.py | 321 +--------------------
airflow-core/src/airflow/models/dagrun.py | 39 ++-
.../tests/unit/cli/commands/test_dag_command.py | 24 +-
airflow-core/tests/unit/models/test_dag.py | 14 +-
.../tests/unit/models/test_mappedoperator.py | 17 +-
task-sdk/src/airflow/sdk/definitions/dag.py | 273 ++++++++++++++++++
.../src/airflow/sdk/execution_time/supervisor.py | 196 ++++++++++++-
.../src/airflow/sdk/execution_time/task_runner.py | 4 +-
14 files changed, 539 insertions(+), 375 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index ef51da98279..691853f322a 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -225,7 +225,11 @@ class InProcessExecutionAPI:
def app(self):
if not self._app:
from airflow.api_fastapi.execution_api.app import
create_task_execution_api_app
- from airflow.api_fastapi.execution_api.deps import JWTBearerDep,
JWTRefresherDep
+ from airflow.api_fastapi.execution_api.deps import (
+ JWTBearerDep,
+ JWTBearerTIPathDep,
+ JWTRefresherDep,
+ )
from airflow.api_fastapi.execution_api.routes.connections import
has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import
has_variable_access
from airflow.api_fastapi.execution_api.routes.xcoms import
has_xcom_access
@@ -235,6 +239,7 @@ class InProcessExecutionAPI:
async def always_allow(): ...
self._app.dependency_overrides[JWTBearerDep.dependency] =
always_allow
+ self._app.dependency_overrides[JWTBearerTIPathDep.dependency] =
always_allow
self._app.dependency_overrides[JWTRefresherDep.dependency] =
always_allow
self._app.dependency_overrides[has_connection_access] =
always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
index 8106a7e81e3..c2161180dbb 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
@@ -96,6 +96,9 @@ class JWTBearer(HTTPBearer):
JWTBearerDep: TIToken = Depends(JWTBearer())
+# This checks that the UUID in the url matches the one in the token for us.
+JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
+
class JWTReissuer:
"""Re-issue JWTs to requests when they are about to run out."""
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 1dfc9bb10e0..00f149f1d10 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -25,7 +25,7 @@ from uuid import UUID
import structlog
from cadwyn import VersionedAPIRouter
-from fastapi import Body, Depends, HTTPException, Query, status
+from fastapi import Body, HTTPException, Query, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
@@ -50,7 +50,7 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TISuccessStatePayload,
TITerminalStatePayload,
)
-from airflow.api_fastapi.execution_api.deps import JWTBearer
+from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
from airflow.models.dagbag import DagBag
from airflow.models.dagrun import DagRun as DR
from airflow.models.taskinstance import TaskInstance as TI,
_stop_remaining_tasks
@@ -70,7 +70,7 @@ router = VersionedAPIRouter()
ti_id_router = VersionedAPIRouter(
dependencies=[
# This checks that the UUID in the url matches the one in the token
for us.
- Depends(JWTBearer(path_param_name="task_instance_id")),
+ JWTBearerTIPathDep
]
)
diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py
b/airflow-core/src/airflow/cli/commands/dag_command.py
index b1151f34091..1b4b017e6c7 100644
--- a/airflow-core/src/airflow/cli/commands/dag_command.py
+++ b/airflow-core/src/airflow/cli/commands/dag_command.py
@@ -644,7 +644,6 @@ def dag_test(args, dag: DAG | None = None, session: Session
= NEW_SESSION) -> No
run_conf=run_conf,
use_executor=use_executor,
mark_success_pattern=mark_success_pattern,
- session=session,
)
show_dagrun = args.show_dagrun
imgcat = args.imgcat_dagrun
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py
b/airflow-core/src/airflow/cli/commands/task_command.py
index 0a4e771315b..a3492a828c0 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -33,8 +33,9 @@ from airflow.cli.simple_table import AirflowConsole
from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string
from airflow.exceptions import DagRunNotFound, TaskDeferred,
TaskInstanceNotFound
from airflow.models import TaskInstance
-from airflow.models.dag import DAG, _run_inline_trigger
+from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
+from airflow.sdk.definitions.dag import _run_inline_trigger
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.execution_time.secrets_masker import RedactedIO
from airflow.ti_deps.dep_context import DepContext
diff --git a/airflow-core/src/airflow/dag_processing/processor.py
b/airflow-core/src/airflow/dag_processing/processor.py
index ef3e5cb69ca..73d2c23c7f5 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -161,7 +161,11 @@ def _execute_dag_callbacks(dagbag: DagBag, request:
DagCallbackRequest, log: Fil
callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
# TODO:We need a proper context object!
- context: Context = {} # type: ignore[assignment]
+ context: Context = { # type: ignore[assignment]
+ "dag": dag,
+ "run_id": request.run_id,
+ "reason": request.msg,
+ }
for callback in callbacks:
log.info(
diff --git a/airflow-core/src/airflow/models/dag.py
b/airflow-core/src/airflow/models/dag.py
index 8024f417ef4..57c1050e3ce 100644
--- a/airflow-core/src/airflow/models/dag.py
+++ b/airflow-core/src/airflow/models/dag.py
@@ -17,20 +17,14 @@
# under the License.
from __future__ import annotations
-import asyncio
import copy
import functools
import logging
import re
-import sys
-import time
from collections import defaultdict
from collections.abc import Collection, Generator, Iterable, Sequence
-from contextlib import ExitStack
from datetime import datetime, timedelta
from functools import cache
-from pathlib import Path
-from re import Pattern
from typing import (
TYPE_CHECKING,
Any,
@@ -70,14 +64,12 @@ from sqlalchemy.sql import Select, expression
from airflow import settings, utils
from airflow.assets.evaluation import AssetEvaluator
-from airflow.configuration import conf as airflow_conf, secrets_backend_list
+from airflow.configuration import conf as airflow_conf
from airflow.exceptions import (
AirflowException,
- TaskDeferred,
UnknownExecutorException,
)
from airflow.executors.executor_loader import ExecutorLoader
-from airflow.executors.workloads import BundleInfo
from airflow.models.asset import (
AssetDagRunQueue,
AssetModel,
@@ -95,9 +87,7 @@ from airflow.models.tasklog import LogTemplate
from airflow.sdk import TaskGroup
from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey,
BaseAsset
from airflow.sdk.definitions.dag import DAG as TaskSDKDag, dag as
task_sdk_dag_decorator
-from airflow.secrets.local_filesystem import LocalFilesystemBackend
from airflow.settings import json
-from airflow.stats import Stats
from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction,
Timetable
from airflow.timetables.interval import CronDataIntervalTimetable,
DeltaDataIntervalTimetable
from airflow.timetables.simple import (
@@ -111,7 +101,7 @@ from airflow.utils.dag_cycle_tester import check_cycle
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, with_row_locks
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, TaskInstanceState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
if TYPE_CHECKING:
@@ -121,7 +111,6 @@ if TYPE_CHECKING:
from airflow.models.dagbag import DagBag
from airflow.models.operator import Operator
- from airflow.sdk.definitions._internal.abstractoperator import
TaskStateChangeCallback
from airflow.serialization.serialized_objects import MaybeSerializedDAG
from airflow.typing_compat import Literal
@@ -777,89 +766,6 @@ class DAG(TaskSDKDag, LoggingMixin):
"""Stringified DAGs and operators contain exactly these fields."""
return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"}
- @staticmethod
- @provide_session
- def fetch_callback(
- dag: DAG,
- run_id: str,
- success: bool = True,
- reason: str | None = None,
- *,
- session: Session = NEW_SESSION,
- ) -> tuple[list[TaskStateChangeCallback], Context] | None:
- """
- Fetch the appropriate callbacks depending on the value of success.
-
- This method gets the context of a single TaskInstance part of this
DagRun and returns it along
- the list of callbacks.
-
- :param dag: DAG object
- :param run_id: The DAG run ID
- :param success: Flag to specify if failure or success callback should
be called
- :param reason: Completion reason
- :param session: Database session
- """
- callbacks = dag.on_success_callback if success else
dag.on_failure_callback
- if callbacks:
- dagrun = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=run_id,
session=session)
- callbacks = callbacks if isinstance(callbacks, list) else
[callbacks]
- tis = dagrun.get_task_instances(session=session)
- # tis from a dagrun may not be a part of dag.partial_subset,
- # since dag.partial_subset is a subset of the dag.
- # This ensures that we will only use the accessible TI
- # context for the callback.
- if dag.partial:
- tis = [ti for ti in tis if not ti.state == State.NONE]
- # filter out removed tasks
- tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED]
- ti = tis[-1] # get first TaskInstance of DagRun
- ti.task = dag.get_task(ti.task_id)
- context = ti.get_template_context(session=session)
- context["reason"] = reason
- return callbacks, context
- return None
-
- @provide_session
- def handle_callback(self, dagrun: DagRun, success=True, reason=None,
session=NEW_SESSION):
- """
- Triggers on_failure_callback or on_success_callback as appropriate.
-
- This method gets the context of a single TaskInstance part of this
DagRun
- and passes that to the callable along with a 'reason', primarily to
- differentiate DagRun failures.
-
- .. note: The logs end up in
- ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log``
-
- :param dagrun: DagRun object
- :param success: Flag to specify if failure or success callback should
be called
- :param reason: Completion reason
- :param session: Database session
- """
- callbacks, context = DAG.fetch_callback(
- dag=self, run_id=dagrun.run_id, success=success, reason=reason,
session=session
- ) or (None, None)
-
- DAG.execute_callback(callbacks, context, self.dag_id)
-
- @classmethod
- def execute_callback(cls, callbacks: list[Callable] | None, context:
Context | None, dag_id: str):
- """
- Triggers the callbacks with the given context.
-
- :param callbacks: List of callbacks to call
- :param context: Context to pass to all callbacks
- :param dag_id: The dag_id of the DAG to find.
- """
- if callbacks and context:
- for callback in callbacks:
- cls.logger().info("Executing dag callback function: %s",
callback)
- try:
- callback(context)
- except Exception:
- cls.logger().exception("failed to invoke dag state update
callback")
- Stats.incr("dag.callback_exceptions", tags={"dag_id":
dag_id})
-
def get_active_runs(self):
"""
Return a list of dag run logical dates currently running.
@@ -1603,188 +1509,6 @@ class DAG(TaskSDKDag, LoggingMixin):
args = parser.parse_args()
args.func(args, self)
- @provide_session
- def test(
- self,
- run_after: datetime | None = None,
- logical_date: datetime | None = None,
- run_conf: dict[str, Any] | None = None,
- conn_file_path: str | None = None,
- variable_file_path: str | None = None,
- use_executor: bool = False,
- mark_success_pattern: Pattern | str | None = None,
- session: Session = NEW_SESSION,
- ) -> DagRun:
- """
- Execute one single DagRun for a given DAG and logical date.
-
- :param run_after: the datetime before which to Dag cannot run.
- :param logical_date: logical date for the DAG run
- :param run_conf: configuration to pass to newly created dagrun
- :param conn_file_path: file path to a connection file in either yaml
or json
- :param variable_file_path: file path to a variable file in either yaml
or json
- :param use_executor: if set, uses an executor to test the DAG
- :param mark_success_pattern: regex of task_ids to mark as success
instead of running
- :param session: database connection (optional)
- """
- from airflow.serialization.serialized_objects import SerializedDAG
-
- def add_logger_if_needed(ti: TaskInstance):
- """
- Add a formatted logger to the task instance.
-
- This allows all logs to surface to the command line, instead of
into
- a task file. Since this is a local test run, it is much better for
- the user to see logs in the command line, rather than needing to
- search for a log file.
-
- :param ti: The task instance that will receive a logger.
- """
- format = logging.Formatter("[%(asctime)s]
{%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
- handler = logging.StreamHandler(sys.stdout)
- handler.level = logging.INFO
- handler.setFormatter(format)
- # only add log handler once
- if not any(isinstance(h, logging.StreamHandler) for h in
ti.log.handlers):
- self.log.debug("Adding Streamhandler to taskinstance %s",
ti.task_id)
- ti.log.addHandler(handler)
-
- exit_stack = ExitStack()
- if conn_file_path or variable_file_path:
- local_secrets = LocalFilesystemBackend(
- variables_file_path=variable_file_path,
connections_file_path=conn_file_path
- )
- secrets_backend_list.insert(0, local_secrets)
- exit_stack.callback(lambda: secrets_backend_list.pop(0))
-
- with exit_stack:
- self.validate()
- self.log.debug("Clearing existing task instances for logical date
%s", logical_date)
- self.clear(
- start_date=logical_date,
- end_date=logical_date,
- dag_run_state=False, # type: ignore
- session=session,
- )
- self.log.debug("Getting dagrun for dag %s", self.dag_id)
- logical_date = timezone.coerce_datetime(logical_date)
- run_after = timezone.coerce_datetime(run_after) or
timezone.coerce_datetime(timezone.utcnow())
- data_interval = (
-
self.timetable.infer_manual_data_interval(run_after=logical_date) if
logical_date else None
- )
- scheduler_dag =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self))
-
- dr: DagRun = _get_or_create_dagrun(
- dag=scheduler_dag,
- start_date=logical_date or run_after,
- logical_date=logical_date,
- data_interval=data_interval,
- run_after=run_after,
- run_id=DagRun.generate_run_id(
- run_type=DagRunType.MANUAL,
- logical_date=logical_date,
- run_after=run_after,
- ),
- session=session,
- conf=run_conf,
- triggered_by=DagRunTriggeredByType.TEST,
- )
- # Start a mock span so that one is present and not started
downstream. We
- # don't care about otel in dag.test and starting the span during
dagrun update
- # is not functioning properly in this context anyway.
- dr.start_dr_spans_if_needed(tis=[])
-
- tasks = self.task_dict
- self.log.debug("starting dagrun")
- # Instead of starting a scheduler, we run the minimal loop
possible to check
- # for task readiness and dependency management.
-
- # ``Dag.test()`` works in two different modes depending on
``use_executor``:
- # - if ``use_executor`` is False, runs the task locally with no
executor using ``_run_task``
- # - if ``use_executor`` is True, sends the task instances to the
executor with
- # ``BaseExecutor.queue_task_instance``
- if use_executor:
- from airflow.models.dagbag import DagBag
-
- dag_bag = DagBag()
- dag_bag.bag_dag(self)
-
- executor = ExecutorLoader.get_default_executor()
- executor.start()
-
- while dr.state == DagRunState.RUNNING:
- session.expire_all()
- schedulable_tis, _ = dr.update_state(session=session)
- for s in schedulable_tis:
- if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
- s.try_number += 1
- s.state = TaskInstanceState.SCHEDULED
- s.scheduled_dttm = timezone.utcnow()
- session.commit()
- # triggerer may mark tasks scheduled so we read from DB
- all_tis = set(dr.get_task_instances(session=session))
- scheduled_tis = {x for x in all_tis if x.state ==
TaskInstanceState.SCHEDULED}
- ids_unrunnable = {x for x in all_tis if x.state not in
State.finished} - scheduled_tis
- if not scheduled_tis and ids_unrunnable:
- self.log.warning("No tasks to run. unrunnable tasks: %s",
ids_unrunnable)
- time.sleep(1)
-
- triggerer_running = _triggerer_is_healthy(session)
- for ti in scheduled_tis:
- ti.task = tasks[ti.task_id]
-
- mark_success = (
- re.compile(mark_success_pattern).fullmatch(ti.task_id)
is not None
- if mark_success_pattern is not None
- else False
- )
-
- if use_executor:
- if executor.has_task(ti):
- continue
- # TODO: Task-SDK: This check is transitionary. Remove
once all executors are ported over.
- from airflow.executors import workloads
- from airflow.executors.base_executor import
BaseExecutor
-
- if executor.queue_workload.__func__ is not
BaseExecutor.queue_workload: # type: ignore[attr-defined]
- workload = workloads.ExecuteTask.make(
- ti,
- dag_rel_path=Path(self.fileloc),
- generator=executor.jwt_generator,
- # For the system test/debug purpose, we use
the default bundle which uses
- # local file system. If it turns out to be a
feature people want, we could
- # plumb the Bundle to use as a parameter to
dag.test
- bundle_info=BundleInfo(name="dags-folder"),
- )
- executor.queue_workload(workload, session=session)
- ti.state = TaskInstanceState.QUEUED
- session.commit()
- else:
- # Send the task to the executor
- executor.queue_task_instance(ti,
ignore_ti_state=True)
- else:
- # Run the task locally
- try:
- add_logger_if_needed(ti)
- _run_task(
- ti=ti,
- inline_trigger=not triggerer_running,
- session=session,
- mark_success=mark_success,
- )
- except Exception:
- self.log.exception("Task failed; ti=%s", ti)
- if use_executor:
- executor.heartbeat()
- from airflow.jobs.scheduler_job_runner import
SchedulerDagBag, SchedulerJobRunner
-
- SchedulerJobRunner.process_executor_events(
- executor=executor, job_id=None,
scheduler_dag_bag=SchedulerDagBag(), session=session
- )
- if use_executor:
- executor.end()
- return dr
-
@provide_session
def create_dagrun(
self,
@@ -2535,47 +2259,6 @@ if STATICA_HACK: # pragma: no cover
""":sphinx-autoapi-skip:"""
-def _run_inline_trigger(trigger):
- async def _run_inline_trigger_main():
- # We can replace it with `return await anext(trigger.run(),
default=None)`
- # when we drop support for Python 3.9
- try:
- return await trigger.run().__anext__()
- except StopAsyncIteration:
- return None
-
- return asyncio.run(_run_inline_trigger_main())
-
-
-def _run_task(
- *, ti: TaskInstance, inline_trigger: bool = False, mark_success: bool =
False, session: Session
-):
- """
- Run a single task instance, and push result to Xcom for downstream tasks.
-
- Bypasses a lot of extra steps used in `task.run` to keep our local running
as fast as
- possible. This function is only meant for the `dag.test` function as a
helper function.
-
- Args:
- ti: TaskInstance to run
- """
- log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id,
ti.map_index)
- while True:
- try:
- log.info("[DAG TEST] running task %s", ti)
- ti._run_raw_task(session=session, raise_on_defer=inline_trigger,
mark_success=mark_success)
- break
- except TaskDeferred as e:
- log.info("[DAG TEST] running trigger in line")
- event = _run_inline_trigger(e.trigger)
- ti.next_method = e.method_name
- ti.next_kwargs = {"event": event.payload} if event else e.kwargs
- log.info("[DAG TEST] Trigger completed")
- session.merge(ti)
- session.commit()
- log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id,
ti.map_index)
-
-
def _get_or_create_dagrun(
*,
dag: DAG,
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 13848164355..11a65f87055 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -100,6 +100,7 @@ if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.dag_version import DagVersion
from airflow.models.operator import Operator
+ from airflow.sdk import DAG as SDKDAG, Context
from airflow.typing_compat import Literal
from airflow.utils.types import ArgNotSet
@@ -1147,8 +1148,8 @@ class DagRun(Base, LoggingMixin):
self.set_state(DagRunState.FAILED)
self.notify_dagrun_state_changed(msg="task_failure")
- if execute_callbacks:
- dag.handle_callback(self, success=False,
reason="task_failure", session=session)
+ if execute_callbacks and dag.has_on_failure_callback:
+ self.handle_dag_callback(dag=dag, success=False,
reason="task_failure")
elif dag.has_on_failure_callback:
callback = DagCallbackRequest(
filepath=self.dag_model.relative_fileloc,
@@ -1176,8 +1177,8 @@ class DagRun(Base, LoggingMixin):
self.set_state(DagRunState.SUCCESS)
self.notify_dagrun_state_changed(msg="success")
- if execute_callbacks:
- dag.handle_callback(self, success=True, reason="success",
session=session)
+ if execute_callbacks and dag.has_on_success_callback:
+ self.handle_dag_callback(dag=dag, success=True,
reason="success")
elif dag.has_on_success_callback:
callback = DagCallbackRequest(
filepath=self.dag_model.relative_fileloc,
@@ -1195,8 +1196,8 @@ class DagRun(Base, LoggingMixin):
self.set_state(DagRunState.FAILED)
self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
- if execute_callbacks:
- dag.handle_callback(self, success=False,
reason="all_tasks_deadlocked", session=session)
+ if execute_callbacks and dag.has_on_failure_callback:
+ self.handle_dag_callback(dag=dag, success=False,
reason="all_tasks_deadlocked")
elif dag.has_on_failure_callback:
callback = DagCallbackRequest(
filepath=self.dag_model.relative_fileloc,
@@ -1316,6 +1317,32 @@ class DagRun(Base, LoggingMixin):
# we can't get all the state changes on SchedulerJob,
# or LocalTaskJob, so we don't want to "falsely advertise" we notify
about that
+ def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason:
str = "success"):
+ """Only needed for `dag.test` where `execute_callbacks=True` is passed
to `update_state`."""
+ context: Context = { # type: ignore[assignment]
+ "dag": dag,
+ "run_id": str(self.run_id),
+ "reason": reason,
+ }
+
+ callbacks = dag.on_success_callback if success else
dag.on_failure_callback
+ if not callbacks:
+ self.log.warning("Callback requested, but dag didn't have any for
DAG: %s.", dag.dag_id)
+ return
+ callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
+
+ for callback in callbacks:
+ self.log.info(
+ "Executing on_%s dag callback: %s",
+ "success" if success else "failure",
+ callback.__name__ if hasattr(callback, "__name__") else
repr(callback),
+ )
+ try:
+ callback(context)
+ except Exception:
+ self.log.exception("Callback failed for %s", dag.dag_id)
+ Stats.incr("dag.callback_exceptions", tags={"dag_id":
dag.dag_id})
+
def _get_ready_tis(
self,
schedulable_tis: list[TI],
diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py
b/airflow-core/tests/unit/cli/commands/test_dag_command.py
index 278248486bd..d1a37fd5392 100644
--- a/airflow-core/tests/unit/cli/commands/test_dag_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py
@@ -38,10 +38,10 @@ from airflow.cli.commands import dag_command
from airflow.exceptions import AirflowException
from airflow.models import DagBag, DagModel, DagRun
from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import _run_inline_trigger
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.standard.triggers.temporal import DateTimeTrigger,
TimeDeltaTrigger
from airflow.sdk import task
+from airflow.sdk.definitions.dag import _run_inline_trigger
from airflow.triggers.base import TriggerEvent
from airflow.utils import timezone
from airflow.utils.session import create_session
@@ -631,7 +631,6 @@ class TestCliDags:
run_conf=None,
use_executor=False,
mark_success_pattern=None,
- session=mock.ANY,
),
]
)
@@ -665,7 +664,6 @@ class TestCliDags:
logical_date=mock.ANY,
run_conf=None,
use_executor=False,
- session=mock.ANY,
mark_success_pattern=None,
),
]
@@ -693,7 +691,6 @@ class TestCliDags:
logical_date=timezone.parse(DEFAULT_DATE.isoformat()),
run_conf={"dag_run_conf_param": "param_value"},
use_executor=False,
- session=mock.ANY,
mark_success_pattern=None,
),
]
@@ -722,7 +719,6 @@ class TestCliDags:
logical_date=timezone.parse(DEFAULT_DATE.isoformat()),
run_conf=None,
use_executor=False,
- session=mock.ANY,
mark_success_pattern=None,
),
]
@@ -773,7 +769,9 @@ class TestCliDags:
assert e.payload == now
def test_dag_test_no_triggerer_running(self, dag_maker):
- with mock.patch("airflow.models.dag._run_inline_trigger",
wraps=_run_inline_trigger) as mock_run:
+ with mock.patch(
+ "airflow.sdk.definitions.dag._run_inline_trigger",
wraps=_run_inline_trigger
+ ) as mock_run:
with dag_maker() as dag:
@task
@@ -806,12 +804,16 @@ class TestCliDags:
op = MyOp(task_id="abc", tfield=task_two)
task_two >> op
dr = dag.test()
- assert mock_run.call_args_list[0] == ((trigger,), {})
+
+ trigger_arg = mock_run.call_args_list[0].args[0]
+ assert isinstance(trigger_arg, DateTimeTrigger)
+ assert trigger_arg.moment == trigger.moment
+
tis = dr.get_task_instances()
assert next(x for x in tis if x.task_id == "abc").state ==
"success"
-
@mock.patch("airflow.models.taskinstance.TaskInstance._execute_task_with_callbacks")
- def test_dag_test_with_mark_success(self,
mock__execute_task_with_callbacks):
+ @mock.patch("airflow.sdk.execution_time.task_runner._execute_task")
+ def test_dag_test_with_mark_success(self, mock__execute_task):
"""
option `--mark-success-pattern` should mark matching tasks as success
without executing them.
"""
@@ -828,8 +830,8 @@ class TestCliDags:
dag_command.dag_test(cli_args)
# only second operator was actually executed, first one was marked as
success
- assert len(mock__execute_task_with_callbacks.call_args_list) == 1
- assert
mock__execute_task_with_callbacks.call_args_list[0].kwargs["self"].task_id ==
"dummy_operator"
+ assert len(mock__execute_task.call_args_list) == 1
+ assert mock__execute_task.call_args_list[0].kwargs["ti"].task_id ==
"dummy_operator"
class TestCliDagsReserialize:
diff --git a/airflow-core/tests/unit/models/test_dag.py
b/airflow-core/tests/unit/models/test_dag.py
index e9b42e7dbaa..548461ccdec 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -1031,7 +1031,7 @@ class TestDag:
assert dag_run.state == State.RUNNING
assert dag_run.run_type != DagRunType.MANUAL
- @patch("airflow.models.dag.Stats")
+ @patch("airflow.models.dagrun.Stats")
def test_dag_handle_callback_crash(self, mock_stats):
"""
Tests avoid crashes from calling dag callbacks exceptions
@@ -1062,8 +1062,8 @@ class TestDag:
)
# should not raise any exception
- dag.handle_callback(dag_run, success=False)
- dag.handle_callback(dag_run, success=True)
+ dag_run.handle_dag_callback(dag=dag, success=False)
+ dag_run.handle_dag_callback(dag=dag, success=True)
mock_stats.incr.assert_called_with(
"dag.callback_exceptions",
@@ -1102,8 +1102,8 @@ class TestDag:
assert dag_run.get_task_instance(task_removed.task_id).state ==
TaskInstanceState.REMOVED
# should not raise any exception
- dag.handle_callback(dag_run, success=True)
- dag.handle_callback(dag_run, success=False)
+ dag_run.handle_dag_callback(dag=dag, success=False)
+ dag_run.handle_dag_callback(dag=dag, success=True)
@pytest.mark.parametrize("catchup,expected_next_dagrun", [(True,
DEFAULT_DATE), (False, None)])
def test_next_dagrun_after_fake_scheduled_previous(self, catchup,
expected_next_dagrun):
@@ -1507,8 +1507,8 @@ class TestDag:
mock_handle_object_1(f"task {ti.task_id} failed...")
def handle_dag_failure(context):
- ti = context["task_instance"]
- mock_handle_object_2(f"dag {ti.dag_id} run failed...")
+ dag_id = context["dag"].dag_id
+ mock_handle_object_2(f"dag {dag_id} run failed...")
dag = DAG(
dag_id="test_local_testing_conn_file",
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 74453449d97..b116082111a 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -32,7 +32,6 @@ from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import setup, task, task_group, teardown
-from airflow.sdk.execution_time.comms import XComCountResponse, XComResult
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
@@ -1270,21 +1269,7 @@ class TestMappedSetupTeardown:
tg1, tg2 = dag.task_group.children.values()
tg1 >> tg2
- with mock.patch(
- "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS",
create=True
- ) as supervisor_comms:
- # TODO: TaskSDK: this is a bit of a hack that we need to stub this
at all. `dag.test()` should
- # really work without this!
- supervisor_comms.get_message.side_effect = [
- XComCountResponse(len=3),
- XComResult(key="return_value", value=1),
- XComCountResponse(len=3),
- XComResult(key="return_value", value=2),
- XComCountResponse(len=3),
- XComResult(key="return_value", value=3),
- ]
- dr = dag.test()
- assert supervisor_comms.get_message.call_count == 6
+ dr = dag.test()
states = self.get_states(dr)
expected = {
"tg_1.my_pre_setup": "success",
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 8b6ed3e4b19..972fea56241 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -70,6 +70,8 @@ from airflow.utils.decorators import
fixup_decorator_warning_stack
from airflow.utils.trigger_rule import TriggerRule
if TYPE_CHECKING:
+ from re import Pattern
+
from pendulum.tz.timezone import FixedTimezone, Timezone
from airflow.decorators import TaskDecoratorCollection
@@ -1014,6 +1016,277 @@ class DAG:
f"Bad formatted links are: {wrong_links}"
)
+ def test(
+ self,
+ run_after: datetime | None = None,
+ logical_date: datetime | None = None,
+ run_conf: dict[str, Any] | None = None,
+ conn_file_path: str | None = None,
+ variable_file_path: str | None = None,
+ use_executor: bool = False,
+ mark_success_pattern: Pattern | str | None = None,
+ ):
+ """
+ Execute one single DagRun for a given DAG and logical date.
+
+ :param run_after: the datetime before which to Dag cannot run.
+ :param logical_date: logical date for the DAG run
+ :param run_conf: configuration to pass to newly created dagrun
+ :param conn_file_path: file path to a connection file in either yaml
or json
+ :param variable_file_path: file path to a variable file in either yaml
or json
+ :param use_executor: if set, uses an executor to test the DAG
+ :param mark_success_pattern: regex of task_ids to mark as success
instead of running
+ """
+ import re
+ import time
+ from contextlib import ExitStack
+
+ from airflow import settings
+ from airflow.configuration import secrets_backend_list
+ from airflow.models.dag import DAG as SchedulerDAG,
_get_or_create_dagrun
+ from airflow.models.dagrun import DagRun
+ from airflow.secrets.local_filesystem import LocalFilesystemBackend
+ from airflow.serialization.serialized_objects import SerializedDAG
+ from airflow.utils import timezone
+ from airflow.utils.state import DagRunState, State, TaskInstanceState
+ from airflow.utils.types import DagRunTriggeredByType, DagRunType
+
+ if TYPE_CHECKING:
+ from airflow.models.taskinstance import TaskInstance
+
+ def add_logger_if_needed(ti: TaskInstance):
+ """
+ Add a formatted logger to the task instance.
+
+ This allows all logs to surface to the command line, instead of
into
+ a task file. Since this is a local test run, it is much better for
+ the user to see logs in the command line, rather than needing to
+ search for a log file.
+
+ :param ti: The task instance that will receive a logger.
+ """
+ format = logging.Formatter("[%(asctime)s]
{%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
+ handler = logging.StreamHandler(sys.stdout)
+ handler.level = logging.INFO
+ handler.setFormatter(format)
+ # only add log handler once
+ if not any(isinstance(h, logging.StreamHandler) for h in
ti.log.handlers):
+ log.debug("Adding Streamhandler to taskinstance %s",
ti.task_id)
+ ti.log.addHandler(handler)
+
+ exit_stack = ExitStack()
+
+ if conn_file_path or variable_file_path:
+ local_secrets = LocalFilesystemBackend(
+ variables_file_path=variable_file_path,
connections_file_path=conn_file_path
+ )
+ secrets_backend_list.insert(0, local_secrets)
+ exit_stack.callback(lambda: secrets_backend_list.pop(0))
+
+ session = settings.Session()
+
+ with exit_stack:
+ self.validate()
+ log.debug("Clearing existing task instances for logical date %s",
logical_date)
+ # TODO: Replace with calling client.dag_run.clear in Execution API
at some point
+ SchedulerDAG.clear_dags(
+ dags=[self],
+ start_date=logical_date,
+ end_date=logical_date,
+ dag_run_state=False, # type: ignore
+ )
+
+ log.debug("Getting dagrun for dag %s", self.dag_id)
+ logical_date = timezone.coerce_datetime(logical_date)
+ run_after = timezone.coerce_datetime(run_after) or
timezone.coerce_datetime(timezone.utcnow())
+ data_interval = (
+
self.timetable.infer_manual_data_interval(run_after=logical_date) if
logical_date else None
+ )
+ scheduler_dag =
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self)) # type:
ignore[arg-type]
+
+ dr: DagRun = _get_or_create_dagrun(
+ dag=scheduler_dag,
+ start_date=logical_date or run_after,
+ logical_date=logical_date,
+ data_interval=data_interval,
+ run_after=run_after,
+ run_id=DagRun.generate_run_id(
+ run_type=DagRunType.MANUAL,
+ logical_date=logical_date,
+ run_after=run_after,
+ ),
+ session=session,
+ conf=run_conf,
+ triggered_by=DagRunTriggeredByType.TEST,
+ )
+ # Start a mock span so that one is present and not started
downstream. We
+ # don't care about otel in dag.test and starting the span during
dagrun update
+ # is not functioning properly in this context anyway.
+ dr.start_dr_spans_if_needed(tis=[])
+ dr.dag = self # type: ignore[assignment]
+
+ tasks = self.task_dict
+ log.debug("starting dagrun")
+ # Instead of starting a scheduler, we run the minimal loop
possible to check
+ # for task readiness and dependency management.
+ # Instead of starting a scheduler, we run the minimal loop
possible to check
+ # for task readiness and dependency management.
+
+ # ``Dag.test()`` works in two different modes depending on
``use_executor``:
+ # - if ``use_executor`` is False, runs the task locally with no
executor using ``_run_task``
+ # - if ``use_executor`` is True, sends the task instances to the
executor with
+ # ``BaseExecutor.queue_task_instance``
+ if use_executor:
+ from airflow.executors.base_executor import ExecutorLoader
+
+ executor = ExecutorLoader.get_default_executor()
+ executor.start()
+
+ while dr.state == DagRunState.RUNNING:
+ session.expire_all()
+ schedulable_tis, _ = dr.update_state(session=session)
+ for s in schedulable_tis:
+ if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+ s.try_number += 1
+ s.state = TaskInstanceState.SCHEDULED
+ s.scheduled_dttm = timezone.utcnow()
+ session.commit()
+ # triggerer may mark tasks scheduled so we read from DB
+ all_tis = set(dr.get_task_instances(session=session))
+ scheduled_tis = {x for x in all_tis if x.state ==
TaskInstanceState.SCHEDULED}
+ ids_unrunnable = {x for x in all_tis if x.state not in
State.finished} - scheduled_tis
+ if not scheduled_tis and ids_unrunnable:
+ log.warning("No tasks to run. unrunnable tasks: %s",
ids_unrunnable)
+ time.sleep(1)
+
+ for ti in scheduled_tis:
+ ti.task = tasks[ti.task_id]
+
+ mark_success = (
+ re.compile(mark_success_pattern).fullmatch(ti.task_id)
is not None
+ if mark_success_pattern is not None
+ else False
+ )
+
+ if use_executor:
+ if executor.has_task(ti):
+ continue
+
+ from pathlib import Path
+
+ from airflow.executors import workloads
+ from airflow.executors.base_executor import
ExecutorLoader
+ from airflow.executors.workloads import BundleInfo
+
+ workload = workloads.ExecuteTask.make(
+ ti,
+ dag_rel_path=Path(self.fileloc),
+ generator=executor.jwt_generator,
+ # For the system test/debug purpose, we use the
default bundle which uses
+ # local file system. If it turns out to be a
feature people want, we could
+ # plumb the Bundle to use as a parameter to
dag.test
+ bundle_info=BundleInfo(name="dags-folder"),
+ )
+ executor.queue_workload(workload, session=session)
+ ti.state = TaskInstanceState.QUEUED
+ session.commit()
+ else:
+ # Run the task locally
+ try:
+ add_logger_if_needed(ti)
+ if mark_success:
+ ti.set_state(State.SUCCESS)
+ log.info("[DAG TEST] Marking success for %s on
%s", ti.task, ti.logical_date)
+ else:
+ _run_task(ti=ti)
+ except Exception:
+ log.exception("Task failed; ti=%s", ti)
+ if use_executor:
+ executor.heartbeat()
+ from airflow.jobs.scheduler_job_runner import
SchedulerDagBag, SchedulerJobRunner
+
+ SchedulerJobRunner.process_executor_events(
+ executor=executor, job_id=None,
scheduler_dag_bag=SchedulerDagBag(), session=session
+ )
+ if use_executor:
+ executor.end()
+ return dr
+
+
+def _run_task(*, ti):
+ """
+ Run a single task instance, and push result to Xcom for downstream tasks.
+
+ Bypasses a lot of extra steps used in `task.run` to keep our local running
as fast as
+ possible. This function is only meant for the `dag.test` function as a
helper function.
+ """
+ from airflow.utils.module_loading import import_string
+ from airflow.utils.state import State
+
+ log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id,
ti.map_index)
+ while True:
+ try:
+ log.info("[DAG TEST] running task %s", ti)
+
+ from airflow.sdk.api.datamodels._generated import TaskInstance as
TaskInstanceSDK
+ from airflow.sdk.execution_time.comms import DeferTask
+ from airflow.sdk.execution_time.supervisor import
run_task_in_process
+
+ # The API Server expects the task instance to be in QUEUED state
before
+ # it is run.
+ ti.set_state(State.QUEUED)
+
+ taskrun_result = run_task_in_process(
+ ti=TaskInstanceSDK(
+ id=ti.id,
+ task_id=ti.task_id,
+ dag_id=ti.task.dag_id,
+ run_id=ti.run_id,
+ try_number=ti.try_number,
+ map_index=ti.map_index,
+ ),
+ task=ti.task,
+ )
+
+ msg = taskrun_result.msg
+
+ if taskrun_result.ti.state == State.DEFERRED and isinstance(msg,
DeferTask):
+ # API Server expects the task instance to be in QUEUED state
before
+ # resuming from deferral.
+ ti.set_state(State.QUEUED)
+
+ log.info("[DAG TEST] running trigger in line")
+ trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
+ event = _run_inline_trigger(trigger)
+ ti.next_method = msg.next_method
+ ti.next_kwargs = {"event": event.payload} if event else
msg.kwargs
+ log.info("[DAG TEST] Trigger completed")
+
+ ti.set_state(State.SUCCESS)
+ break
+ except Exception:
+ log.exception("[DAG TEST] Error running task %s", ti)
+ if ti.state not in State.finished:
+ ti.set_state(State.FAILED)
+ break
+ raise
+
+ log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id,
ti.map_index)
+
+
+def _run_inline_trigger(trigger):
+ import asyncio
+
+ async def _run_inline_trigger_main():
+ # We can replace it with `return await anext(trigger.run(),
default=None)`
+ # when we drop support for Python 3.9
+ try:
+ return await trigger.run().__anext__()
+ except StopAsyncIteration:
+ return None
+
+ return asyncio.run(_run_inline_trigger_main())
+
# Since we define all the attributes of the class with attrs, we can compute
this statically at parse time
DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - {
# type: ignore[attr-defined]
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index b5cf977488b..0e7b7f54cd1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -27,8 +27,9 @@ import selectors
import signal
import sys
import time
+from collections import deque
from collections.abc import Generator
-from contextlib import suppress
+from contextlib import contextmanager, suppress
from datetime import datetime, timezone
from http import HTTPStatus
from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair
@@ -42,6 +43,7 @@ from typing import (
)
from uuid import UUID
+import aiologic
import attrs
import httpx
import msgspec
@@ -837,6 +839,15 @@ class ActivitySubprocess(WatchedSubprocess):
# If it hasn't, assume it's failed
self._exit_code = self._exit_code if self._exit_code is not None else 1
+ self.update_task_state_if_needed()
+
+ # Now at the last possible moment, when all logs and comms with the
subprocess has finished, lets
+ # upload the remote logs
+ self._upload_logs()
+
+ return self._exit_code
+
+ def update_task_state_if_needed(self):
# If the process has finished non-directly patched state (directly
means deferred, reschedule, etc.),
# update the state of the TaskInstance to reflect the final state of
the process.
# For states like `deferred`, `up_for_reschedule`, the process will
exit with 0, but the state will be updated
@@ -849,12 +860,6 @@ class ActivitySubprocess(WatchedSubprocess):
rendered_map_index=self._rendered_map_index,
)
- # Now at the last possible moment, when all logs and comms with the
subprocess has finished, lets
- # upload the remote logs
- self._upload_logs()
-
- return self._exit_code
-
def _upload_logs(self):
"""
Upload all log files found to the remote storage.
@@ -1155,6 +1160,183 @@ class ActivitySubprocess(WatchedSubprocess):
self.send_msg(resp, **dump_opts)
+def in_process_api_server():
+ from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
+
+ api = InProcessExecutionAPI()
+ return api
+
+
[email protected]
+class InProcessSupervisorComms:
+ """In-process communication handler that uses deques instead of sockets."""
+
+ supervisor: InProcessTestSupervisor
+ messages: deque[BaseModel] = attrs.field(factory=deque)
+ lock: aiologic.Lock = attrs.field(factory=aiologic.Lock)
+
+ def get_message(self) -> BaseModel:
+ """Get a message from the supervisor. Blocks until a message is
available."""
+ return self.messages.popleft()
+
+ def send_request(self, log, msg: BaseModel):
+ """Send a request to the supervisor."""
+ log.debug("Sending request", msg=msg)
+
+ with set_supervisor_comms(None):
+ self.supervisor._handle_request(msg, log) # type: ignore[arg-type]
+
+
[email protected]
+class TaskRunResult:
+ """Result of running a task via ``InProcessTestSupervisor``."""
+
+ ti: RuntimeTI
+ state: str
+ msg: BaseModel | None
+ error: BaseException | None
+
+
[email protected](kw_only=True)
+class InProcessTestSupervisor(ActivitySubprocess):
+ """A supervisor that runs tasks in-process for easier testing."""
+
+ comms: InProcessSupervisorComms = attrs.field(init=False)
+ stdin = attrs.field(init=False)
+
+ @classmethod
+ def start( # type: ignore[override]
+ cls,
+ *,
+ what: TaskInstance,
+ task,
+ logger: FilteringBoundLogger | None = None,
+ **kwargs,
+ ) -> TaskRunResult:
+ """
+ Run a task in-process without spawning a new child process.
+
+ This bypasses the standard `ActivitySubprocess.start()` behavior,
which expects
+ to launch a subprocess and communicate via stdin/stdout. Instead, it
constructs
+ the `RuntimeTaskInstance` directly — useful in contexts like
`dag.test()` where the
+ DAG is already parsed in memory.
+
+ Supervisor state and communications are simulated in-memory via
`InProcessSupervisorComms`.
+ """
+ # Create supervisor instance
+ supervisor = cls(
+ id=what.id,
+ pid=os.getpid(), # Use current process
+ process=psutil.Process(), # Current process
+ requests_fd=-1, # Not used in in-process mode
+ process_log=logger or
structlog.get_logger(logger_name="task").bind(),
+ client=cls._api_client(task.dag),
+ **kwargs,
+ )
+
+ from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance, finalize, run
+
+ supervisor.comms = InProcessSupervisorComms(supervisor=supervisor)
+ with set_supervisor_comms(supervisor.comms):
+ supervisor.ti = what # type: ignore[assignment]
+
+ # We avoid calling `task_runner.startup()` because we are already
inside a
+ # parsed DAG file (e.g. via dag.test()).
+ # In normal execution, `startup()` parses the DAG based on info in
a `StartupDetails` message.
+ # By directly constructing the `RuntimeTaskInstance`,
+ # we skip re-parsing (`task_runner.parse()`) and avoid needing
to set DAG Bundle config
+ # and run the task in-process.
+ start_date = datetime.now(tz=timezone.utc)
+ ti_context = supervisor.client.task_instances.start(supervisor.id,
supervisor.pid, start_date)
+
+ ti = RuntimeTaskInstance.model_construct(
+ **what.model_dump(exclude_unset=True),
+ task=task,
+ _ti_context_from_server=ti_context,
+ max_tries=ti_context.max_tries,
+ start_date=start_date,
+ state=TaskInstanceState.RUNNING,
+ )
+ context = ti.get_template_context()
+ log = structlog.get_logger(logger_name="task")
+
+ state, msg, error = run(ti, context, log)
+ finalize(ti, state, context, log, error)
+
+ # In the normal subprocess model, the task runner calls this
before exiting.
+ # Since we're running in-process, we manually notify the API
server that
+ # the task has finished—unless the terminal state was already sent
explicitly.
+ supervisor.update_task_state_if_needed()
+
+ return TaskRunResult(ti=ti, state=state, msg=msg, error=error)
+
+ @staticmethod
+ def _api_client(dag=None):
+ from airflow.models.dagbag import DagBag
+ from airflow.sdk.api.client import Client
+
+ api = in_process_api_server()
+ if dag is not None:
+ from airflow.api_fastapi.common.deps import _get_dag_bag
+ from airflow.serialization.serialized_objects import SerializedDAG
+
+ # This is needed since the Execution API server uses the DagBag in
its "state".
+ # This `app.state.dag_bag` is used to get some DAG properties like
`fail_fast`.
+ dag_bag = DagBag(include_examples=False, collect_dags=False,
load_op_links=False)
+
+ # Mimic the behavior of the DagBag in the API server by converting
the DAG to a SerializedDAG
+ dag_bag.dags[dag.dag_id] =
SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+ api.app.dependency_overrides[_get_dag_bag] = lambda: dag_bag
+
+ client = Client(base_url=None, token="", dry_run=True,
transport=api.transport)
+ # Mypy is wrong -- the setter accepts a string on the property setter!
`URLType = URL | str`
+ client.base_url = "http://in-process.invalid./" # type:
ignore[assignment]
+ return client
+
+ def send_msg(self, msg: BaseModel, **dump_opts):
+ """Override to use in-process comms."""
+ self.comms.messages.append(msg)
+
+ @property
+ def final_state(self):
+ """Override to use in-process comms."""
+ # Since we're running in-process, we don't have a final state until
the task has finished.
+ # We also don't have a process exit code to determine success/failure.
+ return self._terminal_state
+
+
+@contextmanager
+def set_supervisor_comms(temp_comms):
+ """
+ Temporarily override `SUPERVISOR_COMMS` in the `task_runner` module.
+
+ This is used to simulate task-runner ↔ supervisor communication in-process,
+ by injecting a test Comms implementation (e.g. `InProcessSupervisorComms`)
+ in place of the real inter-process communication layer.
+
+ Some parts of the code (e.g. models.Variable.get) check for the presence
+ of `task_runner.SUPERVISOR_COMMS` to determine if the code is running in a
Task SDK execution context.
+ This override ensures those code paths behave correctly during in-process
tests.
+ """
+ from airflow.sdk.execution_time import task_runner
+
+ old = getattr(task_runner, "SUPERVISOR_COMMS", None)
+ task_runner.SUPERVISOR_COMMS = temp_comms
+ try:
+ yield
+ finally:
+ if old is not None:
+ task_runner.SUPERVISOR_COMMS = old
+ else:
+ delattr(task_runner, "SUPERVISOR_COMMS")
+
+
+def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult:
+ """Run a task in-process for testing."""
+ # Run the task
+ return InProcessTestSupervisor.start(what=ti, task=task)
+
+
# Sockets, even the `.makefile()` function don't correctly do line buffering
on reading. If a chunk is read
# and it doesn't contain a new line character, `.readline()` will just return
the chunk as is.
#
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 9092ee86f0b..66a3d02cd8c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -835,7 +835,7 @@ def run(
return state, msg, error
try:
- result = _execute_task(context, ti, log)
+ result = _execute_task(context=context, ti=ti, log=log)
except Exception:
import jinja2
@@ -886,7 +886,7 @@ def run(
)
state = TaskInstanceState.FAILED
error = e
- except (AirflowTaskTimeout, AirflowException) as e:
+ except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
# We should allow retries if the task has defined it.
log.exception("Task failed with exception")
msg, state = _handle_current_task_failed(ti)