This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 76b9aa61e2732c55715890fd50f4ccda5a6645e5
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Mon Dec 8 09:22:09 2025 +0100

    Revert "Find only relevant up/downstream tis when clearing (#57758) 
(#58987)"
    
    This reverts commit 8918f98665f149a454e14c0a8cc7d2b09474d93c.
---
 .../core_api/routes/public/task_instances.py       |  60 ++---
 .../example_dags/example_dynamic_task_mapping.py   |  32 +--
 airflow-core/src/airflow/models/taskinstance.py    | 246 +++++++--------------
 .../core_api/routes/public/test_task_instances.py  |  21 --
 .../tests/unit/models/test_taskinstance.py         |  54 -----
 5 files changed, 105 insertions(+), 308 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 82e7fad3493..68764b5456a 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -729,54 +729,32 @@ def post_clear_task_instances(
     if future:
         body.end_date = None
 
-    if (task_markers_to_clear := body.task_ids) is not None:
-        mapped_tasks_tuples = {t for t in task_markers_to_clear if 
isinstance(t, tuple)}
+    task_ids = body.task_ids
+    if task_ids is not None:
+        tasks = set(task_ids)
+        mapped_tasks_tuples = set(t for t in tasks if isinstance(t, tuple))
         # Unmapped tasks are expressed in their task_ids (without map_indexes)
-        normal_task_ids = {t for t in task_markers_to_clear if not 
isinstance(t, tuple)}
-
-        def _collect_relatives(run_id: str, direction: Literal["upstream", 
"downstream"]) -> None:
-            from airflow.models.taskinstance import find_relevant_relatives
-
-            relevant_relatives = find_relevant_relatives(
-                normal_task_ids,
-                mapped_tasks_tuples,
-                dag=dag,
-                run_id=run_id,
-                direction=direction,
-                session=session,
+        unmapped_task_ids = set(t for t in tasks if not isinstance(t, tuple))
+
+        if upstream or downstream:
+            mapped_task_ids = set(tid for tid, _ in mapped_tasks_tuples)
+            relatives = dag.partial_subset(
+                task_ids=unmapped_task_ids | mapped_task_ids,
+                include_downstream=downstream,
+                include_upstream=upstream,
+                exclude_original=True,
             )
-            normal_task_ids.update(t for t in relevant_relatives if not 
isinstance(t, tuple))
-            mapped_tasks_tuples.update(t for t in relevant_relatives if 
isinstance(t, tuple))
-
-        # We can't easily calculate upstream/downstream map indexes when not
-        # working for a specific dag run. It's possible by looking at the runs
-        # one by one, but that is both resource-consuming and logically 
complex.
-        # So instead we'll just clear all the tis based on task ID and hope
-        # that's good enough for most cases.
-        if dag_run_id is None:
-            if upstream or downstream:
-                partial_dag = dag.partial_subset(
-                    task_ids=normal_task_ids.union(tid for tid, _ in 
mapped_tasks_tuples),
-                    include_downstream=downstream,
-                    include_upstream=upstream,
-                    exclude_original=True,
-                )
-                normal_task_ids.update(partial_dag.task_dict)
-        else:
-            if upstream:
-                _collect_relatives(dag_run_id, "upstream")
-            if downstream:
-                _collect_relatives(dag_run_id, "downstream")
-
-        task_markers_to_clear = [
-            *normal_task_ids,
-            *((t, m) for t, m in mapped_tasks_tuples if t not in 
normal_task_ids),
+            unmapped_task_ids = unmapped_task_ids | 
set(relatives.task_dict.keys())
+
+        mapped_tasks_list = [
+            (tid, map_id) for tid, map_id in mapped_tasks_tuples if tid not in 
unmapped_task_ids
         ]
+        task_ids = mapped_tasks_list + list(unmapped_task_ids)
 
     # Prepare common parameters
     common_params = {
         "dry_run": True,
-        "task_ids": task_markers_to_clear,
+        "task_ids": task_ids,
         "session": session,
         "run_on_latest_version": body.run_on_latest_version,
         "only_failed": body.only_failed,
diff --git 
a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py 
b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
index c7b3a02301d..750c3da1ec1 100644
--- a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
+++ b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
@@ -22,9 +22,9 @@ from __future__ import annotations
 # [START example_dynamic_task_mapping]
 from datetime import datetime
 
-from airflow.sdk import DAG, task, task_group
+from airflow.sdk import DAG, task
 
-with DAG(dag_id="example_dynamic_task_mapping", schedule=None, 
start_date=datetime(2022, 3, 4)):
+with DAG(dag_id="example_dynamic_task_mapping", schedule=None, 
start_date=datetime(2022, 3, 4)) as dag:
 
     @task
     def add_one(x: int):
@@ -39,11 +39,8 @@ with DAG(dag_id="example_dynamic_task_mapping", 
schedule=None, start_date=dateti
     sum_it(added_values)
 
 with DAG(
-    dag_id="example_task_mapping_second_order",
-    schedule=None,
-    catchup=False,
-    start_date=datetime(2022, 3, 4),
-):
+    dag_id="example_task_mapping_second_order", schedule=None, catchup=False, 
start_date=datetime(2022, 3, 4)
+) as dag2:
 
     @task
     def get_nums():
@@ -61,25 +58,4 @@ with DAG(
     _times_2 = times_2.expand(num=_get_nums)
     add_10.expand(num=_times_2)
 
-with DAG(
-    dag_id="example_task_group_mapping",
-    schedule=None,
-    catchup=False,
-    start_date=datetime(2022, 3, 4),
-):
-
-    @task_group
-    def op(num):
-        @task
-        def add_1(num):
-            return num + 1
-
-        @task
-        def mul_2(num):
-            return num * 2
-
-        return mul_2(add_1(num))
-
-    op.expand(num=[1, 2, 3])
-
 # [END example_dynamic_task_mapping]
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 06597a24fc9..216bc9b1bca 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -28,7 +28,7 @@ from collections import defaultdict
 from collections.abc import Collection, Iterable
 from datetime import timedelta
 from functools import cache
-from typing import TYPE_CHECKING, Any, cast
+from typing import TYPE_CHECKING, Any
 from urllib.parse import quote
 
 import attrs
@@ -121,7 +121,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.asset import AssetUniqueKey
     from airflow.sdk.types import RuntimeTaskInstanceProtocol
     from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
-    from airflow.serialization.serialized_objects import 
SerializedBaseOperator, SerializedDAG
+    from airflow.serialization.serialized_objects import SerializedBaseOperator
     from airflow.utils.context import Context
 
     Operator: TypeAlias = MappedOperator | SerializedBaseOperator
@@ -2037,16 +2037,87 @@ class TaskInstance(Base, LoggingMixin):
         *,
         session: Session,
     ) -> int | range | None:
+        """
+        Infer the map indexes of an upstream "relevant" to this ti.
+
+        The bulk of the logic mainly exists to solve the problem described by
+        the following example, where 'val' must resolve to different values,
+        depending on where the reference is being used::
+
+            @task
+            def this_task(v):  # This is self.task.
+                return v * 2
+
+
+            @task_group
+            def tg1(inp):
+                val = upstream(inp)  # This is the upstream task.
+                this_task(val)  # When inp is 1, val here should resolve to 2.
+                return val
+
+
+            # This val is the same object returned by tg1.
+            val = tg1.expand(inp=[1, 2, 3])
+
+
+            @task_group
+            def tg2(inp):
+                another_task(inp, val)  # val here should resolve to [2, 4, 6].
+
+
+            tg2.expand(inp=["a", "b"])
+
+        The surrounding mapped task groups of ``upstream`` and ``self.task`` 
are
+        inspected to find a common "ancestor". If such an ancestor is found,
+        we need to return specific map indexes to pull a partial value from
+        upstream XCom.
+
+        :param upstream: The referenced upstream task.
+        :param ti_count: The total count of task instance this task was 
expanded
+            by the scheduler, i.e. ``expanded_ti_count`` in the template 
context.
+        :return: Specific map index or map indexes to pull, or ``None`` if we
+            want to "whole" return value (i.e. no mapped task groups involved).
+        """
+        from airflow.models.mappedoperator import get_mapped_ti_count
+
         if TYPE_CHECKING:
-            assert self.task
-        return _get_relevant_map_indexes(
-            run_id=self.run_id,
-            map_index=self.map_index,
-            ti_count=ti_count,
-            task=self.task,
-            relative=upstream,
-            session=session,
-        )
+            assert self.task is not None
+
+        # This value should never be None since we already know the current 
task
+        # is in a mapped task group, and should have been expanded, despite 
that,
+        # we need to check that it is not None to satisfy Mypy.
+        # But this value can be 0 when we expand an empty list, for that it is
+        # necessary to check that ti_count is not 0 to avoid dividing by 0.
+        if not ti_count:
+            return None
+
+        # Find the innermost common mapped task group between the current task
+        # If the current task and the referenced task does not have a common
+        # mapped task group, the two are in different task mapping contexts
+        # (like another_task above), and we should use the "whole" value.
+        common_ancestor = _find_common_ancestor_mapped_group(self.task, 
upstream)
+        if common_ancestor is None:
+            return None
+
+        # At this point we know the two tasks share a mapped task group, and we
+        # should use a "partial" value. Let's break down the mapped ti count
+        # between the ancestor and further expansion happened inside it.
+
+        ancestor_ti_count = get_mapped_ti_count(common_ancestor, self.run_id, 
session=session)
+        ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
+
+        # If the task is NOT further expanded inside the common ancestor, we
+        # only want to reference one single ti. We must walk the actual DAG,
+        # and "ti_count == ancestor_ti_count" does not work, since the further
+        # expansion may be of length 1.
+        if not _is_further_mapped_inside(upstream, common_ancestor):
+            return ancestor_map_index
+
+        # Otherwise we need a partial aggregation for values from selected task
+        # instances in the ancestor's expansion context.
+        further_count = ti_count // ancestor_ti_count
+        map_index_start = ancestor_map_index * further_count
+        return range(map_index_start, map_index_start + further_count)
 
     def clear_db_references(self, session: Session):
         """
@@ -2136,159 +2207,6 @@ def _is_further_mapped_inside(operator: Operator, 
container: SerializedTaskGroup
     return False
 
 
-def _get_relevant_map_indexes(
-    *,
-    task: Operator,
-    run_id: str,
-    map_index: int,
-    relative: Operator,
-    ti_count: int | None,
-    session: Session,
-) -> int | range | None:
-    """
-    Infer the map indexes of a relative that's "relevant" to this ti.
-
-    The bulk of the logic mainly exists to solve the problem described by
-    the following example, where 'val' must resolve to different values,
-    depending on where the reference is being used::
-
-        @task
-        def this_task(v):  # This is self.task.
-            return v * 2
-
-
-        @task_group
-        def tg1(inp):
-            val = upstream(inp)  # This is the upstream task.
-            this_task(val)  # When inp is 1, val here should resolve to 2.
-            return val
-
-
-        # This val is the same object returned by tg1.
-        val = tg1.expand(inp=[1, 2, 3])
-
-
-        @task_group
-        def tg2(inp):
-            another_task(inp, val)  # val here should resolve to [2, 4, 6].
-
-
-        tg2.expand(inp=["a", "b"])
-
-    The surrounding mapped task groups of ``upstream`` and ``task`` are
-    inspected to find a common "ancestor". If such an ancestor is found,
-    we need to return specific map indexes to pull a partial value from
-    upstream XCom.
-
-    The same logic apply for finding downstream tasks.
-
-    :param task: Current task being inspected.
-    :param run_id: Current run ID.
-    :param map_index: Map index of the current task instance.
-    :param relative: The relative task to find relevant map indexes for.
-    :param ti_count: The total count of task instance this task was expanded
-        by the scheduler, i.e. ``expanded_ti_count`` in the template context.
-    :return: Specific map index or map indexes to pull, or ``None`` if we
-        want to "whole" return value (i.e. no mapped task groups involved).
-    """
-    from airflow.models.mappedoperator import get_mapped_ti_count
-
-    # This value should never be None since we already know the current task
-    # is in a mapped task group, and should have been expanded, despite that,
-    # we need to check that it is not None to satisfy Mypy.
-    # But this value can be 0 when we expand an empty list, for that it is
-    # necessary to check that ti_count is not 0 to avoid dividing by 0.
-    if not ti_count:
-        return None
-
-    # Find the innermost common mapped task group between the current task
-    # If the current task and the referenced task does not have a common
-    # mapped task group, the two are in different task mapping contexts
-    # (like another_task above), and we should use the "whole" value.
-    if (common_ancestor := _find_common_ancestor_mapped_group(task, relative)) 
is None:
-        return None
-
-    # At this point we know the two tasks share a mapped task group, and we
-    # should use a "partial" value. Let's break down the mapped ti count
-    # between the ancestor and further expansion happened inside it.
-
-    ancestor_ti_count = get_mapped_ti_count(common_ancestor, run_id, 
session=session)
-    ancestor_map_index = map_index * ancestor_ti_count // ti_count
-
-    # If the task is NOT further expanded inside the common ancestor, we
-    # only want to reference one single ti. We must walk the actual DAG,
-    # and "ti_count == ancestor_ti_count" does not work, since the further
-    # expansion may be of length 1.
-    if not _is_further_mapped_inside(relative, common_ancestor):
-        return ancestor_map_index
-
-    # Otherwise we need a partial aggregation for values from selected task
-    # instances in the ancestor's expansion context.
-    further_count = ti_count // ancestor_ti_count
-    map_index_start = ancestor_map_index * further_count
-    return range(map_index_start, map_index_start + further_count)
-
-
-def find_relevant_relatives(
-    normal_tasks: Iterable[str],
-    mapped_tasks: Iterable[tuple[str, int]],
-    *,
-    direction: Literal["upstream", "downstream"],
-    dag: SerializedDAG,
-    run_id: str,
-    session: Session,
-) -> Collection[str | tuple[str, int]]:
-    from airflow.models.mappedoperator import get_mapped_ti_count
-
-    visited: set[str | tuple[str, int]] = set()
-
-    def _visit_relevant_relatives_for_normal(task_ids: Iterable[str]) -> None:
-        partial_dag = dag.partial_subset(
-            task_ids=task_ids,
-            include_downstream=direction == "downstream",
-            include_upstream=direction == "upstream",
-            exclude_original=True,
-        )
-        visited.update(partial_dag.task_dict)
-
-    def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, 
int]]) -> None:
-        for task_id, map_index in mapped_tasks:
-            task = dag.get_task(task_id)
-            ti_count = get_mapped_ti_count(task, run_id, session=session)
-            # TODO (GH-52141): This should return scheduler operator types, but
-            # currently get_flat_relatives is inherited from SDK DAGNode.
-            relatives = cast("Iterable[Operator]", 
task.get_flat_relatives(upstream=direction == "upstream"))
-            for relative in relatives:
-                if relative.task_id in visited:
-                    continue
-                relative_map_indexes = _get_relevant_map_indexes(
-                    task=task,
-                    relative=relative,  # type: ignore[arg-type]
-                    run_id=run_id,
-                    map_index=map_index,
-                    ti_count=ti_count,
-                    session=session,
-                )
-                visiting_mapped: set[tuple[str, int]] = set()
-                visiting_normal: set[str] = set()
-                match relative_map_indexes:
-                    case int():
-                        if (item := (relative.task_id, relative_map_indexes)) 
not in visited:
-                            visiting_mapped.add(item)
-                    case range():
-                        visiting_mapped.update((relative.task_id, i) for i in 
relative_map_indexes)
-                    case None:
-                        if (task_id := relative.task_id) not in visited:
-                            visiting_normal.add(task_id)
-                _visit_relevant_relatives_for_normal(visiting_normal)
-                _visit_relevant_relatives_for_mapped(visiting_mapped)
-                visited.update(visiting_mapped, visiting_normal)
-
-    _visit_relevant_relatives_for_normal(normal_tasks)
-    _visit_relevant_relatives_for_mapped(mapped_tasks)
-    return visited
-
-
 class TaskInstanceNote(Base):
     """For storage of arbitrary notes concerning the task instance."""
 
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 41b81a3c65d..6f743cd7d4b 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -2536,27 +2536,6 @@ class 
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
                 4,
                 id="clear mapped tasks with and without map index",
             ),
-            pytest.param(
-                "example_task_group_mapping",
-                [
-                    {
-                        "state": State.FAILED,
-                        "map_indexes": (0, 1, 2),
-                    },
-                    {
-                        "state": State.FAILED,
-                        "map_indexes": (0, 1, 2),
-                    },
-                ],
-                "example_task_group_mapping",
-                {
-                    "task_ids": [["op.mul_2", 0]],
-                    "dag_run_id": "TEST_DAG_RUN_ID",
-                    "include_upstream": True,
-                },
-                2,
-                id="clear tasks in mapped task group",
-            ),
         ],
     )
     def test_should_respond_200(
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 441d81de2b5..4c7a0d9fd5c 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -51,7 +51,6 @@ from airflow.models.taskinstance import (
     TaskInstance,
     TaskInstance as TI,
     TaskInstanceNote,
-    find_relevant_relatives,
 )
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.models.taskmap import TaskMap
@@ -3102,56 +3101,3 @@ def 
test_delete_dagversion_restricted_when_taskinstance_exists(dag_maker, sessio
     session.delete(version)
     with pytest.raises(IntegrityError):
         session.commit()
-
-
[email protected](
-    ("normal_tasks", "mapped_tasks", "expected"),
-    [
-        # 4 is just a regular task so it depends on all its upstreams.
-        pytest.param(["4"], [], {"1", "2", "3"}, id="nonmapped"),
-        # 3 is a mapped; it depends on all tis of the mapped upstream 2.
-        pytest.param(["3"], [], {"1", "2"}, id="mapped-whole"),
-        # Every ti of a mapped task depends on all tis of the mapped upstream.
-        pytest.param([], [("3", 1)], {"1", "2"}, id="mapped-one"),
-        # Same as the (non-group) unmapped case, d depends on all upstreams.
-        pytest.param(["d"], [], {"a", "b", "c"}, id="group-nonmapped"),
-        # This specifies c tis in ALL mapped task groups, so all b tis are 
needed.
-        pytest.param(["c"], [], {"a", "b"}, id="group-mapped-whole"),
-        # This only specifies one c ti, so only one b ti from the same mapped 
instance is returned.
-        pytest.param([], [("c", 1)], {"a", ("b", 1)}, id="group-mapped-one"),
-    ],
-)
-def test_find_relevant_relatives(dag_maker, session, normal_tasks, 
mapped_tasks, expected):
-    # 1 -> 2[] -> 3[] -> 4
-    #
-    # a -> " b --> c " -> d
-    #      "== g[] =="
-    with dag_maker(session=session) as dag:
-        t1 = EmptyOperator(task_id="1")
-        t2 = MockOperator.partial(task_id="2").expand(arg1=["x", "y"])
-        t3 = MockOperator.partial(task_id="3").expand(arg1=["x", "y"])
-        t4 = EmptyOperator(task_id="4")
-        t1 >> t2 >> t3 >> t4
-
-        ta = EmptyOperator(task_id="a")
-
-        @task_group(prefix_group_id=False)
-        def g(v):
-            tb = MockOperator(task_id="b", arg1=v)
-            tc = MockOperator(task_id="c", arg1=v)
-            tb >> tc
-
-        td = EmptyOperator(task_id="d")
-        ta >> g.expand(v=["x", "y", "z"]) >> td
-
-    dr = dag_maker.create_dagrun(state="success")
-
-    result = find_relevant_relatives(
-        normal_tasks=normal_tasks,
-        mapped_tasks=mapped_tasks,
-        direction="upstream",
-        dag=dag,
-        run_id=dr.run_id,
-        session=session,
-    )
-    assert result == expected

Reply via email to