This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-1-test by this push:
     new 3a31d65e86e Find only relevant up/downstream tis when clearing 
(#57758) (#58987)
3a31d65e86e is described below

commit 3a31d65e86e768dab7f9a18abe70c6788121090f
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Dec 3 14:53:26 2025 +0800

    Find only relevant up/downstream tis when clearing (#57758) (#58987)
    
    (cherry picked from commit ad953da43d6c84d90e811a225e71e3ddc8a63dd5)
---
 .../core_api/routes/public/task_instances.py       |  60 +++--
 .../example_dags/example_dynamic_task_mapping.py   |  32 ++-
 airflow-core/src/airflow/models/taskinstance.py    | 246 ++++++++++++++-------
 .../core_api/routes/public/test_task_instances.py  |  21 ++
 .../tests/unit/models/test_taskinstance.py         |  54 +++++
 5 files changed, 308 insertions(+), 105 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 68764b5456a..82e7fad3493 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -729,32 +729,54 @@ def post_clear_task_instances(
     if future:
         body.end_date = None
 
-    task_ids = body.task_ids
-    if task_ids is not None:
-        tasks = set(task_ids)
-        mapped_tasks_tuples = set(t for t in tasks if isinstance(t, tuple))
+    if (task_markers_to_clear := body.task_ids) is not None:
+        mapped_tasks_tuples = {t for t in task_markers_to_clear if 
isinstance(t, tuple)}
         # Unmapped tasks are expressed in their task_ids (without map_indexes)
-        unmapped_task_ids = set(t for t in tasks if not isinstance(t, tuple))
-
-        if upstream or downstream:
-            mapped_task_ids = set(tid for tid, _ in mapped_tasks_tuples)
-            relatives = dag.partial_subset(
-                task_ids=unmapped_task_ids | mapped_task_ids,
-                include_downstream=downstream,
-                include_upstream=upstream,
-                exclude_original=True,
-            )
-            unmapped_task_ids = unmapped_task_ids | 
set(relatives.task_dict.keys())
+        normal_task_ids = {t for t in task_markers_to_clear if not 
isinstance(t, tuple)}
+
+        def _collect_relatives(run_id: str, direction: Literal["upstream", 
"downstream"]) -> None:
+            from airflow.models.taskinstance import find_relevant_relatives
 
-        mapped_tasks_list = [
-            (tid, map_id) for tid, map_id in mapped_tasks_tuples if tid not in 
unmapped_task_ids
+            relevant_relatives = find_relevant_relatives(
+                normal_task_ids,
+                mapped_tasks_tuples,
+                dag=dag,
+                run_id=run_id,
+                direction=direction,
+                session=session,
+            )
+            normal_task_ids.update(t for t in relevant_relatives if not 
isinstance(t, tuple))
+            mapped_tasks_tuples.update(t for t in relevant_relatives if 
isinstance(t, tuple))
+
+        # We can't easily calculate upstream/downstream map indexes when not
+        # working for a specific dag run. It's possible by looking at the runs
+        # one by one, but that is both resource-consuming and logically 
complex.
+        # So instead we'll just clear all the tis based on task ID and hope
+        # that's good enough for most cases.
+        if dag_run_id is None:
+            if upstream or downstream:
+                partial_dag = dag.partial_subset(
+                    task_ids=normal_task_ids.union(tid for tid, _ in 
mapped_tasks_tuples),
+                    include_downstream=downstream,
+                    include_upstream=upstream,
+                    exclude_original=True,
+                )
+                normal_task_ids.update(partial_dag.task_dict)
+        else:
+            if upstream:
+                _collect_relatives(dag_run_id, "upstream")
+            if downstream:
+                _collect_relatives(dag_run_id, "downstream")
+
+        task_markers_to_clear = [
+            *normal_task_ids,
+            *((t, m) for t, m in mapped_tasks_tuples if t not in 
normal_task_ids),
         ]
-        task_ids = mapped_tasks_list + list(unmapped_task_ids)
 
     # Prepare common parameters
     common_params = {
         "dry_run": True,
-        "task_ids": task_ids,
+        "task_ids": task_markers_to_clear,
         "session": session,
         "run_on_latest_version": body.run_on_latest_version,
         "only_failed": body.only_failed,
diff --git 
a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py 
b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
index 750c3da1ec1..c7b3a02301d 100644
--- a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
+++ b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
@@ -22,9 +22,9 @@ from __future__ import annotations
 # [START example_dynamic_task_mapping]
 from datetime import datetime
 
-from airflow.sdk import DAG, task
+from airflow.sdk import DAG, task, task_group
 
-with DAG(dag_id="example_dynamic_task_mapping", schedule=None, 
start_date=datetime(2022, 3, 4)) as dag:
+with DAG(dag_id="example_dynamic_task_mapping", schedule=None, 
start_date=datetime(2022, 3, 4)):
 
     @task
     def add_one(x: int):
@@ -39,8 +39,11 @@ with DAG(dag_id="example_dynamic_task_mapping", 
schedule=None, start_date=dateti
     sum_it(added_values)
 
 with DAG(
-    dag_id="example_task_mapping_second_order", schedule=None, catchup=False, 
start_date=datetime(2022, 3, 4)
-) as dag2:
+    dag_id="example_task_mapping_second_order",
+    schedule=None,
+    catchup=False,
+    start_date=datetime(2022, 3, 4),
+):
 
     @task
     def get_nums():
@@ -58,4 +61,25 @@ with DAG(
     _times_2 = times_2.expand(num=_get_nums)
     add_10.expand(num=_times_2)
 
+with DAG(
+    dag_id="example_task_group_mapping",
+    schedule=None,
+    catchup=False,
+    start_date=datetime(2022, 3, 4),
+):
+
+    @task_group
+    def op(num):
+        @task
+        def add_1(num):
+            return num + 1
+
+        @task
+        def mul_2(num):
+            return num * 2
+
+        return mul_2(add_1(num))
+
+    op.expand(num=[1, 2, 3])
+
 # [END example_dynamic_task_mapping]
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 216bc9b1bca..06597a24fc9 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -28,7 +28,7 @@ from collections import defaultdict
 from collections.abc import Collection, Iterable
 from datetime import timedelta
 from functools import cache
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
 from urllib.parse import quote
 
 import attrs
@@ -121,7 +121,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.asset import AssetUniqueKey
     from airflow.sdk.types import RuntimeTaskInstanceProtocol
     from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
-    from airflow.serialization.serialized_objects import SerializedBaseOperator
+    from airflow.serialization.serialized_objects import 
SerializedBaseOperator, SerializedDAG
     from airflow.utils.context import Context
 
     Operator: TypeAlias = MappedOperator | SerializedBaseOperator
@@ -2037,87 +2037,16 @@ class TaskInstance(Base, LoggingMixin):
         *,
         session: Session,
     ) -> int | range | None:
-        """
-        Infer the map indexes of an upstream "relevant" to this ti.
-
-        The bulk of the logic mainly exists to solve the problem described by
-        the following example, where 'val' must resolve to different values,
-        depending on where the reference is being used::
-
-            @task
-            def this_task(v):  # This is self.task.
-                return v * 2
-
-
-            @task_group
-            def tg1(inp):
-                val = upstream(inp)  # This is the upstream task.
-                this_task(val)  # When inp is 1, val here should resolve to 2.
-                return val
-
-
-            # This val is the same object returned by tg1.
-            val = tg1.expand(inp=[1, 2, 3])
-
-
-            @task_group
-            def tg2(inp):
-                another_task(inp, val)  # val here should resolve to [2, 4, 6].
-
-
-            tg2.expand(inp=["a", "b"])
-
-        The surrounding mapped task groups of ``upstream`` and ``self.task`` 
are
-        inspected to find a common "ancestor". If such an ancestor is found,
-        we need to return specific map indexes to pull a partial value from
-        upstream XCom.
-
-        :param upstream: The referenced upstream task.
-        :param ti_count: The total count of task instance this task was 
expanded
-            by the scheduler, i.e. ``expanded_ti_count`` in the template 
context.
-        :return: Specific map index or map indexes to pull, or ``None`` if we
-            want to "whole" return value (i.e. no mapped task groups involved).
-        """
-        from airflow.models.mappedoperator import get_mapped_ti_count
-
         if TYPE_CHECKING:
-            assert self.task is not None
-
-        # This value should never be None since we already know the current 
task
-        # is in a mapped task group, and should have been expanded, despite 
that,
-        # we need to check that it is not None to satisfy Mypy.
-        # But this value can be 0 when we expand an empty list, for that it is
-        # necessary to check that ti_count is not 0 to avoid dividing by 0.
-        if not ti_count:
-            return None
-
-        # Find the innermost common mapped task group between the current task
-        # If the current task and the referenced task does not have a common
-        # mapped task group, the two are in different task mapping contexts
-        # (like another_task above), and we should use the "whole" value.
-        common_ancestor = _find_common_ancestor_mapped_group(self.task, 
upstream)
-        if common_ancestor is None:
-            return None
-
-        # At this point we know the two tasks share a mapped task group, and we
-        # should use a "partial" value. Let's break down the mapped ti count
-        # between the ancestor and further expansion happened inside it.
-
-        ancestor_ti_count = get_mapped_ti_count(common_ancestor, self.run_id, 
session=session)
-        ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
-
-        # If the task is NOT further expanded inside the common ancestor, we
-        # only want to reference one single ti. We must walk the actual DAG,
-        # and "ti_count == ancestor_ti_count" does not work, since the further
-        # expansion may be of length 1.
-        if not _is_further_mapped_inside(upstream, common_ancestor):
-            return ancestor_map_index
-
-        # Otherwise we need a partial aggregation for values from selected task
-        # instances in the ancestor's expansion context.
-        further_count = ti_count // ancestor_ti_count
-        map_index_start = ancestor_map_index * further_count
-        return range(map_index_start, map_index_start + further_count)
+            assert self.task
+        return _get_relevant_map_indexes(
+            run_id=self.run_id,
+            map_index=self.map_index,
+            ti_count=ti_count,
+            task=self.task,
+            relative=upstream,
+            session=session,
+        )
 
     def clear_db_references(self, session: Session):
         """
@@ -2207,6 +2136,159 @@ def _is_further_mapped_inside(operator: Operator, 
container: SerializedTaskGroup
     return False
 
 
+def _get_relevant_map_indexes(
+    *,
+    task: Operator,
+    run_id: str,
+    map_index: int,
+    relative: Operator,
+    ti_count: int | None,
+    session: Session,
+) -> int | range | None:
+    """
+    Infer the map indexes of a relative that's "relevant" to this ti.
+
+    The bulk of the logic mainly exists to solve the problem described by
+    the following example, where 'val' must resolve to different values,
+    depending on where the reference is being used::
+
+        @task
+        def this_task(v):  # This is self.task.
+            return v * 2
+
+
+        @task_group
+        def tg1(inp):
+            val = upstream(inp)  # This is the upstream task.
+            this_task(val)  # When inp is 1, val here should resolve to 2.
+            return val
+
+
+        # This val is the same object returned by tg1.
+        val = tg1.expand(inp=[1, 2, 3])
+
+
+        @task_group
+        def tg2(inp):
+            another_task(inp, val)  # val here should resolve to [2, 4, 6].
+
+
+        tg2.expand(inp=["a", "b"])
+
+    The surrounding mapped task groups of ``upstream`` and ``task`` are
+    inspected to find a common "ancestor". If such an ancestor is found,
+    we need to return specific map indexes to pull a partial value from
+    upstream XCom.
+
+    The same logic apply for finding downstream tasks.
+
+    :param task: Current task being inspected.
+    :param run_id: Current run ID.
+    :param map_index: Map index of the current task instance.
+    :param relative: The relative task to find relevant map indexes for.
+    :param ti_count: The total count of task instance this task was expanded
+        by the scheduler, i.e. ``expanded_ti_count`` in the template context.
+    :return: Specific map index or map indexes to pull, or ``None`` if we
+        want to "whole" return value (i.e. no mapped task groups involved).
+    """
+    from airflow.models.mappedoperator import get_mapped_ti_count
+
+    # This value should never be None since we already know the current task
+    # is in a mapped task group, and should have been expanded, despite that,
+    # we need to check that it is not None to satisfy Mypy.
+    # But this value can be 0 when we expand an empty list, for that it is
+    # necessary to check that ti_count is not 0 to avoid dividing by 0.
+    if not ti_count:
+        return None
+
+    # Find the innermost common mapped task group between the current task
+    # If the current task and the referenced task does not have a common
+    # mapped task group, the two are in different task mapping contexts
+    # (like another_task above), and we should use the "whole" value.
+    if (common_ancestor := _find_common_ancestor_mapped_group(task, relative)) 
is None:
+        return None
+
+    # At this point we know the two tasks share a mapped task group, and we
+    # should use a "partial" value. Let's break down the mapped ti count
+    # between the ancestor and further expansion happened inside it.
+
+    ancestor_ti_count = get_mapped_ti_count(common_ancestor, run_id, 
session=session)
+    ancestor_map_index = map_index * ancestor_ti_count // ti_count
+
+    # If the task is NOT further expanded inside the common ancestor, we
+    # only want to reference one single ti. We must walk the actual DAG,
+    # and "ti_count == ancestor_ti_count" does not work, since the further
+    # expansion may be of length 1.
+    if not _is_further_mapped_inside(relative, common_ancestor):
+        return ancestor_map_index
+
+    # Otherwise we need a partial aggregation for values from selected task
+    # instances in the ancestor's expansion context.
+    further_count = ti_count // ancestor_ti_count
+    map_index_start = ancestor_map_index * further_count
+    return range(map_index_start, map_index_start + further_count)
+
+
+def find_relevant_relatives(
+    normal_tasks: Iterable[str],
+    mapped_tasks: Iterable[tuple[str, int]],
+    *,
+    direction: Literal["upstream", "downstream"],
+    dag: SerializedDAG,
+    run_id: str,
+    session: Session,
+) -> Collection[str | tuple[str, int]]:
+    from airflow.models.mappedoperator import get_mapped_ti_count
+
+    visited: set[str | tuple[str, int]] = set()
+
+    def _visit_relevant_relatives_for_normal(task_ids: Iterable[str]) -> None:
+        partial_dag = dag.partial_subset(
+            task_ids=task_ids,
+            include_downstream=direction == "downstream",
+            include_upstream=direction == "upstream",
+            exclude_original=True,
+        )
+        visited.update(partial_dag.task_dict)
+
+    def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str, 
int]]) -> None:
+        for task_id, map_index in mapped_tasks:
+            task = dag.get_task(task_id)
+            ti_count = get_mapped_ti_count(task, run_id, session=session)
+            # TODO (GH-52141): This should return scheduler operator types, but
+            # currently get_flat_relatives is inherited from SDK DAGNode.
+            relatives = cast("Iterable[Operator]", 
task.get_flat_relatives(upstream=direction == "upstream"))
+            for relative in relatives:
+                if relative.task_id in visited:
+                    continue
+                relative_map_indexes = _get_relevant_map_indexes(
+                    task=task,
+                    relative=relative,  # type: ignore[arg-type]
+                    run_id=run_id,
+                    map_index=map_index,
+                    ti_count=ti_count,
+                    session=session,
+                )
+                visiting_mapped: set[tuple[str, int]] = set()
+                visiting_normal: set[str] = set()
+                match relative_map_indexes:
+                    case int():
+                        if (item := (relative.task_id, relative_map_indexes)) 
not in visited:
+                            visiting_mapped.add(item)
+                    case range():
+                        visiting_mapped.update((relative.task_id, i) for i in 
relative_map_indexes)
+                    case None:
+                        if (task_id := relative.task_id) not in visited:
+                            visiting_normal.add(task_id)
+                _visit_relevant_relatives_for_normal(visiting_normal)
+                _visit_relevant_relatives_for_mapped(visiting_mapped)
+                visited.update(visiting_mapped, visiting_normal)
+
+    _visit_relevant_relatives_for_normal(normal_tasks)
+    _visit_relevant_relatives_for_mapped(mapped_tasks)
+    return visited
+
+
 class TaskInstanceNote(Base):
     """For storage of arbitrary notes concerning the task instance."""
 
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 6f743cd7d4b..41b81a3c65d 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -2536,6 +2536,27 @@ class 
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
                 4,
                 id="clear mapped tasks with and without map index",
             ),
+            pytest.param(
+                "example_task_group_mapping",
+                [
+                    {
+                        "state": State.FAILED,
+                        "map_indexes": (0, 1, 2),
+                    },
+                    {
+                        "state": State.FAILED,
+                        "map_indexes": (0, 1, 2),
+                    },
+                ],
+                "example_task_group_mapping",
+                {
+                    "task_ids": [["op.mul_2", 0]],
+                    "dag_run_id": "TEST_DAG_RUN_ID",
+                    "include_upstream": True,
+                },
+                2,
+                id="clear tasks in mapped task group",
+            ),
         ],
     )
     def test_should_respond_200(
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 4c7a0d9fd5c..441d81de2b5 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -51,6 +51,7 @@ from airflow.models.taskinstance import (
     TaskInstance,
     TaskInstance as TI,
     TaskInstanceNote,
+    find_relevant_relatives,
 )
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.models.taskmap import TaskMap
@@ -3101,3 +3102,56 @@ def 
test_delete_dagversion_restricted_when_taskinstance_exists(dag_maker, sessio
     session.delete(version)
     with pytest.raises(IntegrityError):
         session.commit()
+
+
[email protected](
+    ("normal_tasks", "mapped_tasks", "expected"),
+    [
+        # 4 is just a regular task so it depends on all its upstreams.
+        pytest.param(["4"], [], {"1", "2", "3"}, id="nonmapped"),
+        # 3 is a mapped; it depends on all tis of the mapped upstream 2.
+        pytest.param(["3"], [], {"1", "2"}, id="mapped-whole"),
+        # Every ti of a mapped task depends on all tis of the mapped upstream.
+        pytest.param([], [("3", 1)], {"1", "2"}, id="mapped-one"),
+        # Same as the (non-group) unmapped case, d depends on all upstreams.
+        pytest.param(["d"], [], {"a", "b", "c"}, id="group-nonmapped"),
+        # This specifies c tis in ALL mapped task groups, so all b tis are 
needed.
+        pytest.param(["c"], [], {"a", "b"}, id="group-mapped-whole"),
+        # This only specifies one c ti, so only one b ti from the same mapped 
instance is returned.
+        pytest.param([], [("c", 1)], {"a", ("b", 1)}, id="group-mapped-one"),
+    ],
+)
+def test_find_relevant_relatives(dag_maker, session, normal_tasks, 
mapped_tasks, expected):
+    # 1 -> 2[] -> 3[] -> 4
+    #
+    # a -> " b --> c " -> d
+    #      "== g[] =="
+    with dag_maker(session=session) as dag:
+        t1 = EmptyOperator(task_id="1")
+        t2 = MockOperator.partial(task_id="2").expand(arg1=["x", "y"])
+        t3 = MockOperator.partial(task_id="3").expand(arg1=["x", "y"])
+        t4 = EmptyOperator(task_id="4")
+        t1 >> t2 >> t3 >> t4
+
+        ta = EmptyOperator(task_id="a")
+
+        @task_group(prefix_group_id=False)
+        def g(v):
+            tb = MockOperator(task_id="b", arg1=v)
+            tc = MockOperator(task_id="c", arg1=v)
+            tb >> tc
+
+        td = EmptyOperator(task_id="d")
+        ta >> g.expand(v=["x", "y", "z"]) >> td
+
+    dr = dag_maker.create_dagrun(state="success")
+
+    result = find_relevant_relatives(
+        normal_tasks=normal_tasks,
+        mapped_tasks=mapped_tasks,
+        direction="upstream",
+        dag=dag,
+        run_id=dr.run_id,
+        session=session,
+    )
+    assert result == expected

Reply via email to