This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 4a16537364e69c69fbd394294e07b05cb226dafc Author: Tzu-ping Chung <[email protected]> AuthorDate: Wed Sep 17 00:19:46 2025 +0800 Remove no-longer-needed execution interface hacks (#55681) (cherry picked from commit 8e355843228533f9743590e957efb97902a8d3a8) --- airflow-core/src/airflow/api/common/mark_tasks.py | 14 ++------- .../src/airflow/cli/commands/task_command.py | 35 ++++++++++++---------- airflow-core/src/airflow/models/taskinstance.py | 13 +++++--- .../airflow/serialization/serialized_objects.py | 7 ++--- airflow-core/src/airflow/utils/cli.py | 4 +-- airflow-core/tests/unit/models/test_cleartasks.py | 4 +-- .../tests/unit/models/test_taskinstance.py | 8 ++--- task-sdk/src/airflow/sdk/definitions/dag.py | 24 ++++++++++----- task-sdk/src/airflow/sdk/types.py | 2 ++ 9 files changed, 59 insertions(+), 52 deletions(-) diff --git a/airflow-core/src/airflow/api/common/mark_tasks.py b/airflow-core/src/airflow/api/common/mark_tasks.py index d424bab603a..5c0ed4b9f5f 100644 --- a/airflow-core/src/airflow/api/common/mark_tasks.py +++ b/airflow-core/src/airflow/api/common/mark_tasks.py @@ -20,7 +20,7 @@ from __future__ import annotations from collections.abc import Collection, Iterable -from typing import TYPE_CHECKING, TypeAlias, cast +from typing import TYPE_CHECKING, TypeAlias from sqlalchemy import and_, or_, select from sqlalchemy.orm import lazyload @@ -228,9 +228,7 @@ def set_dag_run_state_to_success( if not run_id: raise ValueError(f"Invalid dag_run_id: {run_id}") - # TODO (GH-52141): 'tasks' in scheduler needs to return scheduler types - # instead, but currently it inherits SDK's DAG. - tasks = cast("list[Operator]", dag.tasks) + tasks = dag.tasks # Mark all task instances of the dag run to success - except for unfinished teardown as they need to complete work. teardown_tasks = [task for task in tasks if task.is_teardown] @@ -312,13 +310,7 @@ def set_dag_run_state_to_failed( task.dag = dag return task - # TODO (GH-52141): 'tasks' in scheduler needs to return scheduler types - # instead, but currently it inherits SDK's DAG. - running_tasks = [ - _set_runing_task(task) - for task in cast("list[Operator]", dag.tasks) - if task.task_id in task_ids_of_running_tis - ] + running_tasks = [_set_runing_task(task) for task in dag.tasks if task.task_id in task_ids_of_running_tis] # Mark non-finished tasks as SKIPPED. pending_tis: list[TaskInstance] = session.scalars( diff --git a/airflow-core/src/airflow/cli/commands/task_command.py b/airflow-core/src/airflow/cli/commands/task_command.py index b19ea6161a2..9b4cd4114f4 100644 --- a/airflow-core/src/airflow/cli/commands/task_command.py +++ b/airflow-core/src/airflow/cli/commands/task_command.py @@ -46,6 +46,7 @@ from airflow.utils.cli import ( get_bagged_dag, get_dag_by_file_location, get_dags, + get_db_dag, suppress_logs_and_warning, ) from airflow.utils.helpers import ask_yesno @@ -82,7 +83,7 @@ def _generate_temporary_run_id() -> str: def _get_dag_run( *, - dag: DAG | SerializedDAG, + dag: SerializedDAG, create_if_necessary: CreateIfNecessary, logical_date_or_run_id: str | None = None, session: Session | None = None, @@ -144,9 +145,8 @@ def _get_dag_run( ) return dag_run, True if create_if_necessary == "db": - scheduler_dag = SerializedDAG.deserialize_dag(SerializedDAG.serialize_dag(dag)) # type: ignore[arg-type] dag_run = get_or_create_dagrun( - dag=scheduler_dag, + dag=dag, run_id=_generate_temporary_run_id(), logical_date=dag_run_logical_date, data_interval=data_interval, @@ -246,10 +246,7 @@ def task_failed_deps(args) -> None: Trigger Rule: Task's trigger rule 'all_success' requires all upstream tasks to have succeeded, but found 1 non-success(es). """ - dag = get_bagged_dag(args.bundle_name, args.dag_id) - # TODO (GH-52141): get_task in scheduler needs to return scheduler types - # instead, but currently it inherits SDK's DAG. - task = cast("Operator", dag.get_task(task_id=args.task_id)) + task = get_db_dag(args.bundle_name, args.dag_id).get_task(task_id=args.task_id) ti, _ = _get_ti(task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id) dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS) failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context)) @@ -387,29 +384,35 @@ def task_test(args, dag: DAG | None = None) -> None: env_vars.update(args.env_vars) os.environ.update(env_vars) - dag = dag or get_bagged_dag(args.bundle_name, args.dag_id) + if dag: + sdk_dag = dag + scheduler_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag)) + else: + sdk_dag = get_bagged_dag(args.bundle_name, args.dag_id) + scheduler_dag = get_db_dag(args.bundle_name, args.dag_id) - # TODO (GH-52141): get_task in scheduler needs to return scheduler types - # instead, but currently it inherits SDK's DAG. - task = cast("Operator", dag.get_task(task_id=args.task_id)) + sdk_task = sdk_dag.get_task(args.task_id) # Add CLI provided task_params to task.params if args.task_params: passed_in_params = json.loads(args.task_params) - task.params.update(passed_in_params) + sdk_task.params.update(passed_in_params) - if task.params and isinstance(task.params, ParamsDict): - task.params.validate() + if sdk_task.params and isinstance(sdk_task.params, ParamsDict): + sdk_task.params.validate() ti, dr_created = _get_ti( - task, args.map_index, logical_date_or_run_id=args.logical_date_or_run_id, create_if_necessary="db" + scheduler_dag.get_task(args.task_id), + args.map_index, + logical_date_or_run_id=args.logical_date_or_run_id, + create_if_necessary="db", ) try: # TODO: move bulk of this logic into the SDK: http://github.com/apache/airflow/issues/54658 from airflow.sdk._shared.secrets_masker import RedactedIO with redirect_stdout(RedactedIO()): - _run_task(ti=ti, task=task, run_triggerer=True) + _run_task(ti=ti, task=sdk_task, run_triggerer=True) if ti.state == State.FAILED and args.post_mortem: debugger = _guess_debugger() debugger.set_trace() diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 72a5ab0b24b..5f3eb91f865 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -1075,8 +1075,6 @@ class TaskInstance(Base, LoggingMixin): ti: TaskInstance = task_instance task = task_instance.task - if TYPE_CHECKING: - assert isinstance(task, Operator) # TODO (GH-52141): This shouldn't be needed. ti.refresh_from_task(task, pool_override=pool) ti.test_mode = test_mode ti.refresh_from_db(session=session, lock_for_update=True) @@ -1276,9 +1274,16 @@ class TaskInstance(Base, LoggingMixin): log.info("[DAG TEST] Marking success for %s ", self.task_id) return None - taskrun_result = _run_task(ti=self, task=self.task) - if taskrun_result is not None and taskrun_result.error: + # TODO (TaskSDK): This is the old ti execution path. The only usage is + # in TI.run(...), someone needs to analyse if it's still actually used + # somewhere and fix it, likely by rewriting TI.run(...) to use the same + # mechanism as Operator.test(). + taskrun_result = _run_task(ti=self, task=self.task) # type: ignore[arg-type] + if taskrun_result is None: + return None + if taskrun_result.error: raise taskrun_result.error + self.task = taskrun_result.ti.task # type: ignore[assignment] return None @staticmethod diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py b/airflow-core/src/airflow/serialization/serialized_objects.py index 37c09b4a925..caa40ce93dd 100644 --- a/airflow-core/src/airflow/serialization/serialized_objects.py +++ b/airflow-core/src/airflow/serialization/serialized_objects.py @@ -1148,10 +1148,6 @@ class DependencyDetector: from airflow.providers.standard.operators.trigger_dagrun import TriggerDagRunOperator from airflow.providers.standard.sensors.external_task import ExternalTaskSensor - # TODO (GH-52141): Separate MappedOperator implementation to get rid of this. - if TYPE_CHECKING: - assert isinstance(task.operator_class, type) - deps = [] if isinstance(task, TriggerDagRunOperator): deps.append( @@ -1409,7 +1405,8 @@ class SerializedBaseOperator(DAGNode, BaseSerialization): link = self.operator_extra_link_dict.get(name) or self.global_operator_extra_link_dict.get(name) if not link: return None - return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives SerializedBaseOperator + # TODO: GH-52141 - BaseOperatorLink.get_link expects BaseOperator but receives SerializedBaseOperator. + return link.get_link(self, ti_key=ti.key) # type: ignore[arg-type] @property def operator_name(self) -> str: diff --git a/airflow-core/src/airflow/utils/cli.py b/airflow-core/src/airflow/utils/cli.py index 8fef958ac0f..b6423c5af3a 100644 --- a/airflow-core/src/airflow/utils/cli.py +++ b/airflow-core/src/airflow/utils/cli.py @@ -300,7 +300,7 @@ def get_bagged_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | N ) -def _get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | None = None) -> SerializedDAG: +def get_db_dag(bundle_names: list | None, dag_id: str, dagfile_path: str | None = None) -> SerializedDAG: """ Return DAG of a given dag_id. @@ -321,7 +321,7 @@ def get_dags(bundle_names: list | None, dag_id: str, use_regex: bool = False, fr if not use_regex: if from_db: - return [_get_db_dag(bundle_names=bundle_names, dag_id=dag_id)] + return [get_db_dag(bundle_names=bundle_names, dag_id=dag_id)] return [get_bagged_dag(bundle_names=bundle_names, dag_id=dag_id)] def _find_dag(bundle): diff --git a/airflow-core/tests/unit/models/test_cleartasks.py b/airflow-core/tests/unit/models/test_cleartasks.py index c44cdf5635e..9a1f37c89ca 100644 --- a/airflow-core/tests/unit/models/test_cleartasks.py +++ b/airflow-core/tests/unit/models/test_cleartasks.py @@ -633,11 +633,11 @@ class TestClearTasks: assert ti.max_tries == 1 # test dry_run - for i in range(num_of_dags): + for i, dag in enumerate(dags): ti = _get_ti(tis[i]) ti.try_number += 1 session.commit() - ti.refresh_from_task(tis[i].task) + ti.refresh_from_task(dag.get_task(ti.task_id)) ti.run(session=session) assert ti.state == State.SUCCESS assert ti.try_number == 2 diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 0df775ca41e..315d3c4cc7a 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -488,12 +488,12 @@ class TestTaskInstance: ) def run_with_error(ti): + orig_task, ti.task = ti.task, task with contextlib.suppress(AirflowException): ti.run() + ti.task = orig_task ti = dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0] - ti.task = task - with create_session() as session: session.get(TaskInstance, ti.id).try_number += 1 @@ -539,13 +539,13 @@ class TestTaskInstance: ) def run_with_error(ti): + orig_task, ti.task = ti.task, task with contextlib.suppress(AirflowException): ti.run() + ti.task = orig_task ti = dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0] - ti.task = task assert ti.try_number == 0 - session.get(TaskInstance, ti.id).try_number += 1 session.commit() diff --git a/task-sdk/src/airflow/sdk/definitions/dag.py b/task-sdk/src/airflow/sdk/definitions/dag.py index 7448b321389..6523dead760 100644 --- a/task-sdk/src/airflow/sdk/definitions/dag.py +++ b/task-sdk/src/airflow/sdk/definitions/dag.py @@ -66,10 +66,12 @@ if TYPE_CHECKING: from pendulum.tz.timezone import FixedTimezone, Timezone + from airflow.models.taskinstance import TaskInstance as SchedulerTaskInstance from airflow.sdk.definitions.decorators import TaskDecoratorCollection from airflow.sdk.definitions.edges import EdgeInfoType from airflow.sdk.definitions.mappedoperator import MappedOperator from airflow.sdk.definitions.taskgroup import TaskGroup + from airflow.sdk.execution_time.supervisor import TaskRunResult from airflow.typing_compat import Self Operator: TypeAlias = BaseOperator | MappedOperator @@ -1304,7 +1306,12 @@ class DAG: return dr -def _run_task(*, ti, task, run_triggerer=False): +def _run_task( + *, + ti: SchedulerTaskInstance, + task: Operator, + run_triggerer: bool = False, +) -> TaskRunResult | None: """ Run a single task instance, and push result to Xcom for downstream tasks. @@ -1314,6 +1321,7 @@ def _run_task(*, ti, task, run_triggerer=False): from airflow.sdk.module_loading import import_string from airflow.utils.state import State + taskrun_result: TaskRunResult | None log.info("[DAG TEST] starting task_id=%s map_index=%s", ti.task_id, ti.map_index) while True: try: @@ -1322,6 +1330,7 @@ def _run_task(*, ti, task, run_triggerer=False): from airflow.sdk.api.datamodels._generated import TaskInstance as TaskInstanceSDK from airflow.sdk.execution_time.comms import DeferTask from airflow.sdk.execution_time.supervisor import run_task_in_process + from airflow.serialization.serialized_objects import create_scheduler_operator # The API Server expects the task instance to be in QUEUED state before # it is run. @@ -1336,14 +1345,10 @@ def _run_task(*, ti, task, run_triggerer=False): dag_version_id=ti.dag_version_id, ) - taskrun_result = run_task_in_process( - ti=task_sdk_ti, - task=task, - ) - + taskrun_result = run_task_in_process(ti=task_sdk_ti, task=task) msg = taskrun_result.msg ti.set_state(taskrun_result.ti.state) - ti.task = taskrun_result.ti.task + ti.task = create_scheduler_operator(taskrun_result.ti.task) if ti.state == State.DEFERRED and isinstance(msg, DeferTask) and run_triggerer: from airflow.utils.session import create_session @@ -1363,16 +1368,19 @@ def _run_task(*, ti, task, run_triggerer=False): with create_session() as session: ti.state = State.SCHEDULED session.add(ti) + continue - return taskrun_result + break except Exception: log.exception("[DAG TEST] Error running task %s", ti) if ti.state not in State.finished: ti.set_state(State.FAILED) + taskrun_result = None break raise log.info("[DAG TEST] end task task_id=%s map_index=%s", ti.task_id, ti.map_index) + return taskrun_result def _run_inline_trigger(trigger, task_sdk_ti): diff --git a/task-sdk/src/airflow/sdk/types.py b/task-sdk/src/airflow/sdk/types.py index cfbcafe4201..c7084629085 100644 --- a/task-sdk/src/airflow/sdk/types.py +++ b/task-sdk/src/airflow/sdk/types.py @@ -30,6 +30,7 @@ if TYPE_CHECKING: from pydantic import AwareDatetime from airflow.sdk._shared.logging.types import Logger as Logger + from airflow.sdk.api.datamodels._generated import TaskInstanceState from airflow.sdk.bases.operator import BaseOperator from airflow.sdk.definitions.asset import Asset, AssetAlias, AssetAliasEvent, AssetRef, BaseAssetUniqueKey from airflow.sdk.definitions.context import Context @@ -68,6 +69,7 @@ class RuntimeTaskInstanceProtocol(Protocol): hostname: str | None = None start_date: AwareDatetime end_date: AwareDatetime | None = None + state: TaskInstanceState | None = None def xcom_pull( self,
