This is an automated email from the ASF dual-hosted git repository.
ephraimanierobi pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-1-test by this push:
new 3a31d65e86e Find only relevant up/downstream tis when clearing
(#57758) (#58987)
3a31d65e86e is described below
commit 3a31d65e86e768dab7f9a18abe70c6788121090f
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Dec 3 14:53:26 2025 +0800
Find only relevant up/downstream tis when clearing (#57758) (#58987)
(cherry picked from commit ad953da43d6c84d90e811a225e71e3ddc8a63dd5)
---
.../core_api/routes/public/task_instances.py | 60 +++--
.../example_dags/example_dynamic_task_mapping.py | 32 ++-
airflow-core/src/airflow/models/taskinstance.py | 246 ++++++++++++++-------
.../core_api/routes/public/test_task_instances.py | 21 ++
.../tests/unit/models/test_taskinstance.py | 54 +++++
5 files changed, 308 insertions(+), 105 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 68764b5456a..82e7fad3493 100644
---
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -729,32 +729,54 @@ def post_clear_task_instances(
if future:
body.end_date = None
- task_ids = body.task_ids
- if task_ids is not None:
- tasks = set(task_ids)
- mapped_tasks_tuples = set(t for t in tasks if isinstance(t, tuple))
+ if (task_markers_to_clear := body.task_ids) is not None:
+ mapped_tasks_tuples = {t for t in task_markers_to_clear if
isinstance(t, tuple)}
# Unmapped tasks are expressed in their task_ids (without map_indexes)
- unmapped_task_ids = set(t for t in tasks if not isinstance(t, tuple))
-
- if upstream or downstream:
- mapped_task_ids = set(tid for tid, _ in mapped_tasks_tuples)
- relatives = dag.partial_subset(
- task_ids=unmapped_task_ids | mapped_task_ids,
- include_downstream=downstream,
- include_upstream=upstream,
- exclude_original=True,
- )
- unmapped_task_ids = unmapped_task_ids |
set(relatives.task_dict.keys())
+ normal_task_ids = {t for t in task_markers_to_clear if not
isinstance(t, tuple)}
+
+ def _collect_relatives(run_id: str, direction: Literal["upstream",
"downstream"]) -> None:
+ from airflow.models.taskinstance import find_relevant_relatives
- mapped_tasks_list = [
- (tid, map_id) for tid, map_id in mapped_tasks_tuples if tid not in
unmapped_task_ids
+ relevant_relatives = find_relevant_relatives(
+ normal_task_ids,
+ mapped_tasks_tuples,
+ dag=dag,
+ run_id=run_id,
+ direction=direction,
+ session=session,
+ )
+ normal_task_ids.update(t for t in relevant_relatives if not
isinstance(t, tuple))
+ mapped_tasks_tuples.update(t for t in relevant_relatives if
isinstance(t, tuple))
+
+ # We can't easily calculate upstream/downstream map indexes when not
+ # working for a specific dag run. It's possible by looking at the runs
+ # one by one, but that is both resource-consuming and logically
complex.
+ # So instead we'll just clear all the tis based on task ID and hope
+ # that's good enough for most cases.
+ if dag_run_id is None:
+ if upstream or downstream:
+ partial_dag = dag.partial_subset(
+ task_ids=normal_task_ids.union(tid for tid, _ in
mapped_tasks_tuples),
+ include_downstream=downstream,
+ include_upstream=upstream,
+ exclude_original=True,
+ )
+ normal_task_ids.update(partial_dag.task_dict)
+ else:
+ if upstream:
+ _collect_relatives(dag_run_id, "upstream")
+ if downstream:
+ _collect_relatives(dag_run_id, "downstream")
+
+ task_markers_to_clear = [
+ *normal_task_ids,
+ *((t, m) for t, m in mapped_tasks_tuples if t not in
normal_task_ids),
]
- task_ids = mapped_tasks_list + list(unmapped_task_ids)
# Prepare common parameters
common_params = {
"dry_run": True,
- "task_ids": task_ids,
+ "task_ids": task_markers_to_clear,
"session": session,
"run_on_latest_version": body.run_on_latest_version,
"only_failed": body.only_failed,
diff --git
a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
index 750c3da1ec1..c7b3a02301d 100644
--- a/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
+++ b/airflow-core/src/airflow/example_dags/example_dynamic_task_mapping.py
@@ -22,9 +22,9 @@ from __future__ import annotations
# [START example_dynamic_task_mapping]
from datetime import datetime
-from airflow.sdk import DAG, task
+from airflow.sdk import DAG, task, task_group
-with DAG(dag_id="example_dynamic_task_mapping", schedule=None,
start_date=datetime(2022, 3, 4)) as dag:
+with DAG(dag_id="example_dynamic_task_mapping", schedule=None,
start_date=datetime(2022, 3, 4)):
@task
def add_one(x: int):
@@ -39,8 +39,11 @@ with DAG(dag_id="example_dynamic_task_mapping",
schedule=None, start_date=dateti
sum_it(added_values)
with DAG(
- dag_id="example_task_mapping_second_order", schedule=None, catchup=False,
start_date=datetime(2022, 3, 4)
-) as dag2:
+ dag_id="example_task_mapping_second_order",
+ schedule=None,
+ catchup=False,
+ start_date=datetime(2022, 3, 4),
+):
@task
def get_nums():
@@ -58,4 +61,25 @@ with DAG(
_times_2 = times_2.expand(num=_get_nums)
add_10.expand(num=_times_2)
+with DAG(
+ dag_id="example_task_group_mapping",
+ schedule=None,
+ catchup=False,
+ start_date=datetime(2022, 3, 4),
+):
+
+ @task_group
+ def op(num):
+ @task
+ def add_1(num):
+ return num + 1
+
+ @task
+ def mul_2(num):
+ return num * 2
+
+ return mul_2(add_1(num))
+
+ op.expand(num=[1, 2, 3])
+
# [END example_dynamic_task_mapping]
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 216bc9b1bca..06597a24fc9 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -28,7 +28,7 @@ from collections import defaultdict
from collections.abc import Collection, Iterable
from datetime import timedelta
from functools import cache
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
from urllib.parse import quote
import attrs
@@ -121,7 +121,7 @@ if TYPE_CHECKING:
from airflow.sdk.definitions.asset import AssetUniqueKey
from airflow.sdk.types import RuntimeTaskInstanceProtocol
from airflow.serialization.definitions.taskgroup import SerializedTaskGroup
- from airflow.serialization.serialized_objects import SerializedBaseOperator
+ from airflow.serialization.serialized_objects import
SerializedBaseOperator, SerializedDAG
from airflow.utils.context import Context
Operator: TypeAlias = MappedOperator | SerializedBaseOperator
@@ -2037,87 +2037,16 @@ class TaskInstance(Base, LoggingMixin):
*,
session: Session,
) -> int | range | None:
- """
- Infer the map indexes of an upstream "relevant" to this ti.
-
- The bulk of the logic mainly exists to solve the problem described by
- the following example, where 'val' must resolve to different values,
- depending on where the reference is being used::
-
- @task
- def this_task(v): # This is self.task.
- return v * 2
-
-
- @task_group
- def tg1(inp):
- val = upstream(inp) # This is the upstream task.
- this_task(val) # When inp is 1, val here should resolve to 2.
- return val
-
-
- # This val is the same object returned by tg1.
- val = tg1.expand(inp=[1, 2, 3])
-
-
- @task_group
- def tg2(inp):
- another_task(inp, val) # val here should resolve to [2, 4, 6].
-
-
- tg2.expand(inp=["a", "b"])
-
- The surrounding mapped task groups of ``upstream`` and ``self.task``
are
- inspected to find a common "ancestor". If such an ancestor is found,
- we need to return specific map indexes to pull a partial value from
- upstream XCom.
-
- :param upstream: The referenced upstream task.
- :param ti_count: The total count of task instance this task was
expanded
- by the scheduler, i.e. ``expanded_ti_count`` in the template
context.
- :return: Specific map index or map indexes to pull, or ``None`` if we
- want to "whole" return value (i.e. no mapped task groups involved).
- """
- from airflow.models.mappedoperator import get_mapped_ti_count
-
if TYPE_CHECKING:
- assert self.task is not None
-
- # This value should never be None since we already know the current
task
- # is in a mapped task group, and should have been expanded, despite
that,
- # we need to check that it is not None to satisfy Mypy.
- # But this value can be 0 when we expand an empty list, for that it is
- # necessary to check that ti_count is not 0 to avoid dividing by 0.
- if not ti_count:
- return None
-
- # Find the innermost common mapped task group between the current task
- # If the current task and the referenced task does not have a common
- # mapped task group, the two are in different task mapping contexts
- # (like another_task above), and we should use the "whole" value.
- common_ancestor = _find_common_ancestor_mapped_group(self.task,
upstream)
- if common_ancestor is None:
- return None
-
- # At this point we know the two tasks share a mapped task group, and we
- # should use a "partial" value. Let's break down the mapped ti count
- # between the ancestor and further expansion happened inside it.
-
- ancestor_ti_count = get_mapped_ti_count(common_ancestor, self.run_id,
session=session)
- ancestor_map_index = self.map_index * ancestor_ti_count // ti_count
-
- # If the task is NOT further expanded inside the common ancestor, we
- # only want to reference one single ti. We must walk the actual DAG,
- # and "ti_count == ancestor_ti_count" does not work, since the further
- # expansion may be of length 1.
- if not _is_further_mapped_inside(upstream, common_ancestor):
- return ancestor_map_index
-
- # Otherwise we need a partial aggregation for values from selected task
- # instances in the ancestor's expansion context.
- further_count = ti_count // ancestor_ti_count
- map_index_start = ancestor_map_index * further_count
- return range(map_index_start, map_index_start + further_count)
+ assert self.task
+ return _get_relevant_map_indexes(
+ run_id=self.run_id,
+ map_index=self.map_index,
+ ti_count=ti_count,
+ task=self.task,
+ relative=upstream,
+ session=session,
+ )
def clear_db_references(self, session: Session):
"""
@@ -2207,6 +2136,159 @@ def _is_further_mapped_inside(operator: Operator,
container: SerializedTaskGroup
return False
+def _get_relevant_map_indexes(
+ *,
+ task: Operator,
+ run_id: str,
+ map_index: int,
+ relative: Operator,
+ ti_count: int | None,
+ session: Session,
+) -> int | range | None:
+ """
+ Infer the map indexes of a relative that's "relevant" to this ti.
+
+ The bulk of the logic mainly exists to solve the problem described by
+ the following example, where 'val' must resolve to different values,
+ depending on where the reference is being used::
+
+ @task
+ def this_task(v): # This is self.task.
+ return v * 2
+
+
+ @task_group
+ def tg1(inp):
+ val = upstream(inp) # This is the upstream task.
+ this_task(val) # When inp is 1, val here should resolve to 2.
+ return val
+
+
+ # This val is the same object returned by tg1.
+ val = tg1.expand(inp=[1, 2, 3])
+
+
+ @task_group
+ def tg2(inp):
+ another_task(inp, val) # val here should resolve to [2, 4, 6].
+
+
+ tg2.expand(inp=["a", "b"])
+
+ The surrounding mapped task groups of ``upstream`` and ``task`` are
+ inspected to find a common "ancestor". If such an ancestor is found,
+ we need to return specific map indexes to pull a partial value from
+ upstream XCom.
+
+ The same logic apply for finding downstream tasks.
+
+ :param task: Current task being inspected.
+ :param run_id: Current run ID.
+ :param map_index: Map index of the current task instance.
+ :param relative: The relative task to find relevant map indexes for.
+ :param ti_count: The total count of task instance this task was expanded
+ by the scheduler, i.e. ``expanded_ti_count`` in the template context.
+ :return: Specific map index or map indexes to pull, or ``None`` if we
+ want to "whole" return value (i.e. no mapped task groups involved).
+ """
+ from airflow.models.mappedoperator import get_mapped_ti_count
+
+ # This value should never be None since we already know the current task
+ # is in a mapped task group, and should have been expanded, despite that,
+ # we need to check that it is not None to satisfy Mypy.
+ # But this value can be 0 when we expand an empty list, for that it is
+ # necessary to check that ti_count is not 0 to avoid dividing by 0.
+ if not ti_count:
+ return None
+
+ # Find the innermost common mapped task group between the current task
+ # If the current task and the referenced task does not have a common
+ # mapped task group, the two are in different task mapping contexts
+ # (like another_task above), and we should use the "whole" value.
+ if (common_ancestor := _find_common_ancestor_mapped_group(task, relative))
is None:
+ return None
+
+ # At this point we know the two tasks share a mapped task group, and we
+ # should use a "partial" value. Let's break down the mapped ti count
+ # between the ancestor and further expansion happened inside it.
+
+ ancestor_ti_count = get_mapped_ti_count(common_ancestor, run_id,
session=session)
+ ancestor_map_index = map_index * ancestor_ti_count // ti_count
+
+ # If the task is NOT further expanded inside the common ancestor, we
+ # only want to reference one single ti. We must walk the actual DAG,
+ # and "ti_count == ancestor_ti_count" does not work, since the further
+ # expansion may be of length 1.
+ if not _is_further_mapped_inside(relative, common_ancestor):
+ return ancestor_map_index
+
+ # Otherwise we need a partial aggregation for values from selected task
+ # instances in the ancestor's expansion context.
+ further_count = ti_count // ancestor_ti_count
+ map_index_start = ancestor_map_index * further_count
+ return range(map_index_start, map_index_start + further_count)
+
+
+def find_relevant_relatives(
+ normal_tasks: Iterable[str],
+ mapped_tasks: Iterable[tuple[str, int]],
+ *,
+ direction: Literal["upstream", "downstream"],
+ dag: SerializedDAG,
+ run_id: str,
+ session: Session,
+) -> Collection[str | tuple[str, int]]:
+ from airflow.models.mappedoperator import get_mapped_ti_count
+
+ visited: set[str | tuple[str, int]] = set()
+
+ def _visit_relevant_relatives_for_normal(task_ids: Iterable[str]) -> None:
+ partial_dag = dag.partial_subset(
+ task_ids=task_ids,
+ include_downstream=direction == "downstream",
+ include_upstream=direction == "upstream",
+ exclude_original=True,
+ )
+ visited.update(partial_dag.task_dict)
+
+ def _visit_relevant_relatives_for_mapped(mapped_tasks: Iterable[tuple[str,
int]]) -> None:
+ for task_id, map_index in mapped_tasks:
+ task = dag.get_task(task_id)
+ ti_count = get_mapped_ti_count(task, run_id, session=session)
+ # TODO (GH-52141): This should return scheduler operator types, but
+ # currently get_flat_relatives is inherited from SDK DAGNode.
+ relatives = cast("Iterable[Operator]",
task.get_flat_relatives(upstream=direction == "upstream"))
+ for relative in relatives:
+ if relative.task_id in visited:
+ continue
+ relative_map_indexes = _get_relevant_map_indexes(
+ task=task,
+ relative=relative, # type: ignore[arg-type]
+ run_id=run_id,
+ map_index=map_index,
+ ti_count=ti_count,
+ session=session,
+ )
+ visiting_mapped: set[tuple[str, int]] = set()
+ visiting_normal: set[str] = set()
+ match relative_map_indexes:
+ case int():
+ if (item := (relative.task_id, relative_map_indexes))
not in visited:
+ visiting_mapped.add(item)
+ case range():
+ visiting_mapped.update((relative.task_id, i) for i in
relative_map_indexes)
+ case None:
+ if (task_id := relative.task_id) not in visited:
+ visiting_normal.add(task_id)
+ _visit_relevant_relatives_for_normal(visiting_normal)
+ _visit_relevant_relatives_for_mapped(visiting_mapped)
+ visited.update(visiting_mapped, visiting_normal)
+
+ _visit_relevant_relatives_for_normal(normal_tasks)
+ _visit_relevant_relatives_for_mapped(mapped_tasks)
+ return visited
+
+
class TaskInstanceNote(Base):
"""For storage of arbitrary notes concerning the task instance."""
diff --git
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 6f743cd7d4b..41b81a3c65d 100644
---
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -2536,6 +2536,27 @@ class
TestPostClearTaskInstances(TestTaskInstanceEndpoint):
4,
id="clear mapped tasks with and without map index",
),
+ pytest.param(
+ "example_task_group_mapping",
+ [
+ {
+ "state": State.FAILED,
+ "map_indexes": (0, 1, 2),
+ },
+ {
+ "state": State.FAILED,
+ "map_indexes": (0, 1, 2),
+ },
+ ],
+ "example_task_group_mapping",
+ {
+ "task_ids": [["op.mul_2", 0]],
+ "dag_run_id": "TEST_DAG_RUN_ID",
+ "include_upstream": True,
+ },
+ 2,
+ id="clear tasks in mapped task group",
+ ),
],
)
def test_should_respond_200(
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 4c7a0d9fd5c..441d81de2b5 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -51,6 +51,7 @@ from airflow.models.taskinstance import (
TaskInstance,
TaskInstance as TI,
TaskInstanceNote,
+ find_relevant_relatives,
)
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.models.taskmap import TaskMap
@@ -3101,3 +3102,56 @@ def
test_delete_dagversion_restricted_when_taskinstance_exists(dag_maker, sessio
session.delete(version)
with pytest.raises(IntegrityError):
session.commit()
+
+
[email protected](
+ ("normal_tasks", "mapped_tasks", "expected"),
+ [
+ # 4 is just a regular task so it depends on all its upstreams.
+ pytest.param(["4"], [], {"1", "2", "3"}, id="nonmapped"),
+ # 3 is a mapped; it depends on all tis of the mapped upstream 2.
+ pytest.param(["3"], [], {"1", "2"}, id="mapped-whole"),
+ # Every ti of a mapped task depends on all tis of the mapped upstream.
+ pytest.param([], [("3", 1)], {"1", "2"}, id="mapped-one"),
+ # Same as the (non-group) unmapped case, d depends on all upstreams.
+ pytest.param(["d"], [], {"a", "b", "c"}, id="group-nonmapped"),
+ # This specifies c tis in ALL mapped task groups, so all b tis are
needed.
+ pytest.param(["c"], [], {"a", "b"}, id="group-mapped-whole"),
+ # This only specifies one c ti, so only one b ti from the same mapped
instance is returned.
+ pytest.param([], [("c", 1)], {"a", ("b", 1)}, id="group-mapped-one"),
+ ],
+)
+def test_find_relevant_relatives(dag_maker, session, normal_tasks,
mapped_tasks, expected):
+ # 1 -> 2[] -> 3[] -> 4
+ #
+ # a -> " b --> c " -> d
+ # "== g[] =="
+ with dag_maker(session=session) as dag:
+ t1 = EmptyOperator(task_id="1")
+ t2 = MockOperator.partial(task_id="2").expand(arg1=["x", "y"])
+ t3 = MockOperator.partial(task_id="3").expand(arg1=["x", "y"])
+ t4 = EmptyOperator(task_id="4")
+ t1 >> t2 >> t3 >> t4
+
+ ta = EmptyOperator(task_id="a")
+
+ @task_group(prefix_group_id=False)
+ def g(v):
+ tb = MockOperator(task_id="b", arg1=v)
+ tc = MockOperator(task_id="c", arg1=v)
+ tb >> tc
+
+ td = EmptyOperator(task_id="d")
+ ta >> g.expand(v=["x", "y", "z"]) >> td
+
+ dr = dag_maker.create_dagrun(state="success")
+
+ result = find_relevant_relatives(
+ normal_tasks=normal_tasks,
+ mapped_tasks=mapped_tasks,
+ direction="upstream",
+ dag=dag,
+ run_id=dr.run_id,
+ session=session,
+ )
+ assert result == expected