This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ef2da6f1efd Port `dag.test` to Task SDK (#50300)
ef2da6f1efd is described below

commit ef2da6f1efd6606e424964dec2a184a3a8521e27
Author: Kaxil Naik <[email protected]>
AuthorDate: Fri May 9 19:03:07 2025 +0530

    Port `dag.test` to Task SDK (#50300)
    
    closes https://github.com/apache/airflow/issues/45549
    
    Key changes:
    
    - Moves `dag.test` implementation to Task SDK, leveraging the existing 
in-process execution infrastructure
    - Adds `JWTBearerTIPathDep` for proper task instance path validation
    - Updates `InProcessExecutionAPI` to support task instance validation
    - Removes legacy `dag.test` implementation from DAG class
    
    The changes ensure that `dag.test` uses the same execution path as regular 
task execution.
---
 .../src/airflow/api_fastapi/execution_api/app.py   |   7 +-
 .../src/airflow/api_fastapi/execution_api/deps.py  |   3 +
 .../execution_api/routes/task_instances.py         |   6 +-
 .../src/airflow/cli/commands/dag_command.py        |   1 -
 .../src/airflow/cli/commands/task_command.py       |   3 +-
 .../src/airflow/dag_processing/processor.py        |   6 +-
 airflow-core/src/airflow/models/dag.py             | 321 +--------------------
 airflow-core/src/airflow/models/dagrun.py          |  39 ++-
 .../tests/unit/cli/commands/test_dag_command.py    |  24 +-
 airflow-core/tests/unit/models/test_dag.py         |  14 +-
 .../tests/unit/models/test_mappedoperator.py       |  17 +-
 task-sdk/src/airflow/sdk/definitions/dag.py        | 273 ++++++++++++++++++
 .../src/airflow/sdk/execution_time/supervisor.py   | 196 ++++++++++++-
 .../src/airflow/sdk/execution_time/task_runner.py  |   4 +-
 14 files changed, 539 insertions(+), 375 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index ef51da98279..691853f322a 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -225,7 +225,11 @@ class InProcessExecutionAPI:
     def app(self):
         if not self._app:
             from airflow.api_fastapi.execution_api.app import 
create_task_execution_api_app
-            from airflow.api_fastapi.execution_api.deps import JWTBearerDep, 
JWTRefresherDep
+            from airflow.api_fastapi.execution_api.deps import (
+                JWTBearerDep,
+                JWTBearerTIPathDep,
+                JWTRefresherDep,
+            )
             from airflow.api_fastapi.execution_api.routes.connections import 
has_connection_access
             from airflow.api_fastapi.execution_api.routes.variables import 
has_variable_access
             from airflow.api_fastapi.execution_api.routes.xcoms import 
has_xcom_access
@@ -235,6 +239,7 @@ class InProcessExecutionAPI:
             async def always_allow(): ...
 
             self._app.dependency_overrides[JWTBearerDep.dependency] = 
always_allow
+            self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = 
always_allow
             self._app.dependency_overrides[JWTRefresherDep.dependency] = 
always_allow
             self._app.dependency_overrides[has_connection_access] = 
always_allow
             self._app.dependency_overrides[has_variable_access] = always_allow
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
index 8106a7e81e3..c2161180dbb 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
@@ -96,6 +96,9 @@ class JWTBearer(HTTPBearer):
 
 JWTBearerDep: TIToken = Depends(JWTBearer())
 
+# This checks that the UUID in the url matches the one in the token for us.
+JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
+
 
 class JWTReissuer:
     """Re-issue JWTs to requests when they are about to run out."""
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 1dfc9bb10e0..00f149f1d10 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -25,7 +25,7 @@ from uuid import UUID
 
 import structlog
 from cadwyn import VersionedAPIRouter
-from fastapi import Body, Depends, HTTPException, Query, status
+from fastapi import Body, HTTPException, Query, status
 from pydantic import JsonValue
 from sqlalchemy import func, or_, tuple_, update
 from sqlalchemy.exc import NoResultFound, SQLAlchemyError
@@ -50,7 +50,7 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
     TISuccessStatePayload,
     TITerminalStatePayload,
 )
-from airflow.api_fastapi.execution_api.deps import JWTBearer
+from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
 from airflow.models.dagbag import DagBag
 from airflow.models.dagrun import DagRun as DR
 from airflow.models.taskinstance import TaskInstance as TI, 
_stop_remaining_tasks
@@ -70,7 +70,7 @@ router = VersionedAPIRouter()
 ti_id_router = VersionedAPIRouter(
     dependencies=[
         # This checks that the UUID in the url matches the one in the token 
for us.
-        Depends(JWTBearer(path_param_name="task_instance_id")),
+        JWTBearerTIPathDep
     ]
 )
 
diff --git a/airflow-core/src/airflow/cli/commands/dag_command.py 
b/airflow-core/src/airflow/cli/commands/dag_command.py
index b1151f34091..1b4b017e6c7 100644
--- a/airflow-core/src/airflow/cli/commands/dag_command.py
+++ b/airflow-core/src/airflow/cli/commands/dag_command.py
@@ -644,7 +644,6 @@ def dag_test(args, dag: DAG | None = None, session: Session 
= NEW_SESSION) -> No
         run_conf=run_conf,
         use_executor=use_executor,
         mark_success_pattern=mark_success_pattern,
-        session=session,
     )
     show_dagrun = args.show_dagrun
     imgcat = args.imgcat_dagrun
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py 
b/airflow-core/src/airflow/cli/commands/task_command.py
index 0a4e771315b..a3492a828c0 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -33,8 +33,9 @@ from airflow.cli.simple_table import AirflowConsole
 from airflow.cli.utils import fetch_dag_run_from_run_id_or_logical_date_string
 from airflow.exceptions import DagRunNotFound, TaskDeferred, 
TaskInstanceNotFound
 from airflow.models import TaskInstance
-from airflow.models.dag import DAG, _run_inline_trigger
+from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
+from airflow.sdk.definitions.dag import _run_inline_trigger
 from airflow.sdk.definitions.param import ParamsDict
 from airflow.sdk.execution_time.secrets_masker import RedactedIO
 from airflow.ti_deps.dep_context import DepContext
diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index ef3e5cb69ca..73d2c23c7f5 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -161,7 +161,11 @@ def _execute_dag_callbacks(dagbag: DagBag, request: 
DagCallbackRequest, log: Fil
 
     callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
     # TODO:We need a proper context object!
-    context: Context = {}  # type: ignore[assignment]
+    context: Context = {  # type: ignore[assignment]
+        "dag": dag,
+        "run_id": request.run_id,
+        "reason": request.msg,
+    }
 
     for callback in callbacks:
         log.info(
diff --git a/airflow-core/src/airflow/models/dag.py 
b/airflow-core/src/airflow/models/dag.py
index 8024f417ef4..57c1050e3ce 100644
--- a/airflow-core/src/airflow/models/dag.py
+++ b/airflow-core/src/airflow/models/dag.py
@@ -17,20 +17,14 @@
 # under the License.
 from __future__ import annotations
 
-import asyncio
 import copy
 import functools
 import logging
 import re
-import sys
-import time
 from collections import defaultdict
 from collections.abc import Collection, Generator, Iterable, Sequence
-from contextlib import ExitStack
 from datetime import datetime, timedelta
 from functools import cache
-from pathlib import Path
-from re import Pattern
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -70,14 +64,12 @@ from sqlalchemy.sql import Select, expression
 
 from airflow import settings, utils
 from airflow.assets.evaluation import AssetEvaluator
-from airflow.configuration import conf as airflow_conf, secrets_backend_list
+from airflow.configuration import conf as airflow_conf
 from airflow.exceptions import (
     AirflowException,
-    TaskDeferred,
     UnknownExecutorException,
 )
 from airflow.executors.executor_loader import ExecutorLoader
-from airflow.executors.workloads import BundleInfo
 from airflow.models.asset import (
     AssetDagRunQueue,
     AssetModel,
@@ -95,9 +87,7 @@ from airflow.models.tasklog import LogTemplate
 from airflow.sdk import TaskGroup
 from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetUniqueKey, 
BaseAsset
 from airflow.sdk.definitions.dag import DAG as TaskSDKDag, dag as 
task_sdk_dag_decorator
-from airflow.secrets.local_filesystem import LocalFilesystemBackend
 from airflow.settings import json
-from airflow.stats import Stats
 from airflow.timetables.base import DagRunInfo, DataInterval, TimeRestriction, 
Timetable
 from airflow.timetables.interval import CronDataIntervalTimetable, 
DeltaDataIntervalTimetable
 from airflow.timetables.simple import (
@@ -111,7 +101,7 @@ from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, with_row_locks
-from airflow.utils.state import DagRunState, State, TaskInstanceState
+from airflow.utils.state import DagRunState, TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 if TYPE_CHECKING:
@@ -121,7 +111,6 @@ if TYPE_CHECKING:
 
     from airflow.models.dagbag import DagBag
     from airflow.models.operator import Operator
-    from airflow.sdk.definitions._internal.abstractoperator import 
TaskStateChangeCallback
     from airflow.serialization.serialized_objects import MaybeSerializedDAG
     from airflow.typing_compat import Literal
 
@@ -777,89 +766,6 @@ class DAG(TaskSDKDag, LoggingMixin):
         """Stringified DAGs and operators contain exactly these fields."""
         return TaskSDKDag.get_serialized_fields() | {"_processor_dags_folder"}
 
-    @staticmethod
-    @provide_session
-    def fetch_callback(
-        dag: DAG,
-        run_id: str,
-        success: bool = True,
-        reason: str | None = None,
-        *,
-        session: Session = NEW_SESSION,
-    ) -> tuple[list[TaskStateChangeCallback], Context] | None:
-        """
-        Fetch the appropriate callbacks depending on the value of success.
-
-        This method gets the context of a single TaskInstance part of this 
DagRun and returns it along
-        the list of callbacks.
-
-        :param dag: DAG object
-        :param run_id: The DAG run ID
-        :param success: Flag to specify if failure or success callback should 
be called
-        :param reason: Completion reason
-        :param session: Database session
-        """
-        callbacks = dag.on_success_callback if success else 
dag.on_failure_callback
-        if callbacks:
-            dagrun = DAG.fetch_dagrun(dag_id=dag.dag_id, run_id=run_id, 
session=session)
-            callbacks = callbacks if isinstance(callbacks, list) else 
[callbacks]
-            tis = dagrun.get_task_instances(session=session)
-            # tis from a dagrun may not be a part of dag.partial_subset,
-            # since dag.partial_subset is a subset of the dag.
-            # This ensures that we will only use the accessible TI
-            # context for the callback.
-            if dag.partial:
-                tis = [ti for ti in tis if not ti.state == State.NONE]
-            # filter out removed tasks
-            tis = [ti for ti in tis if ti.state != TaskInstanceState.REMOVED]
-            ti = tis[-1]  # get first TaskInstance of DagRun
-            ti.task = dag.get_task(ti.task_id)
-            context = ti.get_template_context(session=session)
-            context["reason"] = reason
-            return callbacks, context
-        return None
-
-    @provide_session
-    def handle_callback(self, dagrun: DagRun, success=True, reason=None, 
session=NEW_SESSION):
-        """
-        Triggers on_failure_callback or on_success_callback as appropriate.
-
-        This method gets the context of a single TaskInstance part of this 
DagRun
-        and passes that to the callable along with a 'reason', primarily to
-        differentiate DagRun failures.
-
-        .. note: The logs end up in
-            ``$AIRFLOW_HOME/logs/scheduler/latest/PROJECT/DAG_FILE.py.log``
-
-        :param dagrun: DagRun object
-        :param success: Flag to specify if failure or success callback should 
be called
-        :param reason: Completion reason
-        :param session: Database session
-        """
-        callbacks, context = DAG.fetch_callback(
-            dag=self, run_id=dagrun.run_id, success=success, reason=reason, 
session=session
-        ) or (None, None)
-
-        DAG.execute_callback(callbacks, context, self.dag_id)
-
-    @classmethod
-    def execute_callback(cls, callbacks: list[Callable] | None, context: 
Context | None, dag_id: str):
-        """
-        Triggers the callbacks with the given context.
-
-        :param callbacks: List of callbacks to call
-        :param context: Context to pass to all callbacks
-        :param dag_id: The dag_id of the DAG to find.
-        """
-        if callbacks and context:
-            for callback in callbacks:
-                cls.logger().info("Executing dag callback function: %s", 
callback)
-                try:
-                    callback(context)
-                except Exception:
-                    cls.logger().exception("failed to invoke dag state update 
callback")
-                    Stats.incr("dag.callback_exceptions", tags={"dag_id": 
dag_id})
-
     def get_active_runs(self):
         """
         Return a list of dag run logical dates currently running.
@@ -1603,188 +1509,6 @@ class DAG(TaskSDKDag, LoggingMixin):
         args = parser.parse_args()
         args.func(args, self)
 
-    @provide_session
-    def test(
-        self,
-        run_after: datetime | None = None,
-        logical_date: datetime | None = None,
-        run_conf: dict[str, Any] | None = None,
-        conn_file_path: str | None = None,
-        variable_file_path: str | None = None,
-        use_executor: bool = False,
-        mark_success_pattern: Pattern | str | None = None,
-        session: Session = NEW_SESSION,
-    ) -> DagRun:
-        """
-        Execute one single DagRun for a given DAG and logical date.
-
-        :param run_after: the datetime before which to Dag cannot run.
-        :param logical_date: logical date for the DAG run
-        :param run_conf: configuration to pass to newly created dagrun
-        :param conn_file_path: file path to a connection file in either yaml 
or json
-        :param variable_file_path: file path to a variable file in either yaml 
or json
-        :param use_executor: if set, uses an executor to test the DAG
-        :param mark_success_pattern: regex of task_ids to mark as success 
instead of running
-        :param session: database connection (optional)
-        """
-        from airflow.serialization.serialized_objects import SerializedDAG
-
-        def add_logger_if_needed(ti: TaskInstance):
-            """
-            Add a formatted logger to the task instance.
-
-            This allows all logs to surface to the command line, instead of 
into
-            a task file. Since this is a local test run, it is much better for
-            the user to see logs in the command line, rather than needing to
-            search for a log file.
-
-            :param ti: The task instance that will receive a logger.
-            """
-            format = logging.Formatter("[%(asctime)s] 
{%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
-            handler = logging.StreamHandler(sys.stdout)
-            handler.level = logging.INFO
-            handler.setFormatter(format)
-            # only add log handler once
-            if not any(isinstance(h, logging.StreamHandler) for h in 
ti.log.handlers):
-                self.log.debug("Adding Streamhandler to taskinstance %s", 
ti.task_id)
-                ti.log.addHandler(handler)
-
-        exit_stack = ExitStack()
-        if conn_file_path or variable_file_path:
-            local_secrets = LocalFilesystemBackend(
-                variables_file_path=variable_file_path, 
connections_file_path=conn_file_path
-            )
-            secrets_backend_list.insert(0, local_secrets)
-            exit_stack.callback(lambda: secrets_backend_list.pop(0))
-
-        with exit_stack:
-            self.validate()
-            self.log.debug("Clearing existing task instances for logical date 
%s", logical_date)
-            self.clear(
-                start_date=logical_date,
-                end_date=logical_date,
-                dag_run_state=False,  # type: ignore
-                session=session,
-            )
-            self.log.debug("Getting dagrun for dag %s", self.dag_id)
-            logical_date = timezone.coerce_datetime(logical_date)
-            run_after = timezone.coerce_datetime(run_after) or 
timezone.coerce_datetime(timezone.utcnow())
-            data_interval = (
-                
self.timetable.infer_manual_data_interval(run_after=logical_date) if 
logical_date else None
-            )
-            scheduler_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self))
-
-            dr: DagRun = _get_or_create_dagrun(
-                dag=scheduler_dag,
-                start_date=logical_date or run_after,
-                logical_date=logical_date,
-                data_interval=data_interval,
-                run_after=run_after,
-                run_id=DagRun.generate_run_id(
-                    run_type=DagRunType.MANUAL,
-                    logical_date=logical_date,
-                    run_after=run_after,
-                ),
-                session=session,
-                conf=run_conf,
-                triggered_by=DagRunTriggeredByType.TEST,
-            )
-            # Start a mock span so that one is present and not started 
downstream. We
-            # don't care about otel in dag.test and starting the span during 
dagrun update
-            # is not functioning properly in this context anyway.
-            dr.start_dr_spans_if_needed(tis=[])
-
-            tasks = self.task_dict
-            self.log.debug("starting dagrun")
-            # Instead of starting a scheduler, we run the minimal loop 
possible to check
-            # for task readiness and dependency management.
-
-            # ``Dag.test()`` works in two different modes depending on 
``use_executor``:
-            # - if ``use_executor`` is False, runs the task locally with no 
executor using ``_run_task``
-            # - if ``use_executor`` is True, sends the task instances to the 
executor with
-            #   ``BaseExecutor.queue_task_instance``
-            if use_executor:
-                from airflow.models.dagbag import DagBag
-
-                dag_bag = DagBag()
-                dag_bag.bag_dag(self)
-
-                executor = ExecutorLoader.get_default_executor()
-                executor.start()
-
-            while dr.state == DagRunState.RUNNING:
-                session.expire_all()
-                schedulable_tis, _ = dr.update_state(session=session)
-                for s in schedulable_tis:
-                    if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
-                        s.try_number += 1
-                    s.state = TaskInstanceState.SCHEDULED
-                    s.scheduled_dttm = timezone.utcnow()
-                session.commit()
-                # triggerer may mark tasks scheduled so we read from DB
-                all_tis = set(dr.get_task_instances(session=session))
-                scheduled_tis = {x for x in all_tis if x.state == 
TaskInstanceState.SCHEDULED}
-                ids_unrunnable = {x for x in all_tis if x.state not in 
State.finished} - scheduled_tis
-                if not scheduled_tis and ids_unrunnable:
-                    self.log.warning("No tasks to run. unrunnable tasks: %s", 
ids_unrunnable)
-                    time.sleep(1)
-
-                triggerer_running = _triggerer_is_healthy(session)
-                for ti in scheduled_tis:
-                    ti.task = tasks[ti.task_id]
-
-                    mark_success = (
-                        re.compile(mark_success_pattern).fullmatch(ti.task_id) 
is not None
-                        if mark_success_pattern is not None
-                        else False
-                    )
-
-                    if use_executor:
-                        if executor.has_task(ti):
-                            continue
-                        # TODO: Task-SDK: This check is transitionary. Remove 
once all executors are ported over.
-                        from airflow.executors import workloads
-                        from airflow.executors.base_executor import 
BaseExecutor
-
-                        if executor.queue_workload.__func__ is not 
BaseExecutor.queue_workload:  # type: ignore[attr-defined]
-                            workload = workloads.ExecuteTask.make(
-                                ti,
-                                dag_rel_path=Path(self.fileloc),
-                                generator=executor.jwt_generator,
-                                # For the system test/debug purpose, we use 
the default bundle which uses
-                                # local file system. If it turns out to be a 
feature people want, we could
-                                # plumb the Bundle to use as a parameter to 
dag.test
-                                bundle_info=BundleInfo(name="dags-folder"),
-                            )
-                            executor.queue_workload(workload, session=session)
-                            ti.state = TaskInstanceState.QUEUED
-                            session.commit()
-                        else:
-                            # Send the task to the executor
-                            executor.queue_task_instance(ti, 
ignore_ti_state=True)
-                    else:
-                        # Run the task locally
-                        try:
-                            add_logger_if_needed(ti)
-                            _run_task(
-                                ti=ti,
-                                inline_trigger=not triggerer_running,
-                                session=session,
-                                mark_success=mark_success,
-                            )
-                        except Exception:
-                            self.log.exception("Task failed; ti=%s", ti)
-                if use_executor:
-                    executor.heartbeat()
-                    from airflow.jobs.scheduler_job_runner import 
SchedulerDagBag, SchedulerJobRunner
-
-                    SchedulerJobRunner.process_executor_events(
-                        executor=executor, job_id=None, 
scheduler_dag_bag=SchedulerDagBag(), session=session
-                    )
-            if use_executor:
-                executor.end()
-        return dr
-
     @provide_session
     def create_dagrun(
         self,
@@ -2535,47 +2259,6 @@ if STATICA_HACK:  # pragma: no cover
     """:sphinx-autoapi-skip:"""
 
 
-def _run_inline_trigger(trigger):
-    async def _run_inline_trigger_main():
-        # We can replace it with `return await anext(trigger.run(), 
default=None)`
-        # when we drop support for Python 3.9
-        try:
-            return await trigger.run().__anext__()
-        except StopAsyncIteration:
-            return None
-
-    return asyncio.run(_run_inline_trigger_main())
-
-
-def _run_task(
-    *, ti: TaskInstance, inline_trigger: bool = False, mark_success: bool = 
False, session: Session
-):
-    """
-    Run a single task instance, and push result to Xcom for downstream tasks.
-
-    Bypasses a lot of extra steps used in `task.run` to keep our local running 
as fast as
-    possible.  This function is only meant for the `dag.test` function as a 
helper function.
-
-    Args:
-        ti: TaskInstance to run
-    """
-    log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
-    while True:
-        try:
-            log.info("[DAG TEST] running task %s", ti)
-            ti._run_raw_task(session=session, raise_on_defer=inline_trigger, 
mark_success=mark_success)
-            break
-        except TaskDeferred as e:
-            log.info("[DAG TEST] running trigger in line")
-            event = _run_inline_trigger(e.trigger)
-            ti.next_method = e.method_name
-            ti.next_kwargs = {"event": event.payload} if event else e.kwargs
-            log.info("[DAG TEST] Trigger completed")
-        session.merge(ti)
-        session.commit()
-    log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
-
-
 def _get_or_create_dagrun(
     *,
     dag: DAG,
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index 13848164355..11a65f87055 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -100,6 +100,7 @@ if TYPE_CHECKING:
     from airflow.models.dag import DAG
     from airflow.models.dag_version import DagVersion
     from airflow.models.operator import Operator
+    from airflow.sdk import DAG as SDKDAG, Context
     from airflow.typing_compat import Literal
     from airflow.utils.types import ArgNotSet
 
@@ -1147,8 +1148,8 @@ class DagRun(Base, LoggingMixin):
             self.set_state(DagRunState.FAILED)
             self.notify_dagrun_state_changed(msg="task_failure")
 
-            if execute_callbacks:
-                dag.handle_callback(self, success=False, 
reason="task_failure", session=session)
+            if execute_callbacks and dag.has_on_failure_callback:
+                self.handle_dag_callback(dag=dag, success=False, 
reason="task_failure")
             elif dag.has_on_failure_callback:
                 callback = DagCallbackRequest(
                     filepath=self.dag_model.relative_fileloc,
@@ -1176,8 +1177,8 @@ class DagRun(Base, LoggingMixin):
             self.set_state(DagRunState.SUCCESS)
             self.notify_dagrun_state_changed(msg="success")
 
-            if execute_callbacks:
-                dag.handle_callback(self, success=True, reason="success", 
session=session)
+            if execute_callbacks and dag.has_on_success_callback:
+                self.handle_dag_callback(dag=dag, success=True, 
reason="success")
             elif dag.has_on_success_callback:
                 callback = DagCallbackRequest(
                     filepath=self.dag_model.relative_fileloc,
@@ -1195,8 +1196,8 @@ class DagRun(Base, LoggingMixin):
             self.set_state(DagRunState.FAILED)
             self.notify_dagrun_state_changed(msg="all_tasks_deadlocked")
 
-            if execute_callbacks:
-                dag.handle_callback(self, success=False, 
reason="all_tasks_deadlocked", session=session)
+            if execute_callbacks and dag.has_on_failure_callback:
+                self.handle_dag_callback(dag=dag, success=False, 
reason="all_tasks_deadlocked")
             elif dag.has_on_failure_callback:
                 callback = DagCallbackRequest(
                     filepath=self.dag_model.relative_fileloc,
@@ -1316,6 +1317,32 @@ class DagRun(Base, LoggingMixin):
         # we can't get all the state changes on SchedulerJob,
         # or LocalTaskJob, so we don't want to "falsely advertise" we notify 
about that
 
+    def handle_dag_callback(self, dag: SDKDAG, success: bool = True, reason: 
str = "success"):
+        """Only needed for `dag.test` where `execute_callbacks=True` is passed 
to `update_state`."""
+        context: Context = {  # type: ignore[assignment]
+            "dag": dag,
+            "run_id": str(self.run_id),
+            "reason": reason,
+        }
+
+        callbacks = dag.on_success_callback if success else 
dag.on_failure_callback
+        if not callbacks:
+            self.log.warning("Callback requested, but dag didn't have any for 
DAG: %s.", dag.dag_id)
+            return
+        callbacks = callbacks if isinstance(callbacks, list) else [callbacks]
+
+        for callback in callbacks:
+            self.log.info(
+                "Executing on_%s dag callback: %s",
+                "success" if success else "failure",
+                callback.__name__ if hasattr(callback, "__name__") else 
repr(callback),
+            )
+            try:
+                callback(context)
+            except Exception:
+                self.log.exception("Callback failed for %s", dag.dag_id)
+                Stats.incr("dag.callback_exceptions", tags={"dag_id": 
dag.dag_id})
+
     def _get_ready_tis(
         self,
         schedulable_tis: list[TI],
diff --git a/airflow-core/tests/unit/cli/commands/test_dag_command.py 
b/airflow-core/tests/unit/cli/commands/test_dag_command.py
index 278248486bd..d1a37fd5392 100644
--- a/airflow-core/tests/unit/cli/commands/test_dag_command.py
+++ b/airflow-core/tests/unit/cli/commands/test_dag_command.py
@@ -38,10 +38,10 @@ from airflow.cli.commands import dag_command
 from airflow.exceptions import AirflowException
 from airflow.models import DagBag, DagModel, DagRun
 from airflow.models.baseoperator import BaseOperator
-from airflow.models.dag import _run_inline_trigger
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.providers.standard.triggers.temporal import DateTimeTrigger, 
TimeDeltaTrigger
 from airflow.sdk import task
+from airflow.sdk.definitions.dag import _run_inline_trigger
 from airflow.triggers.base import TriggerEvent
 from airflow.utils import timezone
 from airflow.utils.session import create_session
@@ -631,7 +631,6 @@ class TestCliDags:
                     run_conf=None,
                     use_executor=False,
                     mark_success_pattern=None,
-                    session=mock.ANY,
                 ),
             ]
         )
@@ -665,7 +664,6 @@ class TestCliDags:
                     logical_date=mock.ANY,
                     run_conf=None,
                     use_executor=False,
-                    session=mock.ANY,
                     mark_success_pattern=None,
                 ),
             ]
@@ -693,7 +691,6 @@ class TestCliDags:
                     logical_date=timezone.parse(DEFAULT_DATE.isoformat()),
                     run_conf={"dag_run_conf_param": "param_value"},
                     use_executor=False,
-                    session=mock.ANY,
                     mark_success_pattern=None,
                 ),
             ]
@@ -722,7 +719,6 @@ class TestCliDags:
                     logical_date=timezone.parse(DEFAULT_DATE.isoformat()),
                     run_conf=None,
                     use_executor=False,
-                    session=mock.ANY,
                     mark_success_pattern=None,
                 ),
             ]
@@ -773,7 +769,9 @@ class TestCliDags:
         assert e.payload == now
 
     def test_dag_test_no_triggerer_running(self, dag_maker):
-        with mock.patch("airflow.models.dag._run_inline_trigger", 
wraps=_run_inline_trigger) as mock_run:
+        with mock.patch(
+            "airflow.sdk.definitions.dag._run_inline_trigger", 
wraps=_run_inline_trigger
+        ) as mock_run:
             with dag_maker() as dag:
 
                 @task
@@ -806,12 +804,16 @@ class TestCliDags:
                 op = MyOp(task_id="abc", tfield=task_two)
                 task_two >> op
             dr = dag.test()
-            assert mock_run.call_args_list[0] == ((trigger,), {})
+
+            trigger_arg = mock_run.call_args_list[0].args[0]
+            assert isinstance(trigger_arg, DateTimeTrigger)
+            assert trigger_arg.moment == trigger.moment
+
             tis = dr.get_task_instances()
             assert next(x for x in tis if x.task_id == "abc").state == 
"success"
 
-    
@mock.patch("airflow.models.taskinstance.TaskInstance._execute_task_with_callbacks")
-    def test_dag_test_with_mark_success(self, 
mock__execute_task_with_callbacks):
+    @mock.patch("airflow.sdk.execution_time.task_runner._execute_task")
+    def test_dag_test_with_mark_success(self, mock__execute_task):
         """
         option `--mark-success-pattern` should mark matching tasks as success 
without executing them.
         """
@@ -828,8 +830,8 @@ class TestCliDags:
         dag_command.dag_test(cli_args)
 
         # only second operator was actually executed, first one was marked as 
success
-        assert len(mock__execute_task_with_callbacks.call_args_list) == 1
-        assert 
mock__execute_task_with_callbacks.call_args_list[0].kwargs["self"].task_id == 
"dummy_operator"
+        assert len(mock__execute_task.call_args_list) == 1
+        assert mock__execute_task.call_args_list[0].kwargs["ti"].task_id == 
"dummy_operator"
 
 
 class TestCliDagsReserialize:
diff --git a/airflow-core/tests/unit/models/test_dag.py 
b/airflow-core/tests/unit/models/test_dag.py
index e9b42e7dbaa..548461ccdec 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -1031,7 +1031,7 @@ class TestDag:
         assert dag_run.state == State.RUNNING
         assert dag_run.run_type != DagRunType.MANUAL
 
-    @patch("airflow.models.dag.Stats")
+    @patch("airflow.models.dagrun.Stats")
     def test_dag_handle_callback_crash(self, mock_stats):
         """
         Tests avoid crashes from calling dag callbacks exceptions
@@ -1062,8 +1062,8 @@ class TestDag:
             )
 
             # should not raise any exception
-            dag.handle_callback(dag_run, success=False)
-            dag.handle_callback(dag_run, success=True)
+        dag_run.handle_dag_callback(dag=dag, success=False)
+        dag_run.handle_dag_callback(dag=dag, success=True)
 
         mock_stats.incr.assert_called_with(
             "dag.callback_exceptions",
@@ -1102,8 +1102,8 @@ class TestDag:
             assert dag_run.get_task_instance(task_removed.task_id).state == 
TaskInstanceState.REMOVED
 
             # should not raise any exception
-            dag.handle_callback(dag_run, success=True)
-            dag.handle_callback(dag_run, success=False)
+            dag_run.handle_dag_callback(dag=dag, success=False)
+            dag_run.handle_dag_callback(dag=dag, success=True)
 
     @pytest.mark.parametrize("catchup,expected_next_dagrun", [(True, 
DEFAULT_DATE), (False, None)])
     def test_next_dagrun_after_fake_scheduled_previous(self, catchup, 
expected_next_dagrun):
@@ -1507,8 +1507,8 @@ class TestDag:
             mock_handle_object_1(f"task {ti.task_id} failed...")
 
         def handle_dag_failure(context):
-            ti = context["task_instance"]
-            mock_handle_object_2(f"dag {ti.dag_id} run failed...")
+            dag_id = context["dag"].dag_id
+            mock_handle_object_2(f"dag {dag_id} run failed...")
 
         dag = DAG(
             dag_id="test_local_testing_conn_file",
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py 
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 74453449d97..b116082111a 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -32,7 +32,6 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.providers.standard.operators.python import PythonOperator
 from airflow.sdk import setup, task, task_group, teardown
-from airflow.sdk.execution_time.comms import XComCountResponse, XComResult
 from airflow.utils.state import TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
@@ -1270,21 +1269,7 @@ class TestMappedSetupTeardown:
         tg1, tg2 = dag.task_group.children.values()
         tg1 >> tg2
 
-        with mock.patch(
-            "airflow.sdk.execution_time.task_runner.SUPERVISOR_COMMS", 
create=True
-        ) as supervisor_comms:
-            # TODO: TaskSDK: this is a bit of a hack that we need to stub this 
at all. `dag.test()` should
-            # really work without this!
-            supervisor_comms.get_message.side_effect = [
-                XComCountResponse(len=3),
-                XComResult(key="return_value", value=1),
-                XComCountResponse(len=3),
-                XComResult(key="return_value", value=2),
-                XComCountResponse(len=3),
-                XComResult(key="return_value", value=3),
-            ]
-            dr = dag.test()
-            assert supervisor_comms.get_message.call_count == 6
+        dr = dag.test()
         states = self.get_states(dr)
         expected = {
             "tg_1.my_pre_setup": "success",
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 8b6ed3e4b19..972fea56241 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -70,6 +70,8 @@ from airflow.utils.decorators import 
fixup_decorator_warning_stack
 from airflow.utils.trigger_rule import TriggerRule
 
 if TYPE_CHECKING:
+    from re import Pattern
+
     from pendulum.tz.timezone import FixedTimezone, Timezone
 
     from airflow.decorators import TaskDecoratorCollection
@@ -1014,6 +1016,277 @@ class DAG:
                 f"Bad formatted links are: {wrong_links}"
             )
 
+    def test(
+        self,
+        run_after: datetime | None = None,
+        logical_date: datetime | None = None,
+        run_conf: dict[str, Any] | None = None,
+        conn_file_path: str | None = None,
+        variable_file_path: str | None = None,
+        use_executor: bool = False,
+        mark_success_pattern: Pattern | str | None = None,
+    ):
+        """
+        Execute one single DagRun for a given DAG and logical date.
+
+        :param run_after: the datetime before which to Dag cannot run.
+        :param logical_date: logical date for the DAG run
+        :param run_conf: configuration to pass to newly created dagrun
+        :param conn_file_path: file path to a connection file in either yaml 
or json
+        :param variable_file_path: file path to a variable file in either yaml 
or json
+        :param use_executor: if set, uses an executor to test the DAG
+        :param mark_success_pattern: regex of task_ids to mark as success 
instead of running
+        """
+        import re
+        import time
+        from contextlib import ExitStack
+
+        from airflow import settings
+        from airflow.configuration import secrets_backend_list
+        from airflow.models.dag import DAG as SchedulerDAG, 
_get_or_create_dagrun
+        from airflow.models.dagrun import DagRun
+        from airflow.secrets.local_filesystem import LocalFilesystemBackend
+        from airflow.serialization.serialized_objects import SerializedDAG
+        from airflow.utils import timezone
+        from airflow.utils.state import DagRunState, State, TaskInstanceState
+        from airflow.utils.types import DagRunTriggeredByType, DagRunType
+
+        if TYPE_CHECKING:
+            from airflow.models.taskinstance import TaskInstance
+
+        def add_logger_if_needed(ti: TaskInstance):
+            """
+            Add a formatted logger to the task instance.
+
+            This allows all logs to surface to the command line, instead of 
into
+            a task file. Since this is a local test run, it is much better for
+            the user to see logs in the command line, rather than needing to
+            search for a log file.
+
+            :param ti: The task instance that will receive a logger.
+            """
+            format = logging.Formatter("[%(asctime)s] 
{%(filename)s:%(lineno)d} %(levelname)s - %(message)s")
+            handler = logging.StreamHandler(sys.stdout)
+            handler.level = logging.INFO
+            handler.setFormatter(format)
+            # only add log handler once
+            if not any(isinstance(h, logging.StreamHandler) for h in 
ti.log.handlers):
+                log.debug("Adding Streamhandler to taskinstance %s", 
ti.task_id)
+                ti.log.addHandler(handler)
+
+        exit_stack = ExitStack()
+
+        if conn_file_path or variable_file_path:
+            local_secrets = LocalFilesystemBackend(
+                variables_file_path=variable_file_path, 
connections_file_path=conn_file_path
+            )
+            secrets_backend_list.insert(0, local_secrets)
+            exit_stack.callback(lambda: secrets_backend_list.pop(0))
+
+        session = settings.Session()
+
+        with exit_stack:
+            self.validate()
+            log.debug("Clearing existing task instances for logical date %s", 
logical_date)
+            # TODO: Replace with calling client.dag_run.clear in Execution API 
at some point
+            SchedulerDAG.clear_dags(
+                dags=[self],
+                start_date=logical_date,
+                end_date=logical_date,
+                dag_run_state=False,  # type: ignore
+            )
+
+            log.debug("Getting dagrun for dag %s", self.dag_id)
+            logical_date = timezone.coerce_datetime(logical_date)
+            run_after = timezone.coerce_datetime(run_after) or 
timezone.coerce_datetime(timezone.utcnow())
+            data_interval = (
+                
self.timetable.infer_manual_data_interval(run_after=logical_date) if 
logical_date else None
+            )
+            scheduler_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(self))  # type: 
ignore[arg-type]
+
+            dr: DagRun = _get_or_create_dagrun(
+                dag=scheduler_dag,
+                start_date=logical_date or run_after,
+                logical_date=logical_date,
+                data_interval=data_interval,
+                run_after=run_after,
+                run_id=DagRun.generate_run_id(
+                    run_type=DagRunType.MANUAL,
+                    logical_date=logical_date,
+                    run_after=run_after,
+                ),
+                session=session,
+                conf=run_conf,
+                triggered_by=DagRunTriggeredByType.TEST,
+            )
+            # Start a mock span so that one is present and not started 
downstream. We
+            # don't care about otel in dag.test and starting the span during 
dagrun update
+            # is not functioning properly in this context anyway.
+            dr.start_dr_spans_if_needed(tis=[])
+            dr.dag = self  # type: ignore[assignment]
+
+            tasks = self.task_dict
+            log.debug("starting dagrun")
+            # Instead of starting a scheduler, we run the minimal loop 
possible to check
+            # for task readiness and dependency management.
+            # Instead of starting a scheduler, we run the minimal loop 
possible to check
+            # for task readiness and dependency management.
+
+            # ``Dag.test()`` works in two different modes depending on 
``use_executor``:
+            # - if ``use_executor`` is False, runs the task locally with no 
executor using ``_run_task``
+            # - if ``use_executor`` is True, sends the task instances to the 
executor with
+            #   ``BaseExecutor.queue_task_instance``
+            if use_executor:
+                from airflow.executors.base_executor import ExecutorLoader
+
+                executor = ExecutorLoader.get_default_executor()
+                executor.start()
+
+            while dr.state == DagRunState.RUNNING:
+                session.expire_all()
+                schedulable_tis, _ = dr.update_state(session=session)
+                for s in schedulable_tis:
+                    if s.state != TaskInstanceState.UP_FOR_RESCHEDULE:
+                        s.try_number += 1
+                    s.state = TaskInstanceState.SCHEDULED
+                    s.scheduled_dttm = timezone.utcnow()
+                session.commit()
+                # triggerer may mark tasks scheduled so we read from DB
+                all_tis = set(dr.get_task_instances(session=session))
+                scheduled_tis = {x for x in all_tis if x.state == 
TaskInstanceState.SCHEDULED}
+                ids_unrunnable = {x for x in all_tis if x.state not in 
State.finished} - scheduled_tis
+                if not scheduled_tis and ids_unrunnable:
+                    log.warning("No tasks to run. unrunnable tasks: %s", 
ids_unrunnable)
+                    time.sleep(1)
+
+                for ti in scheduled_tis:
+                    ti.task = tasks[ti.task_id]
+
+                    mark_success = (
+                        re.compile(mark_success_pattern).fullmatch(ti.task_id) 
is not None
+                        if mark_success_pattern is not None
+                        else False
+                    )
+
+                    if use_executor:
+                        if executor.has_task(ti):
+                            continue
+
+                        from pathlib import Path
+
+                        from airflow.executors import workloads
+                        from airflow.executors.base_executor import 
ExecutorLoader
+                        from airflow.executors.workloads import BundleInfo
+
+                        workload = workloads.ExecuteTask.make(
+                            ti,
+                            dag_rel_path=Path(self.fileloc),
+                            generator=executor.jwt_generator,
+                            # For the system test/debug purpose, we use the 
default bundle which uses
+                            # local file system. If it turns out to be a 
feature people want, we could
+                            # plumb the Bundle to use as a parameter to 
dag.test
+                            bundle_info=BundleInfo(name="dags-folder"),
+                        )
+                        executor.queue_workload(workload, session=session)
+                        ti.state = TaskInstanceState.QUEUED
+                        session.commit()
+                    else:
+                        # Run the task locally
+                        try:
+                            add_logger_if_needed(ti)
+                            if mark_success:
+                                ti.set_state(State.SUCCESS)
+                                log.info("[DAG TEST] Marking success for %s on 
%s", ti.task, ti.logical_date)
+                            else:
+                                _run_task(ti=ti)
+                        except Exception:
+                            log.exception("Task failed; ti=%s", ti)
+                if use_executor:
+                    executor.heartbeat()
+                    from airflow.jobs.scheduler_job_runner import 
SchedulerDagBag, SchedulerJobRunner
+
+                    SchedulerJobRunner.process_executor_events(
+                        executor=executor, job_id=None, 
scheduler_dag_bag=SchedulerDagBag(), session=session
+                    )
+            if use_executor:
+                executor.end()
+        return dr
+
+
+def _run_task(*, ti):
+    """
+    Run a single task instance, and push result to Xcom for downstream tasks.
+
+    Bypasses a lot of extra steps used in `task.run` to keep our local running 
as fast as
+    possible.  This function is only meant for the `dag.test` function as a 
helper function.
+    """
+    from airflow.utils.module_loading import import_string
+    from airflow.utils.state import State
+
+    log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
+    while True:
+        try:
+            log.info("[DAG TEST] running task %s", ti)
+
+            from airflow.sdk.api.datamodels._generated import TaskInstance as 
TaskInstanceSDK
+            from airflow.sdk.execution_time.comms import DeferTask
+            from airflow.sdk.execution_time.supervisor import 
run_task_in_process
+
+            # The API Server expects the task instance to be in QUEUED state 
before
+            # it is run.
+            ti.set_state(State.QUEUED)
+
+            taskrun_result = run_task_in_process(
+                ti=TaskInstanceSDK(
+                    id=ti.id,
+                    task_id=ti.task_id,
+                    dag_id=ti.task.dag_id,
+                    run_id=ti.run_id,
+                    try_number=ti.try_number,
+                    map_index=ti.map_index,
+                ),
+                task=ti.task,
+            )
+
+            msg = taskrun_result.msg
+
+            if taskrun_result.ti.state == State.DEFERRED and isinstance(msg, 
DeferTask):
+                # API Server expects the task instance to be in QUEUED state 
before
+                # resuming from deferral.
+                ti.set_state(State.QUEUED)
+
+                log.info("[DAG TEST] running trigger in line")
+                trigger = import_string(msg.classpath)(**msg.trigger_kwargs)
+                event = _run_inline_trigger(trigger)
+                ti.next_method = msg.next_method
+                ti.next_kwargs = {"event": event.payload} if event else 
msg.kwargs
+                log.info("[DAG TEST] Trigger completed")
+
+                ti.set_state(State.SUCCESS)
+            break
+        except Exception:
+            log.exception("[DAG TEST] Error running task %s", ti)
+            if ti.state not in State.finished:
+                ti.set_state(State.FAILED)
+                break
+            raise
+
+    log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
+
+
+def _run_inline_trigger(trigger):
+    import asyncio
+
+    async def _run_inline_trigger_main():
+        # We can replace it with `return await anext(trigger.run(), 
default=None)`
+        # when we drop support for Python 3.9
+        try:
+            return await trigger.run().__anext__()
+        except StopAsyncIteration:
+            return None
+
+    return asyncio.run(_run_inline_trigger_main())
+
 
 # Since we define all the attributes of the class with attrs, we can compute 
this statically at parse time
 DAG._DAG__serialized_fields = frozenset(a.name for a in attrs.fields(DAG)) - { 
 # type: ignore[attr-defined]
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index b5cf977488b..0e7b7f54cd1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -27,8 +27,9 @@ import selectors
 import signal
 import sys
 import time
+from collections import deque
 from collections.abc import Generator
-from contextlib import suppress
+from contextlib import contextmanager, suppress
 from datetime import datetime, timezone
 from http import HTTPStatus
 from socket import SO_SNDBUF, SOL_SOCKET, SocketIO, socket, socketpair
@@ -42,6 +43,7 @@ from typing import (
 )
 from uuid import UUID
 
+import aiologic
 import attrs
 import httpx
 import msgspec
@@ -837,6 +839,15 @@ class ActivitySubprocess(WatchedSubprocess):
         # If it hasn't, assume it's failed
         self._exit_code = self._exit_code if self._exit_code is not None else 1
 
+        self.update_task_state_if_needed()
+
+        # Now at the last possible moment, when all logs and comms with the 
subprocess has finished, lets
+        # upload the remote logs
+        self._upload_logs()
+
+        return self._exit_code
+
+    def update_task_state_if_needed(self):
         # If the process has finished non-directly patched state (directly 
means deferred, reschedule, etc.),
         # update the state of the TaskInstance to reflect the final state of 
the process.
         # For states like `deferred`, `up_for_reschedule`, the process will 
exit with 0, but the state will be updated
@@ -849,12 +860,6 @@ class ActivitySubprocess(WatchedSubprocess):
                 rendered_map_index=self._rendered_map_index,
             )
 
-        # Now at the last possible moment, when all logs and comms with the 
subprocess has finished, lets
-        # upload the remote logs
-        self._upload_logs()
-
-        return self._exit_code
-
     def _upload_logs(self):
         """
         Upload all log files found to the remote storage.
@@ -1155,6 +1160,183 @@ class ActivitySubprocess(WatchedSubprocess):
             self.send_msg(resp, **dump_opts)
 
 
+def in_process_api_server():
+    from airflow.api_fastapi.execution_api.app import InProcessExecutionAPI
+
+    api = InProcessExecutionAPI()
+    return api
+
+
[email protected]
+class InProcessSupervisorComms:
+    """In-process communication handler that uses deques instead of sockets."""
+
+    supervisor: InProcessTestSupervisor
+    messages: deque[BaseModel] = attrs.field(factory=deque)
+    lock: aiologic.Lock = attrs.field(factory=aiologic.Lock)
+
+    def get_message(self) -> BaseModel:
+        """Get a message from the supervisor. Blocks until a message is 
available."""
+        return self.messages.popleft()
+
+    def send_request(self, log, msg: BaseModel):
+        """Send a request to the supervisor."""
+        log.debug("Sending request", msg=msg)
+
+        with set_supervisor_comms(None):
+            self.supervisor._handle_request(msg, log)  # type: ignore[arg-type]
+
+
[email protected]
+class TaskRunResult:
+    """Result of running a task via ``InProcessTestSupervisor``."""
+
+    ti: RuntimeTI
+    state: str
+    msg: BaseModel | None
+    error: BaseException | None
+
+
[email protected](kw_only=True)
+class InProcessTestSupervisor(ActivitySubprocess):
+    """A supervisor that runs tasks in-process for easier testing."""
+
+    comms: InProcessSupervisorComms = attrs.field(init=False)
+    stdin = attrs.field(init=False)
+
+    @classmethod
+    def start(  # type: ignore[override]
+        cls,
+        *,
+        what: TaskInstance,
+        task,
+        logger: FilteringBoundLogger | None = None,
+        **kwargs,
+    ) -> TaskRunResult:
+        """
+        Run a task in-process without spawning a new child process.
+
+        This bypasses the standard `ActivitySubprocess.start()` behavior, 
which expects
+        to launch a subprocess and communicate via stdin/stdout. Instead, it 
constructs
+        the `RuntimeTaskInstance` directly — useful in contexts like 
`dag.test()` where the
+        DAG is already parsed in memory.
+
+        Supervisor state and communications are simulated in-memory via 
`InProcessSupervisorComms`.
+        """
+        # Create supervisor instance
+        supervisor = cls(
+            id=what.id,
+            pid=os.getpid(),  # Use current process
+            process=psutil.Process(),  # Current process
+            requests_fd=-1,  # Not used in in-process mode
+            process_log=logger or 
structlog.get_logger(logger_name="task").bind(),
+            client=cls._api_client(task.dag),
+            **kwargs,
+        )
+
+        from airflow.sdk.execution_time.task_runner import 
RuntimeTaskInstance, finalize, run
+
+        supervisor.comms = InProcessSupervisorComms(supervisor=supervisor)
+        with set_supervisor_comms(supervisor.comms):
+            supervisor.ti = what  # type: ignore[assignment]
+
+            # We avoid calling `task_runner.startup()` because we are already 
inside a
+            # parsed DAG file (e.g. via dag.test()).
+            # In normal execution, `startup()` parses the DAG based on info in 
a `StartupDetails` message.
+            # By directly constructing the `RuntimeTaskInstance`,
+            #   we skip re-parsing (`task_runner.parse()`) and avoid needing 
to set DAG Bundle config
+            #   and run the task in-process.
+            start_date = datetime.now(tz=timezone.utc)
+            ti_context = supervisor.client.task_instances.start(supervisor.id, 
supervisor.pid, start_date)
+
+            ti = RuntimeTaskInstance.model_construct(
+                **what.model_dump(exclude_unset=True),
+                task=task,
+                _ti_context_from_server=ti_context,
+                max_tries=ti_context.max_tries,
+                start_date=start_date,
+                state=TaskInstanceState.RUNNING,
+            )
+            context = ti.get_template_context()
+            log = structlog.get_logger(logger_name="task")
+
+            state, msg, error = run(ti, context, log)
+            finalize(ti, state, context, log, error)
+
+            # In the normal subprocess model, the task runner calls this 
before exiting.
+            # Since we're running in-process, we manually notify the API 
server that
+            # the task has finished—unless the terminal state was already sent 
explicitly.
+            supervisor.update_task_state_if_needed()
+
+        return TaskRunResult(ti=ti, state=state, msg=msg, error=error)
+
+    @staticmethod
+    def _api_client(dag=None):
+        from airflow.models.dagbag import DagBag
+        from airflow.sdk.api.client import Client
+
+        api = in_process_api_server()
+        if dag is not None:
+            from airflow.api_fastapi.common.deps import _get_dag_bag
+            from airflow.serialization.serialized_objects import SerializedDAG
+
+            # This is needed since the Execution API server uses the DagBag in 
its "state".
+            # This `app.state.dag_bag` is used to get some DAG properties like 
`fail_fast`.
+            dag_bag = DagBag(include_examples=False, collect_dags=False, 
load_op_links=False)
+
+            # Mimic the behavior of the DagBag in the API server by converting 
the DAG to a SerializedDAG
+            dag_bag.dags[dag.dag_id] = 
SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+            api.app.dependency_overrides[_get_dag_bag] = lambda: dag_bag
+
+        client = Client(base_url=None, token="", dry_run=True, 
transport=api.transport)
+        # Mypy is wrong -- the setter accepts a string on the property setter! 
`URLType = URL | str`
+        client.base_url = "http://in-process.invalid./";  # type: 
ignore[assignment]
+        return client
+
+    def send_msg(self, msg: BaseModel, **dump_opts):
+        """Override to use in-process comms."""
+        self.comms.messages.append(msg)
+
+    @property
+    def final_state(self):
+        """Override to use in-process comms."""
+        # Since we're running in-process, we don't have a final state until 
the task has finished.
+        # We also don't have a process exit code to determine success/failure.
+        return self._terminal_state
+
+
+@contextmanager
+def set_supervisor_comms(temp_comms):
+    """
+    Temporarily override `SUPERVISOR_COMMS` in the `task_runner` module.
+
+    This is used to simulate task-runner ↔ supervisor communication in-process,
+    by injecting a test Comms implementation (e.g. `InProcessSupervisorComms`)
+    in place of the real inter-process communication layer.
+
+    Some parts of the code (e.g. models.Variable.get) check for the presence
+    of `task_runner.SUPERVISOR_COMMS` to determine if the code is running in a 
Task SDK execution context.
+    This override ensures those code paths behave correctly during in-process 
tests.
+    """
+    from airflow.sdk.execution_time import task_runner
+
+    old = getattr(task_runner, "SUPERVISOR_COMMS", None)
+    task_runner.SUPERVISOR_COMMS = temp_comms
+    try:
+        yield
+    finally:
+        if old is not None:
+            task_runner.SUPERVISOR_COMMS = old
+        else:
+            delattr(task_runner, "SUPERVISOR_COMMS")
+
+
+def run_task_in_process(ti: TaskInstance, task) -> TaskRunResult:
+    """Run a task in-process for testing."""
+    # Run the task
+    return InProcessTestSupervisor.start(what=ti, task=task)
+
+
 # Sockets, even the `.makefile()` function don't correctly do line buffering 
on reading. If a chunk is read
 # and it doesn't contain a new line character, `.readline()` will just return 
the chunk as is.
 #
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 9092ee86f0b..66a3d02cd8c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -835,7 +835,7 @@ def run(
                 return state, msg, error
 
             try:
-                result = _execute_task(context, ti, log)
+                result = _execute_task(context=context, ti=ti, log=log)
             except Exception:
                 import jinja2
 
@@ -886,7 +886,7 @@ def run(
         )
         state = TaskInstanceState.FAILED
         error = e
-    except (AirflowTaskTimeout, AirflowException) as e:
+    except (AirflowTaskTimeout, AirflowException, AirflowRuntimeError) as e:
         # We should allow retries if the task has defined it.
         log.exception("Task failed with exception")
         msg, state = _handle_current_task_failed(ti)

Reply via email to