This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 4a16537364e69c69fbd394294e07b05cb226dafc
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Sep 17 00:19:46 2025 +0800

    Remove no-longer-needed execution interface hacks (#55681)
    
    (cherry picked from commit 8e355843228533f9743590e957efb97902a8d3a8)
---
 airflow-core/src/airflow/api/common/mark_tasks.py  | 14 ++-------
 .../src/airflow/cli/commands/task_command.py       | 35 ++++++++++++----------
 airflow-core/src/airflow/models/taskinstance.py    | 13 +++++---
 .../airflow/serialization/serialized_objects.py    |  7 ++---
 airflow-core/src/airflow/utils/cli.py              |  4 +--
 airflow-core/tests/unit/models/test_cleartasks.py  |  4 +--
 .../tests/unit/models/test_taskinstance.py         |  8 ++---
 task-sdk/src/airflow/sdk/definitions/dag.py        | 24 ++++++++++-----
 task-sdk/src/airflow/sdk/types.py                  |  2 ++
 9 files changed, 59 insertions(+), 52 deletions(-)

diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py 
b/airflow-core/src/airflow/api/common/mark_tasks.py
index d424bab603a..5c0ed4b9f5f 100644
--- a/airflow-core/src/airflow/api/common/mark_tasks.py
+++ b/airflow-core/src/airflow/api/common/mark_tasks.py
@@ -20,7 +20,7 @@
 from __future__ import annotations
 
 from collections.abc import Collection, Iterable
-from typing import TYPE_CHECKING, TypeAlias, cast
+from typing import TYPE_CHECKING, TypeAlias
 
 from sqlalchemy import and_, or_, select
 from sqlalchemy.orm import lazyload
@@ -228,9 +228,7 @@ def set_dag_run_state_to_success(
     if not run_id:
         raise ValueError(f"Invalid dag_run_id: {run_id}")
 
-    # TODO (GH-52141): 'tasks' in scheduler needs to return scheduler types
-    # instead, but currently it inherits SDK's DAG.
-    tasks = cast("list[Operator]", dag.tasks)
+    tasks = dag.tasks
 
     # Mark all task instances of the dag run to success - except for 
unfinished teardown as they need to complete work.
     teardown_tasks = [task for task in tasks if task.is_teardown]
@@ -312,13 +310,7 @@ def set_dag_run_state_to_failed(
         task.dag = dag
         return task
 
-    # TODO (GH-52141): 'tasks' in scheduler needs to return scheduler types
-    # instead, but currently it inherits SDK's DAG.
-    running_tasks = [
-        _set_runing_task(task)
-        for task in cast("list[Operator]", dag.tasks)
-        if task.task_id in task_ids_of_running_tis
-    ]
+    running_tasks = [_set_runing_task(task) for task in dag.tasks if 
task.task_id in task_ids_of_running_tis]
 
     # Mark non-finished tasks as SKIPPED.
     pending_tis: list[TaskInstance] = session.scalars(
diff --git a/airflow-core/src/airflow/cli/commands/task_command.py 
b/airflow-core/src/airflow/cli/commands/task_command.py
index b19ea6161a2..9b4cd4114f4 100644
--- a/airflow-core/src/airflow/cli/commands/task_command.py
+++ b/airflow-core/src/airflow/cli/commands/task_command.py
@@ -46,6 +46,7 @@ from airflow.utils.cli import (
     get_bagged_dag,
     get_dag_by_file_location,
     get_dags,
+    get_db_dag,
     suppress_logs_and_warning,
 )
 from airflow.utils.helpers import ask_yesno
@@ -82,7 +83,7 @@ def _generate_temporary_run_id() -> str:
 
 def _get_dag_run(
     *,
-    dag: DAG | SerializedDAG,
+    dag: SerializedDAG,
     create_if_necessary: CreateIfNecessary,
     logical_date_or_run_id: str | None = None,
     session: Session | None = None,
@@ -144,9 +145,8 @@ def _get_dag_run(
         )
         return dag_run, True
     if create_if_necessary == "db":
-        scheduler_dag = 
SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag))  # type: 
ignore[arg-type]
         dag_run = get_or_create_dagrun(
-            dag=scheduler_dag,
+            dag=dag,
             run_id=_generate_temporary_run_id(),
             logical_date=dag_run_logical_date,
             data_interval=data_interval,
@@ -246,10 +246,7 @@ def task_failed_deps(args) -> None:
     Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks
     to have succeeded, but found 1 non-success(es).
     """
-    dag = get_bagged_dag(args.bundle_name, args.dag_id)
-    # TODO (GH-52141): get_task in scheduler needs to return scheduler types
-    # instead, but currently it inherits SDK's DAG.
-    task = cast("Operator", dag.get_task(task_id=args.task_id))
+    task = get_db_dag(args.bundle_name, 
args.dag_id).get_task(task_id=args.task_id)
     ti, _ = _get_ti(task, args.map_index, 
logical_date_or_run_id=args.logical_date_or_run_id)
     dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
     failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
@@ -387,29 +384,35 @@ def task_test(args, dag: DAG | None = None) -> None:
         env_vars.update(args.env_vars)
         os.environ.update(env_vars)
 
-    dag = dag or get_bagged_dag(args.bundle_name, args.dag_id)
+    if dag:
+        sdk_dag = dag
+        scheduler_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+    else:
+        sdk_dag = get_bagged_dag(args.bundle_name, args.dag_id)
+        scheduler_dag = get_db_dag(args.bundle_name, args.dag_id)
 
-    # TODO (GH-52141): get_task in scheduler needs to return scheduler types
-    # instead, but currently it inherits SDK's DAG.
-    task = cast("Operator", dag.get_task(task_id=args.task_id))
+    sdk_task = sdk_dag.get_task(args.task_id)
 
     # Add CLI provided task_params to task.params
     if args.task_params:
         passed_in_params = json.loads(args.task_params)
-        task.params.update(passed_in_params)
+        sdk_task.params.update(passed_in_params)
 
-    if task.params and isinstance(task.params, ParamsDict):
-        task.params.validate()
+    if sdk_task.params and isinstance(sdk_task.params, ParamsDict):
+        sdk_task.params.validate()
 
     ti, dr_created = _get_ti(
-        task, args.map_index, 
logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="db"
+        scheduler_dag.get_task(args.task_id),
+        args.map_index,
+        logical_date_or_run_id=args.logical_date_or_run_id,
+        create_if_necessary="db",
     )
     try:
         # TODO: move bulk of this logic into the SDK: 
http://github.com/apache/airflow/issues/54658
         from airflow.sdk._shared.secrets_masker import RedactedIO
 
         with redirect_stdout(RedactedIO()):
-            _run_task(ti=ti, task=task, run_triggerer=True)
+            _run_task(ti=ti, task=sdk_task, run_triggerer=True)
         if ti.state == State.FAILED and args.post_mortem:
             debugger = _guess_debugger()
             debugger.set_trace()
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 72a5ab0b24b..5f3eb91f865 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -1075,8 +1075,6 @@ class TaskInstance(Base, LoggingMixin):
 
         ti: TaskInstance = task_instance
         task = task_instance.task
-        if TYPE_CHECKING:
-            assert isinstance(task, Operator)  # TODO (GH-52141): This 
shouldn't be needed.
         ti.refresh_from_task(task, pool_override=pool)
         ti.test_mode = test_mode
         ti.refresh_from_db(session=session, lock_for_update=True)
@@ -1276,9 +1274,16 @@ class TaskInstance(Base, LoggingMixin):
             log.info("[DAG TEST] Marking success for %s ", self.task_id)
             return None
 
-        taskrun_result = _run_task(ti=self, task=self.task)
-        if taskrun_result is not None and taskrun_result.error:
+        # TODO (TaskSDK): This is the old ti execution path. The only usage is
+        # in TI.run(...), someone needs to analyse if it's still actually used
+        # somewhere and fix it, likely by rewriting TI.run(...) to use the same
+        # mechanism as Operator.test().
+        taskrun_result = _run_task(ti=self, task=self.task)  # type: 
ignore[arg-type]
+        if taskrun_result is None:
+            return None
+        if taskrun_result.error:
             raise taskrun_result.error
+        self.task = taskrun_result.ti.task  # type: ignore[assignment]
         return None
 
     @staticmethod
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index 37c09b4a925..caa40ce93dd 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -1148,10 +1148,6 @@ class DependencyDetector:
         from airflow.providers.standard.operators.trigger_dagrun import 
TriggerDagRunOperator
         from airflow.providers.standard.sensors.external_task import 
ExternalTaskSensor
 
-        # TODO (GH-52141): Separate MappedOperator implementation to get rid 
of this.
-        if TYPE_CHECKING:
-            assert isinstance(task.operator_class, type)
-
         deps = []
         if isinstance(task, TriggerDagRunOperator):
             deps.append(
@@ -1409,7 +1405,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization):
         link = self.operator_extra_link_dict.get(name) or 
self.global_operator_extra_link_dict.get(name)
         if not link:
             return None
-        return link.get_link(self, ti_key=ti.key)  # type: ignore[arg-type] # 
TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives 
SerializedBaseOperator
+        # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but 
receives SerializedBaseOperator.
+        return link.get_link(self, ti_key=ti.key)  # type: ignore[arg-type]
 
     @property
     def operator_name(self) -> str:
diff --git a/airflow-core/src/airflow/utils/cli.py 
b/airflow-core/src/airflow/utils/cli.py
index 8fef958ac0f..b6423c5af3a 100644
--- a/airflow-core/src/airflow/utils/cli.py
+++ b/airflow-core/src/airflow/utils/cli.py
@@ -300,7 +300,7 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, 
dagfile_path: str | N
     )
 
 
-def _get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | 
None = None) -> SerializedDAG:
+def get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | 
None = None) -> SerializedDAG:
     """
     Return DAG of a given dag_id.
 
@@ -321,7 +321,7 @@ def get_dags(bundle_names: list | None, dag_id: str, 
use_regex: bool = False, fr
 
     if not use_regex:
         if from_db:
-            return [_get_db_dag(bundle_names=bundle_names, dag_id=dag_id)]
+            return [get_db_dag(bundle_names=bundle_names, dag_id=dag_id)]
         return [get_bagged_dag(bundle_names=bundle_names, dag_id=dag_id)]
 
     def _find_dag(bundle):
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py 
b/airflow-core/tests/unit/models/test_cleartasks.py
index c44cdf5635e..9a1f37c89ca 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -633,11 +633,11 @@ class TestClearTasks:
             assert ti.max_tries == 1
 
         # test dry_run
-        for i in range(num_of_dags):
+        for i, dag in enumerate(dags):
             ti = _get_ti(tis[i])
             ti.try_number += 1
             session.commit()
-            ti.refresh_from_task(tis[i].task)
+            ti.refresh_from_task(dag.get_task(ti.task_id))
             ti.run(session=session)
             assert ti.state == State.SUCCESS
             assert ti.try_number == 2
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 0df775ca41e..315d3c4cc7a 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -488,12 +488,12 @@ class TestTaskInstance:
             )
 
         def run_with_error(ti):
+            orig_task, ti.task = ti.task, task
             with contextlib.suppress(AirflowException):
                 ti.run()
+            ti.task = orig_task
 
         ti = 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
-        ti.task = task
-
         with create_session() as session:
             session.get(TaskInstance, ti.id).try_number += 1
 
@@ -539,13 +539,13 @@ class TestTaskInstance:
             )
 
         def run_with_error(ti):
+            orig_task, ti.task = ti.task, task
             with contextlib.suppress(AirflowException):
                 ti.run()
+            ti.task = orig_task
 
         ti = 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
-        ti.task = task
         assert ti.try_number == 0
-
         session.get(TaskInstance, ti.id).try_number += 1
         session.commit()
 
diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py 
b/task-sdk/src/airflow/sdk/definitions/dag.py
index 7448b321389..6523dead760 100644
--- a/task-sdk/src/airflow/sdk/definitions/dag.py
+++ b/task-sdk/src/airflow/sdk/definitions/dag.py
@@ -66,10 +66,12 @@ if TYPE_CHECKING:
 
     from pendulum.tz.timezone import FixedTimezone, Timezone
 
+    from airflow.models.taskinstance import TaskInstance as 
SchedulerTaskInstance
     from airflow.sdk.definitions.decorators import TaskDecoratorCollection
     from airflow.sdk.definitions.edges import EdgeInfoType
     from airflow.sdk.definitions.mappedoperator import MappedOperator
     from airflow.sdk.definitions.taskgroup import TaskGroup
+    from airflow.sdk.execution_time.supervisor import TaskRunResult
     from airflow.typing_compat import Self
 
     Operator: TypeAlias = BaseOperator | MappedOperator
@@ -1304,7 +1306,12 @@ class DAG:
         return dr
 
 
-def _run_task(*, ti, task, run_triggerer=False):
+def _run_task(
+    *,
+    ti: SchedulerTaskInstance,
+    task: Operator,
+    run_triggerer: bool = False,
+) -> TaskRunResult | None:
     """
     Run a single task instance, and push result to Xcom for downstream tasks.
 
@@ -1314,6 +1321,7 @@ def _run_task(*, ti, task, run_triggerer=False):
     from airflow.sdk.module_loading import import_string
     from airflow.utils.state import State
 
+    taskrun_result: TaskRunResult | None
     log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
     while True:
         try:
@@ -1322,6 +1330,7 @@ def _run_task(*, ti, task, run_triggerer=False):
             from airflow.sdk.api.datamodels._generated import TaskInstance as 
TaskInstanceSDK
             from airflow.sdk.execution_time.comms import DeferTask
             from airflow.sdk.execution_time.supervisor import 
run_task_in_process
+            from airflow.serialization.serialized_objects import 
create_scheduler_operator
 
             # The API Server expects the task instance to be in QUEUED state 
before
             # it is run.
@@ -1336,14 +1345,10 @@ def _run_task(*, ti, task, run_triggerer=False):
                 dag_version_id=ti.dag_version_id,
             )
 
-            taskrun_result = run_task_in_process(
-                ti=task_sdk_ti,
-                task=task,
-            )
-
+            taskrun_result = run_task_in_process(ti=task_sdk_ti, task=task)
             msg = taskrun_result.msg
             ti.set_state(taskrun_result.ti.state)
-            ti.task = taskrun_result.ti.task
+            ti.task = create_scheduler_operator(taskrun_result.ti.task)
 
             if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and 
run_triggerer:
                 from airflow.utils.session import create_session
@@ -1363,16 +1368,19 @@ def _run_task(*, ti, task, run_triggerer=False):
                 with create_session() as session:
                     ti.state = State.SCHEDULED
                     session.add(ti)
+                continue
 
-            return taskrun_result
+            break
         except Exception:
             log.exception("[DAG TEST] Error running task %s", ti)
             if ti.state not in State.finished:
                 ti.set_state(State.FAILED)
+                taskrun_result = None
                 break
             raise
 
     log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, 
ti.map_index)
+    return taskrun_result
 
 
 def _run_inline_trigger(trigger, task_sdk_ti):
diff --git a/task-sdk/src/airflow/sdk/types.py 
b/task-sdk/src/airflow/sdk/types.py
index cfbcafe4201..c7084629085 100644
--- a/task-sdk/src/airflow/sdk/types.py
+++ b/task-sdk/src/airflow/sdk/types.py
@@ -30,6 +30,7 @@ if TYPE_CHECKING:
     from pydantic import AwareDatetime
 
     from airflow.sdk._shared.logging.types import Logger as Logger
+    from airflow.sdk.api.datamodels._generated import TaskInstanceState
     from airflow.sdk.bases.operator import BaseOperator
     from airflow.sdk.definitions.asset import Asset, AssetAlias, 
AssetAliasEvent, AssetRef, BaseAssetUniqueKey
     from airflow.sdk.definitions.context import Context
@@ -68,6 +69,7 @@ class RuntimeTaskInstanceProtocol(Protocol):
     hostname: str | None = None
     start_date: AwareDatetime
     end_date: AwareDatetime | None = None
+    state: TaskInstanceState | None = None
 
     def xcom_pull(
         self,

Reply via email to