This is an automated email from the ASF dual-hosted git repository. ephraimanierobi pushed a commit to branch v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 76b9aa61e2732c55715890fd50f4ccda5a6645e5 Author: Ephraim Anierobi <[email protected]> AuthorDate: Mon Dec 8 09:22:09 2025 +0100 Revert "Find only relevant up/downstream tis when clearing (#57758) (#58987)" This reverts commit 8918f98665f149a454e14c0a8cc7d2b09474d93c. --- .../core_api/routes/public/task_instances.py | 60 ++--- .../example_dags/example_dynamic_task_mapping.py | 32 +-- airflow-core/src/airflow/models/taskinstance.py | 246 +++++++-------------- .../core_api/routes/public/test_task_instances.py | 21 -- .../tests/unit/models/test_taskinstance.py | 54 ----- 5 files changed, 105 insertions(+), 308 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py index 82e7fad3493..68764b5456a 100644 --- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -729,54 +729,32 @@ def post_clear_task_instances( if future: body.end_date = None - if (task_markers_to_clear := body.task_ids) is not None: - mapped_tasks_tuples = {t for t in task_markers_to_clear if isinstance(t, tuple)} + task_ids = body.task_ids + if task_ids is not None: + tasks = set(task_ids) + mapped_tasks_tuples = set(t for t in tasks if isinstance(t, tuple)) # Unmapped tasks are expressed in their task_ids (without map_indexes) - normal_task_ids = {t for t in task_markers_to_clear if not isinstance(t, tuple)} - - def _collect_relatives(run_id: str, direction: Literal["upstream", "downstream"]) -> None: - from airflow.models.taskinstance import find_relevant_relatives - - relevant_relatives = find_relevant_relatives( - normal_task_ids, - mapped_tasks_tuples, - dag=dag, - run_id=run_id, - direction=direction, - session=session, + unmapped_task_ids = set(t for t in tasks if not isinstance(t, tuple)) + + if upstream or downstream: + mapped_task_ids = set(tid for tid, _ in mapped_tasks_tuples) + relatives = dag.partial_subset( + task_ids=unmapped_task_ids | mapped_task_ids, + include_downstream=downstream, + include_upstream=upstream, + exclude_original=True, ) - normal_task_ids.update(t for t in relevant_relatives if not isinstance(t, tuple)) - mapped_tasks_tuples.update(t for t in relevant_relatives if isinstance(t, tuple)) - - # We can't easily calculate upstream/downstream map indexes when not - # working for a specific dag run. It's possible by looking at the runs - # one by one, but that is both resource-consuming and logically complex. - # So instead we'll just clear all the tis based on task ID and hope - # that's good enough for most cases. - if dag_run_id is None: - if upstream or downstream: - partial_dag = dag.partial_subset( - task_ids=normal_task_ids.union(tid for tid, _ in mapped_tasks_tuples), - include_downstream=downstream, - include_upstream=upstream, - exclude_original=True, - ) - normal_task_ids.update(partial_dag.task_dict) - else: - if upstream: - _collect_relatives(dag_run_id, "upstream") - if downstream: - _collect_relatives(dag_run_id, "downstream") - - task_markers_to_clear = [ - *normal_task_ids, - *((t, m) for t, m in mapped_tasks_tuples if t not in normal_task_ids), + unmapped_task_ids = unmapped_task_ids | set(relatives.task_dict.keys()) + + mapped_tasks_list = [ + (tid, map_id) for tid, map_id in mapped_tasks_tuples if tid not in unmapped_task_ids ] + task_ids = mapped_tasks_list + list(unmapped_task_ids) # Prepare common parameters common_params = { "dry_run": True, - "task_ids": task_markers_to_clear, + "task_ids": task_ids, "session": session, "run_on_latest_version": body.run_on_latest_version, "only_failed": body.only_failed, diff --git a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py index c7b3a02301d..750c3da1ec1 100644 --- a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py +++ b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py @@ -22,9 +22,9 @@ from __future__ import annotations # [START example_dynamic_task_mapping] from datetime import datetime -from airflow.sdk import DAG, task, task_group +from airflow.sdk import DAG, task -with DAG(dag_id="example_dynamic_task_mapping", schedule=None, start_date=datetime(2022, 3, 4)): +with DAG(dag_id="example_dynamic_task_mapping", schedule=None, start_date=datetime(2022, 3, 4)) as dag: @task def add_one(x: int): @@ -39,11 +39,8 @@ with DAG(dag_id="example_dynamic_task_mapping", schedule=None, start_date=dateti sum_it(added_values) with DAG( - dag_id="example_task_mapping_second_order", - schedule=None, - catchup=False, - start_date=datetime(2022, 3, 4), -): + dag_id="example_task_mapping_second_order", schedule=None, catchup=False, start_date=datetime(2022, 3, 4) +) as dag2: @task def get_nums(): @@ -61,25 +58,4 @@ with DAG( _times_2 = times_2.expand(num=_get_nums) add_10.expand(num=_times_2) -with DAG( - dag_id="example_task_group_mapping", - schedule=None, - catchup=False, - start_date=datetime(2022, 3, 4), -): - - @task_group - def op(num): - @task - def add_1(num): - return num + 1 - - @task - def mul_2(num): - return num * 2 - - return mul_2(add_1(num)) - - op.expand(num=[1, 2, 3]) - # [END example_dynamic_task_mapping] diff --git a/airflow-core/src/airflow/models/taskinstance.py b/airflow-core/src/airflow/models/taskinstance.py index 06597a24fc9..216bc9b1bca 100644 --- a/airflow-core/src/airflow/models/taskinstance.py +++ b/airflow-core/src/airflow/models/taskinstance.py @@ -28,7 +28,7 @@ from collections import defaultdict from collections.abc import Collection, Iterable from datetime import timedelta from functools import cache -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any from urllib.parse import quote import attrs @@ -121,7 +121,7 @@ if TYPE_CHECKING: from airflow.sdk.definitions.asset import AssetUniqueKey from airflow.sdk.types import RuntimeTaskInstanceProtocol from airflow.serialization.definitions.taskgroup import SerializedTaskGroup - from airflow.serialization.serialized_objects import SerializedBaseOperator, SerializedDAG + from airflow.serialization.serialized_objects import SerializedBaseOperator from airflow.utils.context import Context Operator: TypeAlias = MappedOperator | SerializedBaseOperator @@ -2037,16 +2037,87 @@ class TaskInstance(Base, LoggingMixin): *, session: Session, ) -> int | range | None: + """ + Infer the map indexes of an upstream "relevant" to this ti. + + The bulk of the logic mainly exists to solve the problem described by + the following example, where 'val' must resolve to different values, + depending on where the reference is being used:: + + @task + def this_task(v): # This is self.task. + return v * 2 + + + @task_group + def tg1(inp): + val = upstream(inp) # This is the upstream task. + this_task(val) # When inp is 1, val here should resolve to 2. + return val + + + # This val is the same object returned by tg1. + val = tg1.expand(inp=[1, 2, 3]) + + + @task_group + def tg2(inp): + another_task(inp, val) # val here should resolve to [2, 4, 6]. + + + tg2.expand(inp=["a", "b"]) + + The surrounding mapped task groups of ``upstream`` and ``self.task`` are + inspected to find a common "ancestor". If such an ancestor is found, + we need to return specific map indexes to pull a partial value from + upstream XCom. + + :param upstream: The referenced upstream task. + :param ti_count: The total count of task instance this task was expanded + by the scheduler, i.e. ``expanded_ti_count`` in the template context. + :return: Specific map index or map indexes to pull, or ``None`` if we + want to "whole" return value (i.e. no mapped task groups involved). + """ + from airflow.models.mappedoperator import get_mapped_ti_count + if TYPE_CHECKING: - assert self.task - return _get_relevant_map_indexes( - run_id=self.run_id, - map_index=self.map_index, - ti_count=ti_count, - task=self.task, - relative=upstream, - session=session, - ) + assert self.task is not None + + # This value should never be None since we already know the current task + # is in a mapped task group, and should have been expanded, despite that, + # we need to check that it is not None to satisfy Mypy. + # But this value can be 0 when we expand an empty list, for that it is + # necessary to check that ti_count is not 0 to avoid dividing by 0. + if not ti_count: + return None + + # Find the innermost common mapped task group between the current task + # If the current task and the referenced task does not have a common + # mapped task group, the two are in different task mapping contexts + # (like another_task above), and we should use the "whole" value. + common_ancestor = _find_common_ancestor_mapped_group(self.task, upstream) + if common_ancestor is None: + return None + + # At this point we know the two tasks share a mapped task group, and we + # should use a "partial" value. Let's break down the mapped ti count + # between the ancestor and further expansion happened inside it. + + ancestor_ti_count = get_mapped_ti_count(common_ancestor, self.run_id, session=session) + ancestor_map_index = self.map_index * ancestor_ti_count // ti_count + + # If the task is NOT further expanded inside the common ancestor, we + # only want to reference one single ti. We must walk the actual DAG, + # and "ti_count == ancestor_ti_count" does not work, since the further + # expansion may be of length 1. + if not _is_further_mapped_inside(upstream, common_ancestor): + return ancestor_map_index + + # Otherwise we need a partial aggregation for values from selected task + # instances in the ancestor's expansion context. + further_count = ti_count // ancestor_ti_count + map_index_start = ancestor_map_index * further_count + return range(map_index_start, map_index_start + further_count) def clear_db_references(self, session: Session): """ @@ -2136,159 +2207,6 @@ def _is_further_mapped_inside(operator: Operator, container: SerializedTaskGroup return False -def _get_relevant_map_indexes( - *, - task: Operator, - run_id: str, - map_index: int, - relative: Operator, - ti_count: int | None, - session: Session, -) -> int | range | None: - """ - Infer the map indexes of a relative that's "relevant" to this ti. - - The bulk of the logic mainly exists to solve the problem described by - the following example, where 'val' must resolve to different values, - depending on where the reference is being used:: - - @task - def this_task(v): # This is self.task. - return v * 2 - - - @task_group - def tg1(inp): - val = upstream(inp) # This is the upstream task. - this_task(val) # When inp is 1, val here should resolve to 2. - return val - - - # This val is the same object returned by tg1. - val = tg1.expand(inp=[1, 2, 3]) - - - @task_group - def tg2(inp): - another_task(inp, val) # val here should resolve to [2, 4, 6]. - - - tg2.expand(inp=["a", "b"]) - - The surrounding mapped task groups of ``upstream`` and ``task`` are - inspected to find a common "ancestor". If such an ancestor is found, - we need to return specific map indexes to pull a partial value from - upstream XCom. - - The same logic apply for finding downstream tasks. - - :param task: Current task being inspected. - :param run_id: Current run ID. - :param map_index: Map index of the current task instance. - :param relative: The relative task to find relevant map indexes for. - :param ti_count: The total count of task instance this task was expanded - by the scheduler, i.e. ``expanded_ti_count`` in the template context. - :return: Specific map index or map indexes to pull, or ``None`` if we - want to "whole" return value (i.e. no mapped task groups involved). - """ - from airflow.models.mappedoperator import get_mapped_ti_count - - # This value should never be None since we already know the current task - # is in a mapped task group, and should have been expanded, despite that, - # we need to check that it is not None to satisfy Mypy. - # But this value can be 0 when we expand an empty list, for that it is - # necessary to check that ti_count is not 0 to avoid dividing by 0. - if not ti_count: - return None - - # Find the innermost common mapped task group between the current task - # If the current task and the referenced task does not have a common - # mapped task group, the two are in different task mapping contexts - # (like another_task above), and we should use the "whole" value. - if (common_ancestor := _find_common_ancestor_mapped_group(task, relative)) is None: - return None - - # At this point we know the two tasks share a mapped task group, and we - # should use a "partial" value. Let's break down the mapped ti count - # between the ancestor and further expansion happened inside it. - - ancestor_ti_count = get_mapped_ti_count(common_ancestor, run_id, session=session) - ancestor_map_index = map_index * ancestor_ti_count // ti_count - - # If the task is NOT further expanded inside the common ancestor, we - # only want to reference one single ti. We must walk the actual DAG, - # and "ti_count == ancestor_ti_count" does not work, since the further - # expansion may be of length 1. - if not _is_further_mapped_inside(relative, common_ancestor): - return ancestor_map_index - - # Otherwise we need a partial aggregation for values from selected task - # instances in the ancestor's expansion context. - further_count = ti_count // ancestor_ti_count - map_index_start = ancestor_map_index * further_count - return range(map_index_start, map_index_start + further_count) - - -def find_relevant_relatives( - normal_tasks: Iterable[str], - mapped_tasks: Iterable[tuple[str, int]], - *, - direction: Literal["upstream", "downstream"], - dag: SerializedDAG, - run_id: str, - session: Session, -) -> Collection[str | tuple[str, int]]: - from airflow.models.mappedoperator import get_mapped_ti_count - - visited: set[str | tuple[str, int]] = set() - - def _visit_relevant_relatives_for_normal(task_ids: Iterable[str]) -> None: - partial_dag = dag.partial_subset( - task_ids=task_ids, - include_downstream=direction == "downstream", - include_upstream=direction == "upstream", - exclude_original=True, - ) - visited.update(partial_dag.task_dict) - - def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, int]]) -> None: - for task_id, map_index in mapped_tasks: - task = dag.get_task(task_id) - ti_count = get_mapped_ti_count(task, run_id, session=session) - # TODO (GH-52141): This should return scheduler operator types, but - # currently get_flat_relatives is inherited from SDK DAGNode. - relatives = cast("Iterable[Operator]", task.get_flat_relatives(upstream=direction == "upstream")) - for relative in relatives: - if relative.task_id in visited: - continue - relative_map_indexes = _get_relevant_map_indexes( - task=task, - relative=relative, # type: ignore[arg-type] - run_id=run_id, - map_index=map_index, - ti_count=ti_count, - session=session, - ) - visiting_mapped: set[tuple[str, int]] = set() - visiting_normal: set[str] = set() - match relative_map_indexes: - case int(): - if (item := (relative.task_id, relative_map_indexes)) not in visited: - visiting_mapped.add(item) - case range(): - visiting_mapped.update((relative.task_id, i) for i in relative_map_indexes) - case None: - if (task_id := relative.task_id) not in visited: - visiting_normal.add(task_id) - _visit_relevant_relatives_for_normal(visiting_normal) - _visit_relevant_relatives_for_mapped(visiting_mapped) - visited.update(visiting_mapped, visiting_normal) - - _visit_relevant_relatives_for_normal(normal_tasks) - _visit_relevant_relatives_for_mapped(mapped_tasks) - return visited - - class TaskInstanceNote(Base): """For storage of arbitrary notes concerning the task instance.""" diff --git a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py index 41b81a3c65d..6f743cd7d4b 100644 --- a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py @@ -2536,27 +2536,6 @@ class TestPostClearTaskInstances(TestTaskInstanceEndpoint): 4, id="clear mapped tasks with and without map index", ), - pytest.param( - "example_task_group_mapping", - [ - { - "state": State.FAILED, - "map_indexes": (0, 1, 2), - }, - { - "state": State.FAILED, - "map_indexes": (0, 1, 2), - }, - ], - "example_task_group_mapping", - { - "task_ids": [["op.mul_2", 0]], - "dag_run_id": "TEST_DAG_RUN_ID", - "include_upstream": True, - }, - 2, - id="clear tasks in mapped task group", - ), ], ) def test_should_respond_200( diff --git a/airflow-core/tests/unit/models/test_taskinstance.py b/airflow-core/tests/unit/models/test_taskinstance.py index 441d81de2b5..4c7a0d9fd5c 100644 --- a/airflow-core/tests/unit/models/test_taskinstance.py +++ b/airflow-core/tests/unit/models/test_taskinstance.py @@ -51,7 +51,6 @@ from airflow.models.taskinstance import ( TaskInstance, TaskInstance as TI, TaskInstanceNote, - find_relevant_relatives, ) from airflow.models.taskinstancehistory import TaskInstanceHistory from airflow.models.taskmap import TaskMap @@ -3102,56 +3101,3 @@ def test_delete_dagversion_restricted_when_taskinstance_exists(dag_maker, sessio session.delete(version) with pytest.raises(IntegrityError): session.commit() - - [email protected]( - ("normal_tasks", "mapped_tasks", "expected"), - [ - # 4 is just a regular task so it depends on all its upstreams. - pytest.param(["4"], [], {"1", "2", "3"}, id="nonmapped"), - # 3 is a mapped; it depends on all tis of the mapped upstream 2. - pytest.param(["3"], [], {"1", "2"}, id="mapped-whole"), - # Every ti of a mapped task depends on all tis of the mapped upstream. - pytest.param([], [("3", 1)], {"1", "2"}, id="mapped-one"), - # Same as the (non-group) unmapped case, d depends on all upstreams. - pytest.param(["d"], [], {"a", "b", "c"}, id="group-nonmapped"), - # This specifies c tis in ALL mapped task groups, so all b tis are needed. - pytest.param(["c"], [], {"a", "b"}, id="group-mapped-whole"), - # This only specifies one c ti, so only one b ti from the same mapped instance is returned. - pytest.param([], [("c", 1)], {"a", ("b", 1)}, id="group-mapped-one"), - ], -) -def test_find_relevant_relatives(dag_maker, session, normal_tasks, mapped_tasks, expected): - # 1 -> 2[] -> 3[] -> 4 - # - # a -> " b --> c " -> d - # "== g[] ==" - with dag_maker(session=session) as dag: - t1 = EmptyOperator(task_id="1") - t2 = MockOperator.partial(task_id="2").expand(arg1=["x", "y"]) - t3 = MockOperator.partial(task_id="3").expand(arg1=["x", "y"]) - t4 = EmptyOperator(task_id="4") - t1 >> t2 >> t3 >> t4 - - ta = EmptyOperator(task_id="a") - - @task_group(prefix_group_id=False) - def g(v): - tb = MockOperator(task_id="b", arg1=v) - tc = MockOperator(task_id="c", arg1=v) - tb >> tc - - td = EmptyOperator(task_id="d") - ta >> g.expand(v=["x", "y", "z"]) >> td - - dr = dag_maker.create_dagrun(state="success") - - result = find_relevant_relatives( - normal_tasks=normal_tasks, - mapped_tasks=mapped_tasks, - direction="upstream", - dag=dag, - run_id=dr.run_id, - session=session, - ) - assert result == expected
