This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c441e4bf12f Reduce API server memory usage by eliminating
`SerializedDAG` loads on task start (#60803)
c441e4bf12f is described below
commit c441e4bf12f78c8846330fd6a37d02641f64e163
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jan 20 13:24:11 2026 +0000
Reduce API server memory usage by eliminating `SerializedDAG` loads on task
start (#60803)
---
airflow-core/newsfragments/60803.significant.rst | 1 +
.../execution_api/datamodels/taskinstance.py | 2 -
.../execution_api/routes/task_instances.py | 58 +------
.../api_fastapi/execution_api/versions/__init__.py | 7 +-
.../execution_api/versions/v2026_03_31.py | 21 ++-
.../versions/head/test_task_instances.py | 72 +-------
.../versions/v2025_04_28/test_task_instances.py | 67 +-------
devel-common/src/tests_common/pytest_plugin.py | 5 -
.../src/airflow/sdk/api/datamodels/_generated.py | 3 -
.../sdk/definitions/_internal/expandinput.py | 4 +-
task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 23 ++-
.../src/airflow/sdk/execution_time/task_mapping.py | 133 +++++++++++++++
.../src/airflow/sdk/execution_time/task_runner.py | 39 ++++-
.../task_sdk/definitions/test_mappedoperator.py | 147 +++++++++++-----
.../task_sdk/execution_time/test_task_mapping.py | 189 +++++++++++++++++++++
15 files changed, 521 insertions(+), 250 deletions(-)
diff --git a/airflow-core/newsfragments/60803.significant.rst
b/airflow-core/newsfragments/60803.significant.rst
new file mode 100644
index 00000000000..1054258ad38
--- /dev/null
+++ b/airflow-core/newsfragments/60803.significant.rst
@@ -0,0 +1 @@
+Move ``upstream_map_indexes`` computation from API server to Task SDK,
reducing memory usage on task start by eliminating ``SerializedDAG`` loads.
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index e7ebee9ebe7..513a99f6dc9 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -324,8 +324,6 @@ class TIRunContext(BaseModel):
connections: Annotated[list[ConnectionResponse],
Field(default_factory=list)]
"""Connections that can be accessed by the task instance."""
- upstream_map_indexes: dict[str, int | list[int] | None] | None = None
-
next_method: str | None = None
"""Method to call. Set when task resumes from a trigger."""
next_kwargs: dict[str, Any] | str | None = None
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a729a1ee83b..a73145b30ab 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -31,7 +31,7 @@ from cadwyn import VersionedAPIRouter
from fastapi import Body, HTTPException, Query, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
-from sqlalchemy.engine import CursorResult, Row
+from sqlalchemy.engine import CursorResult
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select
@@ -64,14 +64,11 @@ from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dag import DagModel
from airflow.models.dagrun import DagRun as DR
-from airflow.models.expandinput import NotFullyPopulated
from airflow.models.taskinstance import TaskInstance as TI,
_stop_remaining_tasks
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
from airflow.serialization.definitions.assets import SerializedAsset,
SerializedAssetUniqueKey
-from airflow.serialization.definitions.dag import SerializedDAG
-from airflow.task.trigger_rule import TriggerRule
from airflow.utils.sqlalchemy import get_dialect_name
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -251,17 +248,6 @@ def ti_run(
or 0
)
- if dag := dag_bag.get_dag_for_run(dag_run=dr, session=session):
- upstream_map_indexes = dict(
- _get_upstream_map_indexes(
- serialized_dag=dag,
- ti=ti,
- session=session,
- )
- )
- else:
- upstream_map_indexes = None
-
context = TIRunContext(
dag_run=dr,
task_reschedule_count=task_reschedule_count,
@@ -271,7 +257,6 @@ def ti_run(
connections=[],
xcom_keys_to_clear=xcom_keys,
should_retry=_is_eligible_to_retry(previous_state, ti.try_number,
ti.max_tries),
- upstream_map_indexes=upstream_map_indexes,
)
# Only set if they are non-null
@@ -287,47 +272,6 @@ def ti_run(
)
-def _get_upstream_map_indexes(
- *,
- serialized_dag: SerializedDAG,
- ti: TI | Row,
- session: SessionDep,
-) -> Iterator[tuple[str, int | list[int] | None]]:
- task = serialized_dag.get_task(ti.task_id)
- for upstream_task in task.upstream_list:
- map_indexes: int | list[int] | None
- if (upstream_mapped_group :=
upstream_task.get_closest_mapped_task_group()) is None:
- # regular tasks or non-mapped task groups
- map_indexes = None
- elif task.get_closest_mapped_task_group() is upstream_mapped_group:
- # tasks in the same mapped task group hierarchy
- map_indexes = ti.map_index
- else:
- # tasks not in the same mapped task group
- # the upstream mapped task group should combine the return xcom as
a list and return it
- mapped_ti_count: int | None = None
-
- try:
- # First try: without resolving XCom
- mapped_ti_count =
upstream_mapped_group.get_parse_time_mapped_ti_count()
- except NotFullyPopulated:
- # Second try: resolve XCom for correct count
- try:
- expand_input = upstream_mapped_group._expand_input
- mapped_ti_count =
expand_input.get_total_map_length(ti.run_id, session=session)
- except NotFullyPopulated:
- # For these trigger rules, unresolved map indexes are
acceptable.
- # The success of the upstream task is not the main reason
for triggering the current task.
- # Therefore, whether the upstream task is fully populated
can be ignored.
- if task.trigger_rule != TriggerRule.ALL_SUCCESS:
- mapped_ti_count = None
-
- # Compute map indexes if we have a valid count
- map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is
not None else None
-
- yield upstream_task.task_id, map_indexes
-
-
@ti_id_router.patch(
"/{task_instance_id}/state",
status_code=status.HTTP_204_NO_CONTENT,
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index 36c1b31b959..30d4159f745 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -33,11 +33,14 @@ from airflow.api_fastapi.execution_api.versions.v2025_12_08
import (
AddDagRunDetailEndpoint,
MovePreviousRunEndpoint,
)
-from airflow.api_fastapi.execution_api.versions.v2026_03_31 import
ModifyDeferredTaskKwargsToJsonValue
+from airflow.api_fastapi.execution_api.versions.v2026_03_31 import (
+ ModifyDeferredTaskKwargsToJsonValue,
+ RemoveUpstreamMapIndexesField,
+)
bundle = VersionBundle(
HeadVersion(),
- Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue),
+ Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue,
RemoveUpstreamMapIndexesField),
Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint),
Version("2025-11-07", AddPartitionKeyField),
Version("2025-11-05", AddTriggeringUserNameField),
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
index 48630e0b50c..72e193426da 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
@@ -19,9 +19,9 @@ from __future__ import annotations
from typing import Any
-from cadwyn import VersionChange, schema
+from cadwyn import ResponseInfo, VersionChange,
convert_response_to_previous_version_for, schema
-from airflow.api_fastapi.execution_api.datamodels.taskinstance import
TIDeferredStatePayload
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import
TIDeferredStatePayload, TIRunContext
class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
@@ -33,3 +33,20 @@ class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
schema(TIDeferredStatePayload).field("trigger_kwargs").had(type=dict[str, Any]
| str),
schema(TIDeferredStatePayload).field("next_kwargs").had(type=dict[str,
Any]),
)
+
+
+class RemoveUpstreamMapIndexesField(VersionChange):
+ """Remove upstream_map_indexes field from TIRunContext - now computed by
Task SDK."""
+
+ description = __doc__
+
+ instructions_to_migrate_to_previous_version = (
+ schema(TIRunContext)
+ .field("upstream_map_indexes")
+ .existed_as(type=dict[str, int | list[int] | None] | None),
+ )
+
+ @convert_response_to_previous_version_for(TIRunContext) # type:
ignore[arg-type]
+ def add_upstream_map_indexes_field(response: ResponseInfo) -> None: #
type: ignore[misc]
+ """Add upstream_map_indexes field with None for older API versions."""
+ response.body["upstream_map_indexes"] = None
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index e99b7d9380f..cfee6c9d46c 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -202,7 +202,6 @@ class TestTIRunState:
"partition_key": None,
},
"task_reschedule_count": 0,
- "upstream_map_indexes": {},
"max_tries": max_tries,
"should_retry": should_retry,
"variables": [],
@@ -249,10 +248,7 @@ class TestTIRunState:
assert response.status_code == 409
def test_dynamic_task_mapping_with_parse_time_value(self, client,
dag_maker):
- """
- Test that the Task Instance upstream_map_indexes is correctly fetched
when to running the Task Instances
- """
-
+ """Test that dynamic task mapping works correctly with parse-time
values."""
with dag_maker("test_dynamic_task_mapping_with_parse_time_value",
serialized=True):
@task_group
@@ -278,23 +274,6 @@ class TestTIRunState:
ti.set_state(State.QUEUED)
dag_maker.session.flush()
- # key: (task_id, map_index)
- # value: result upstream_map_indexes ({task_id: map_indexes})
- expected_upstream_map_indexes = {
- # no upstream task for task_group_1.group_task_1
- ("task_group_1.group1_task_1", 0): {},
- ("task_group_1.group1_task_1", 1): {},
- # the upstream task for task_group_1.group_task_2 is
task_group_1.group_task_2
- # since they are in the same task group, the upstream map index
should be the same as the task
- ("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1":
0},
- ("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1":
1},
- # the upstream task for task2 is the last tasks of task_group_1,
which is
- # task_group_1.group_task_2
- # since they are not in the same task group, the upstream map
index should include all the
- # expanded tasks
- ("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
- }
-
for ti in dr.get_task_instances():
response = client.patch(
f"/execution/task-instances/{ti.id}/run",
@@ -308,13 +287,9 @@ class TestTIRunState:
)
assert response.status_code == 200
- upstream_map_indexes = response.json()["upstream_map_indexes"]
- assert upstream_map_indexes ==
expected_upstream_map_indexes[(ti.task_id, ti.map_index)]
- def test_nested_mapped_task_group_upstream_indexes(self, client,
dag_maker):
- """
- Test that upstream_map_indexes are correctly computed for tasks in
nested mapped task groups.
- """
+ def test_nested_mapped_task_group(self, client, dag_maker):
+ """Test that nested mapped task groups work correctly."""
with dag_maker("test_nested_mapped_tg", serialized=True):
@task
@@ -346,25 +321,11 @@ class TestTIRunState:
ti.set_state(State.QUEUED)
dag_maker.session.flush()
- # Expected upstream_map_indexes for each print_task instance
- expected_upstream_map_indexes = {
- ("expandable_task_group.inner_task_group.print_task", 0): {
- "expandable_task_group.inner_task_group.alter_input": 0
- },
- ("expandable_task_group.inner_task_group.print_task", 1): {
- "expandable_task_group.inner_task_group.alter_input": 1
- },
- ("expandable_task_group.inner_task_group.print_task", 2): {
- "expandable_task_group.inner_task_group.alter_input": 2
- },
- }
-
# Get only the expanded print_task instances (not the template)
print_task_tis = [
ti for ti in dr.get_task_instances() if "print_task" in ti.task_id
and ti.map_index >= 0
]
- # Test each print_task instance
for ti in print_task_tis:
response = client.patch(
f"/execution/task-instances/{ti.id}/run",
@@ -378,18 +339,9 @@ class TestTIRunState:
)
assert response.status_code == 200
- upstream_map_indexes = response.json()["upstream_map_indexes"]
- expected = expected_upstream_map_indexes[(ti.task_id,
ti.map_index)]
-
- assert upstream_map_indexes == expected, (
- f"Task {ti.task_id}[{ti.map_index}] should have
upstream_map_indexes {expected}, "
- f"but got {upstream_map_indexes}"
- )
def test_dynamic_task_mapping_with_xcom(self, client: Client, dag_maker:
DagMaker, session: Session):
- """
- Test that the Task Instance upstream_map_indexes is correctly fetched
when to running the Task Instances with xcom
- """
+ """Test that dynamic task mapping works correctly with XCom values."""
from airflow.models.taskmap import TaskMap
with dag_maker(session=session, serialized=True):
@@ -442,13 +394,10 @@ class TestTIRunState:
"start_date": "2024-09-30T12:00:00Z",
},
)
- assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1,
2, 3, 4, 5]}
+ assert response.status_code == 200
def test_dynamic_task_mapping_with_all_success_trigger_rule(self,
dag_maker: DagMaker, session: Session):
- """
- Test that the Task Instance upstream_map_indexes is not populuated but
- the downstream task should not be run.
- """
+ """Test that with ALL_SUCCESS trigger rule and skipped upstream,
downstream should not run."""
with dag_maker(session=session, serialized=True):
@@ -504,10 +453,7 @@ class TestTIRunState:
def test_dynamic_task_mapping_with_non_all_success_trigger_rule(
self, client: Client, dag_maker: DagMaker, session: Session,
trigger_rule: TriggerRule
):
- """
- Test that the Task Instance upstream_map_indexes is not populuated but
- the downstream task should still be run due to trigger rule.
- """
+ """Test that with non-ALL_SUCCESS trigger rule, downstream task should
still run."""
with dag_maker(session=session, serialized=True):
@@ -564,7 +510,7 @@ class TestTIRunState:
"start_date": "2024-09-30T12:00:00Z",
},
)
- assert response.json()["upstream_map_indexes"] == {"tg.task_2": None}
+ assert response.status_code == 200
def test_next_kwargs_still_encoded(self, client, session,
create_task_instance, time_machine):
instant_str = "2024-09-30T12:00:00Z"
@@ -615,7 +561,6 @@ class TestTIRunState:
assert response.json() == {
"dag_run": mock.ANY,
"task_reschedule_count": 0,
- "upstream_map_indexes": {},
"max_tries": 0,
"should_retry": False,
"variables": [],
@@ -687,7 +632,6 @@ class TestTIRunState:
assert response.json() == {
"dag_run": mock.ANY,
"task_reschedule_count": 0,
- "upstream_map_indexes": {},
"max_tries": 0,
"should_retry": False,
"variables": [],
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
index 16371b67d04..efa29338d92 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
@@ -17,8 +17,6 @@
from __future__ import annotations
-from unittest.mock import patch
-
import pytest
from airflow._shared.timezones import timezone
@@ -50,52 +48,19 @@ class TestTIUpdateState:
clear_db_assets()
clear_db_runs()
- @pytest.mark.parametrize(
- ("mock_indexes", "expected_response_indexes"),
- [
- pytest.param(
- [("task_a", 5), ("task_b", 10)],
- {"task_a": 5, "task_b": 10},
- id="plain ints",
- ),
- pytest.param(
- [("task_a", [3, 4]), ("task_b", [9])],
- {"task_a": 3, "task_b": 9},
- id="list of ints",
- ),
- pytest.param(
- [
- ("task_a", None),
- ],
- {"task_a": None},
- id="task has no upstreams",
- ),
- pytest.param(
- [("task_a", None), ("task_b", [6, 7]), ("task_c", 2)],
- {"task_a": None, "task_b": 6, "task_c": 2},
- id="mixed types",
- ),
- ],
- )
-
@patch("airflow.api_fastapi.execution_api.routes.task_instances._get_upstream_map_indexes")
def test_ti_run(
self,
- mock_get_upstream_map_indexes,
ver_client,
session,
create_task_instance,
time_machine,
- mock_indexes,
- expected_response_indexes,
get_execution_app,
):
"""
Test that this version of the endpoint works.
- Later versions modified the type of upstream_map_indexes.
+ upstream_map_indexes is now always None as it's computed by the Task
SDK.
"""
- mock_get_upstream_map_indexes.return_value = mock_indexes
-
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)
time_machine.move_to(instant, tick=False)
@@ -124,26 +89,10 @@ class TestTIUpdateState:
)
assert response.status_code == 200
- assert response.json() == {
- "dag_run": {
- "dag_id": "dag",
- "run_id": "test",
- "clear_number": 0,
- "logical_date": instant_str,
- "data_interval_start":
instant.subtract(days=1).to_iso8601_string(),
- "data_interval_end": instant_str,
- "run_after": instant_str,
- "start_date": instant_str,
- "end_date": None,
- "run_type": "manual",
- "conf": {},
- "consumed_asset_events": [],
- },
- "task_reschedule_count": 0,
- "upstream_map_indexes": expected_response_indexes,
- "max_tries": 0,
- "should_retry": False,
- "variables": [],
- "connections": [],
- "xcom_keys_to_clear": [],
- }
+ result = response.json()
+ # upstream_map_indexes is now computed by SDK, server returns None
+ assert result["upstream_map_indexes"] is None
+ assert result["dag_run"]["dag_id"] == "dag"
+ assert result["task_reschedule_count"] == 0
+ assert result["max_tries"] == 0
+ assert result["should_retry"] is False
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index bc18c94f17b..258d3345706 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2457,7 +2457,6 @@ def create_runtime_ti(mocked_parse):
run_type: str = "manual",
try_number: int = 1,
map_index: int | None = -1,
- upstream_map_indexes: dict[str, int | list[int] | None] | None = None,
task_reschedule_count: int = 0,
ti_id: UUID | None = None,
conf: dict[str, Any] | None = None,
@@ -2532,12 +2531,8 @@ def create_runtime_ti(mocked_parse):
task_reschedule_count=task_reschedule_count,
max_tries=task_retries if max_tries is None else max_tries,
should_retry=should_retry if should_retry is not None else
try_number <= task_retries,
- upstream_map_indexes=upstream_map_indexes,
)
- if upstream_map_indexes is not None:
- ti_context.upstream_map_indexes = upstream_map_indexes
-
compat_fields = {
"requests_fd": 0,
"sentry_integration": "",
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 6a3a07f5e8a..149aa5d7dcb 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -631,9 +631,6 @@ class TIRunContext(BaseModel):
max_tries: Annotated[int, Field(title="Max Tries")]
variables: Annotated[list[VariableResponse] | None,
Field(title="Variables")] = None
connections: Annotated[list[ConnectionResponse] | None,
Field(title="Connections")] = None
- upstream_map_indexes: Annotated[
- dict[str, int | list[int] | None] | None, Field(title="Upstream Map
Indexes")
- ] = None
next_method: Annotated[str | None, Field(title="Next Method")] = None
next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next
Kwargs")] = None
xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To
Clear")] = None
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
index fa56ae84650..b6ffbd22142 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
@@ -189,7 +189,9 @@ class DictOfListsExpandInput(ResolveMixin):
if map_index is None or map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without
expanding")
- upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes",
{})
+ # Get pre-computed upstream_map_indexes if available, otherwise
default to empty dict.
+ # When empty, individual XComArgs will compute their map_indexes
lazily in xcom_arg.py.
+ upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes",
None) or {}
# TODO: This initiates one API call for each XComArg. Would it be
# more efficient to do one single call and unpack the value here?
diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
index ef316574df5..88db9009242 100644
--- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
+++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
@@ -30,7 +30,7 @@ from airflow.sdk import TriggerRule
from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
from airflow.sdk.definitions._internal.mixins import DependencyMixin,
ResolveMixin
from airflow.sdk.definitions._internal.setup_teardown import
SetupTeardownContext
-from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set
+from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet,
is_arg_set
from airflow.sdk.exceptions import AirflowException, XComNotFound
from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
from airflow.sdk.execution_time.xcom import BaseXCom
@@ -337,10 +337,25 @@ class PlainXComArg(XComArg):
return LazyXComSequence(xcom_arg=self, ti=ti)
tg = self.operator.get_closest_mapped_task_group()
if tg is None:
- map_indexes = None
+ # No mapped task group - pull from unmapped instance
+ map_indexes: int | range | None | ArgNotSet = None
else:
- upstream_map_indexes = getattr(ti, "_upstream_map_indexes", {})
- map_indexes = upstream_map_indexes.get(task_id, None)
+ # Check for pre-computed value from server (backward compatibility)
+ upstream_map_indexes = getattr(ti, "_upstream_map_indexes", None)
+ if upstream_map_indexes is not None:
+ # Use None as default to match original behavior (filter for
unmapped XCom)
+ map_indexes = upstream_map_indexes.get(task_id, None)
+ else:
+ # Compute lazily - ti_count will be queried if needed
+ cached_context = getattr(ti, "_cached_template_context", None)
+ ti_count = cached_context.get("expanded_ti_count") if
cached_context else None
+ computed = ti.get_relevant_upstream_map_indexes(
+ upstream=self.operator,
+ ti_count=ti_count,
+ session=None, # Not used in SDK implementation
+ )
+ # None means "no filtering needed" -> use NOTSET to pull all
values
+ map_indexes = NOTSET if computed is None else computed
result = ti.xcom_pull(
task_ids=task_id,
key=self.key,
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_mapping.py
b/task-sdk/src/airflow/sdk/execution_time/task_mapping.py
new file mode 100644
index 00000000000..fc54bb794d9
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/execution_time/task_mapping.py
@@ -0,0 +1,133 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utility functions for computing upstream map indexes in the Task SDK."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.sdk.execution_time.comms import GetTICount, TICount
+
+if TYPE_CHECKING:
+ from airflow.sdk import BaseOperator
+ from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
+
+
+def _find_common_ancestor_mapped_group(node1: BaseOperator, node2:
BaseOperator) -> MappedTaskGroup | None:
+ """
+ Given two operators, find their innermost common mapped task group.
+
+ :param node1: First operator
+ :param node2: Second operator
+ :return: The common mapped task group, or None if they don't share one
+ """
+ try:
+ dag1 = node1.dag
+ dag2 = node2.dag
+ except RuntimeError:
+ # Operator not assigned to a DAG
+ return None
+
+ if dag1 is None or dag2 is None or node1.dag_id != node2.dag_id:
+ return None
+ parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
+ common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id
in parent_group_ids)
+ return next(common_groups, None)
+
+
+def _is_further_mapped_inside(operator: BaseOperator, container: TaskGroup) ->
bool:
+ """
+ Whether given operator is *further* mapped inside a task group.
+
+ :param operator: The operator to check
+ :param container: The container task group
+ :return: True if the operator is further mapped inside the container
+ """
+ # Use getattr for compatibility with both SDK and serialized operators
+ if getattr(operator, "is_mapped", False):
+ return True
+ task_group = operator.task_group
+ while task_group is not None and task_group.group_id != container.group_id:
+ if getattr(task_group, "is_mapped", False):
+ return True
+ task_group = task_group.parent_group
+ return False
+
+
+def get_ti_count_for_task(task_id: str, dag_id: str, run_id: str) -> int:
+ """
+ Query TI count for a specific task.
+
+ :param task_id: The task ID
+ :param dag_id: The DAG ID
+ :param run_id: The run ID
+ :return: The count of task instances for the task
+ """
+ # Import here because SUPERVISOR_COMMS is set at runtime, not import time
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ response = SUPERVISOR_COMMS.send(GetTICount(dag_id=dag_id,
task_ids=[task_id], run_ids=[run_id]))
+ if not isinstance(response, TICount):
+ raise RuntimeError(f"Unexpected response type: {type(response)}")
+ return response.count
+
+
+def get_relevant_map_indexes(
+ task: BaseOperator,
+ run_id: str,
+ map_index: int,
+ ti_count: int,
+ relative: BaseOperator,
+ dag_id: str,
+) -> int | range | None:
+ """
+ Determine map indexes for XCom aggregation.
+
+ This is used to figure out which specific map indexes of an upstream task
+ are relevant when resolving XCom values for a task in a mapped task group.
+
+ :param task: The current task
+ :param run_id: The current run ID
+ :param map_index: The map index of the current task instance
+ :param ti_count: The total count of task instances for the current task
+ :param relative: The upstream/downstream task to find relevant map indexes
for
+ :param dag_id: The DAG ID
+ :return: None (use entire value), int (single index), or range (subset of
indexes)
+ """
+ if not ti_count:
+ return None
+
+ common_ancestor = _find_common_ancestor_mapped_group(task, relative)
+ if common_ancestor is None or common_ancestor.group_id is None:
+ return None # Different mapping contexts → use whole value
+
+ # Query TI count using the current task, which is in the mapped task group.
+ # This gives us the number of expansion iterations, not total TIs in the
group.
+ ancestor_ti_count = get_ti_count_for_task(task.task_id, dag_id, run_id)
+ if not ancestor_ti_count:
+ return None
+
+ ancestor_map_index = map_index * ancestor_ti_count // ti_count
+
+ if not _is_further_mapped_inside(relative, common_ancestor):
+ return ancestor_map_index # Single index
+
+ # Partial aggregation for selected TIs
+ further_count = ti_count // ancestor_ti_count
+ map_index_start = ancestor_map_index * further_count
+ return range(map_index_start, map_index_start + further_count)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 60577066998..1cfc2c82d45 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -268,10 +268,10 @@ class RuntimeTaskInstance(TaskInstance):
}
)
- if from_server.upstream_map_indexes is not None:
- # We stash this in here for later use, but we purposefully
don't want to document it's
- # existence. Should this be a private attribute on RuntimeTI
instead perhaps?
- setattr(self, "_upstream_map_indexes",
from_server.upstream_map_indexes)
+ # Backward compatibility: old servers may still send
upstream_map_indexes
+ upstream_map_indexes = getattr(from_server,
"upstream_map_indexes", None)
+ if upstream_map_indexes is not None:
+ setattr(self, "_upstream_map_indexes", upstream_map_indexes)
return self._cached_template_context
@@ -436,8 +436,35 @@ class RuntimeTaskInstance(TaskInstance):
def get_relevant_upstream_map_indexes(
self, upstream: BaseOperator, ti_count: int | None, session: Any
) -> int | range | None:
- # TODO: Implement this method
- return None
+ """
+ Compute the relevant upstream map indexes for XCom resolution.
+
+ :param upstream: The upstream operator
+ :param ti_count: The total count of task instances for this task's
expansion
+ :param session: Not used (kept for API compatibility)
+ :return: None (use entire value), int (single index), or range (subset
of indexes)
+ """
+ from airflow.sdk.execution_time.task_mapping import
get_relevant_map_indexes, get_ti_count_for_task
+
+ map_index = self.map_index
+ if map_index is None or map_index < 0:
+ return None
+
+ # If ti_count not provided, we need to query it
+ if ti_count is None:
+ ti_count = get_ti_count_for_task(self.task_id, self.dag_id,
self.run_id)
+
+ if not ti_count:
+ return None
+
+ return get_relevant_map_indexes(
+ task=self.task,
+ run_id=self.run_id,
+ map_index=map_index,
+ ti_count=ti_count,
+ relative=upstream,
+ dag_id=self.dag_id,
+ )
def get_first_reschedule_date(self, context: Context) -> AwareDatetime |
None:
"""Get the first reschedule date for the task instance if found, none
otherwise."""
diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
index 4d0e52e8f2f..85b6e9c0b27 100644
--- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
+++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
@@ -31,7 +31,15 @@ from airflow.sdk.bases.xcom import BaseXCom
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.mappedoperator import MappedOperator
from airflow.sdk.definitions.xcom_arg import XComArg
-from airflow.sdk.execution_time.comms import GetXCom, SetXCom, XComResult
+from airflow.sdk.execution_time.comms import (
+ GetTICount,
+ GetXCom,
+ GetXComSequenceSlice,
+ SetXCom,
+ TICount,
+ XComResult,
+ XComSequenceSliceResult,
+)
from tests_common.test_utils.mapping import expand_mapped_task # noqa: F401
from tests_common.test_utils.mock_operators import (
@@ -252,9 +260,18 @@ def test_mapped_render_template_fields_validating_operator(
)
mapped = callable(mapped, task1.output)
- mock_supervisor_comms.send.return_value =
XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=["{{ ds }}"])
+ def mock_comms(msg):
+ if isinstance(msg, GetXCom):
+ return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=["{{ ds }}"])
+ if isinstance(msg, GetXComSequenceSlice):
+ return XComSequenceSliceResult(root=["{{ ds }}"])
+ if isinstance(msg, GetTICount):
+ return TICount(count=1)
+ return mock.DEFAULT
+
+ mock_supervisor_comms.send.side_effect = mock_comms
- mapped_ti = create_runtime_ti(task=mapped, map_index=0,
upstream_map_indexes={task1.task_id: 1})
+ mapped_ti = create_runtime_ti(task=mapped, map_index=0)
assert isinstance(mapped_ti.task, MappedOperator)
mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context())
@@ -273,13 +290,13 @@ def
test_mapped_render_nested_template_fields(create_runtime_ti, mock_supervisor
task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}",
field_2="value_2")
).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]])
- ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={})
+ ti = create_runtime_ti(task=mapped, map_index=0)
ti.task.render_template_fields(context=ti.get_template_context())
assert ti.task.arg1 == "t"
assert ti.task.arg2.field_1 == "t"
assert ti.task.arg2.field_2 == "value_2"
- ti = create_runtime_ti(task=mapped, map_index=1, upstream_map_indexes={})
+ ti = create_runtime_ti(task=mapped, map_index=1)
ti.task.render_template_fields(context=ti.get_template_context())
assert ti.task.arg1 == ["s", "t"]
assert ti.task.arg2.field_1 == "t"
@@ -300,11 +317,20 @@ def
test_expand_kwargs_render_template_fields_validating_operator(
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id
}}").expand_kwargs(task1.output)
- mock_supervisor_comms.send.return_value = XComResult(
- key=BaseXCom.XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}]
- )
+ xcom_values = [{"arg1": "{{ ds }}"}, {"arg1": 2}]
+
+ def mock_comms(msg):
+ if isinstance(msg, GetXCom):
+ return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=xcom_values)
+ if isinstance(msg, GetXComSequenceSlice):
+ return XComSequenceSliceResult(root=xcom_values)
+ if isinstance(msg, GetTICount):
+ return TICount(count=2)
+ return mock.DEFAULT
+
+ mock_supervisor_comms.send.side_effect = mock_comms
- ti = create_runtime_ti(task=mapped, map_index=map_index,
upstream_map_indexes={})
+ ti = create_runtime_ti(task=mapped, map_index=map_index)
assert isinstance(ti.task, MappedOperator)
ti.task.render_template_fields(context=ti.get_template_context())
assert isinstance(ti.task, MockOperator)
@@ -428,14 +454,29 @@ def test_map_cross_product(run_ti: RunTI,
mock_supervisor_comms):
show.expand(number=emit_numbers(), letter=emit_letters())
- def xcom_get(msg):
- if not isinstance(msg, GetXCom):
- return mock.DEFAULT
- task = dag.get_task(msg.task_id)
- value = task.python_callable()
- return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+ numbers = [1, 2]
+ letters = {"a": "x", "b": "y", "c": "z"}
+
+ def mock_comms(msg):
+ if isinstance(msg, GetXCom):
+ if msg.task_id == "emit_numbers":
+ return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=numbers)
+ if msg.task_id == "emit_letters":
+ return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=letters)
+ elif isinstance(msg, GetXComSequenceSlice):
+ if msg.task_id == "emit_numbers":
+ return XComSequenceSliceResult(root=numbers)
+ if msg.task_id == "emit_letters":
+ # Convert dict items to list for XComSequenceSliceResult
+ return XComSequenceSliceResult(root=list(letters.items()))
+ elif isinstance(msg, GetTICount):
+ # show is mapped by 6 (2 numbers * 3 letters)
+ if msg.task_ids and msg.task_ids[0] == "show":
+ return TICount(count=6)
+ return TICount(count=1)
+ return mock.DEFAULT
- mock_supervisor_comms.send.side_effect = xcom_get
+ mock_supervisor_comms.send.side_effect = mock_comms
states = [run_ti(dag, "show", map_index) for map_index in range(6)]
assert states == [TaskInstanceState.SUCCESS] * 6
@@ -466,14 +507,23 @@ def test_map_product_same(run_ti: RunTI,
mock_supervisor_comms):
emit_task = emit_numbers()
show.expand(a=emit_task, b=emit_task)
- def xcom_get(msg):
- if not isinstance(msg, GetXCom):
- return mock.DEFAULT
- task = dag.get_task(msg.task_id)
- value = task.python_callable()
- return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+ numbers = [1, 2]
+
+ def mock_comms(msg):
+ if isinstance(msg, GetXCom):
+ if msg.task_id == "emit_numbers":
+ return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=numbers)
+ elif isinstance(msg, GetXComSequenceSlice):
+ if msg.task_id == "emit_numbers":
+ return XComSequenceSliceResult(root=numbers)
+ elif isinstance(msg, GetTICount):
+ # show is mapped by 4 (2 * 2 cross product)
+ if msg.task_ids and msg.task_ids[0] == "show":
+ return TICount(count=4)
+ return TICount(count=1)
+ return mock.DEFAULT
- mock_supervisor_comms.send.side_effect = xcom_get
+ mock_supervisor_comms.send.side_effect = mock_comms
states = [run_ti(dag, "show", map_index) for map_index in range(4)]
assert states == [TaskInstanceState.SUCCESS] * 4
@@ -591,20 +641,37 @@ def
test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_super
# Aggregates results from task group.
t.override(task_id="t3")(tg1)
- def xcom_get(msg):
- if not isinstance(msg, GetXCom):
- return mock.DEFAULT
- key = (msg.task_id, msg.map_index)
- if key in expected_values:
- value = expected_values[key]
- return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
- if msg.map_index is None:
- # Get all mapped XComValues for this ti
- value = [v for k, v in expected_values.items() if k[0] ==
msg.task_id]
- return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+ # Map task group IDs to their expansion counts
+ task_group_expansion = {"tg": 3}
+
+ def mock_comms_response(msg):
+ if isinstance(msg, GetXCom):
+ key = (msg.task_id, msg.map_index)
+ if key in expected_values:
+ value = expected_values[key]
+ return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+ if msg.map_index is None:
+ # Get all mapped XComValues for this ti
+ value = [v for k, v in expected_values.items() if k[0] ==
msg.task_id]
+ return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+ elif isinstance(msg, GetXComSequenceSlice):
+ # Handle sequence slicing for pulling all XCom values from mapped
tasks
+ task_id = msg.task_id
+ values = [v for k, v in expected_values.items() if k[0] == task_id
and k[1] is not None]
+ return XComSequenceSliceResult(root=values)
+ elif isinstance(msg, GetTICount):
+ # Handle TI count queries for upstream_map_indexes computation
+ if msg.task_ids:
+ task_id = msg.task_ids[0]
+ if task_id in expansion_per_task_id:
+ return
TICount(count=len(list(expansion_per_task_id[task_id])))
+ return TICount(count=1)
+ if msg.task_group_id:
+ return
TICount(count=task_group_expansion.get(msg.task_group_id, 0))
+ return TICount(count=0)
return mock.DEFAULT
- mock_supervisor_comms.send.side_effect = xcom_get
+ mock_supervisor_comms.send.side_effect = mock_comms_response
expected_values = {
("tg.t1", 0): ["a", "b"],
@@ -622,21 +689,11 @@ def
test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_super
"tg.t2": range(3),
"t3": [None],
}
- upstream_map_indexes_per_task_id = {
- ("tg.t1", 0): {},
- ("tg.t1", 1): {},
- ("tg.t1", 2): {},
- ("tg.t2", 0): {"tg.t1": 0},
- ("tg.t2", 1): {"tg.t1": 1},
- ("tg.t2", 2): {"tg.t1": 2},
- ("t3", None): {"tg.t2": [0, 1, 2]},
- }
for task in dag.tasks:
for map_index in expansion_per_task_id[task.task_id]:
mapped_ti = create_runtime_ti(
task=task.prepare_for_execution(),
map_index=map_index,
-
upstream_map_indexes=upstream_map_indexes_per_task_id[(task.task_id,
map_index)],
)
context = mapped_ti.get_template_context()
mapped_ti.task.render_template_fields(context)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_mapping.py
b/task-sdk/tests/task_sdk/execution_time/test_task_mapping.py
new file mode 100644
index 00000000000..5f6bf044785
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_mapping.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+from airflow.sdk import DAG, BaseOperator
+from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.sdk.execution_time.comms import TICount
+from airflow.sdk.execution_time.task_mapping import (
+ _find_common_ancestor_mapped_group,
+ _is_further_mapped_inside,
+ get_relevant_map_indexes,
+ get_ti_count_for_task,
+)
+
+
+class TestFindCommonAncestorMappedGroup:
+ """Tests for _find_common_ancestor_mapped_group function."""
+
+ def test_no_common_group_different_dags(self):
+ """Tasks in different DAGs should return None."""
+ with DAG("dag1"):
+ op1 = BaseOperator(task_id="op1")
+
+ with DAG("dag2"):
+ op2 = BaseOperator(task_id="op2")
+
+ result = _find_common_ancestor_mapped_group(op1, op2)
+ assert result is None
+
+ def test_no_common_group_no_mapped_groups(self):
+ """Tasks not in any mapped group should return None."""
+ with DAG("dag1"):
+ op1 = BaseOperator(task_id="op1")
+ op2 = BaseOperator(task_id="op2")
+
+ result = _find_common_ancestor_mapped_group(op1, op2)
+ assert result is None
+
+ def test_no_dag_returns_none(self):
+ """Tasks without DAG should return None."""
+ op1 = BaseOperator(task_id="op1")
+ op2 = BaseOperator(task_id="op2")
+
+ # Function should handle operators not assigned to a DAG gracefully
+ result = _find_common_ancestor_mapped_group(op1, op2)
+ assert result is None
+
+
+class TestIsFurtherMappedInside:
+ """Tests for _is_further_mapped_inside function."""
+
+ def test_mapped_operator_returns_true(self):
+ """A mapped operator should return True."""
+ with DAG("dag1"):
+ with TaskGroup("tg") as tg:
+ op = BaseOperator(task_id="op")
+
+ # Simulate a mapped operator
+ op._is_mapped = True
+
+ result = _is_further_mapped_inside(op, tg)
+ assert result is True
+
+ def test_non_mapped_operator_returns_false(self):
+ """A non-mapped operator with no mapped parent groups should return
False."""
+ with DAG("dag1"):
+ with TaskGroup("tg") as tg:
+ op = BaseOperator(task_id="op")
+
+ result = _is_further_mapped_inside(op, tg)
+ assert result is False
+
+
+class TestGetTiCountForTask:
+ """Tests for get_ti_count_for_task function."""
+
+ def test_queries_supervisor(self, mock_supervisor_comms):
+ """Should send GetTICount message to supervisor with task_ids."""
+ from airflow.sdk.execution_time.comms import TICount
+
+ mock_supervisor_comms.send.return_value = TICount(count=3)
+
+ result = get_ti_count_for_task("task_id", "dag_id", "run_id")
+
+ assert result == 3
+ mock_supervisor_comms.send.assert_called_once()
+ call_args = mock_supervisor_comms.send.call_args[0][0]
+ assert call_args.dag_id == "dag_id"
+ assert call_args.task_ids == ["task_id"]
+ assert call_args.run_ids == ["run_id"]
+
+
+class TestGetRelevantMapIndexes:
+ """Tests for get_relevant_map_indexes function."""
+
+ def test_returns_none_when_no_ti_count(self):
+ """Should return None when ti_count is 0 or None."""
+ with DAG("dag1"):
+ op1 = BaseOperator(task_id="op1")
+ op2 = BaseOperator(task_id="op2")
+
+ result = get_relevant_map_indexes(
+ task=op1,
+ run_id="run_id",
+ map_index=0,
+ ti_count=0,
+ relative=op2,
+ dag_id="dag1",
+ )
+ assert result is None
+
+ def test_returns_none_when_no_common_ancestor(self):
+ """Should return None when tasks have no common mapped ancestor."""
+ with DAG("dag1"):
+ op1 = BaseOperator(task_id="op1")
+ op2 = BaseOperator(task_id="op2")
+
+ result = get_relevant_map_indexes(
+ task=op1,
+ run_id="run_id",
+ map_index=0,
+ ti_count=3,
+ relative=op2,
+ dag_id="dag1",
+ )
+ assert result is None
+
+ def test_same_mapped_group_returns_single_index(self,
mock_supervisor_comms):
+ """Tasks in same mapped group should get single index matching their
map_index."""
+ with DAG("dag1"):
+ with TaskGroup("tg"):
+ op1 = BaseOperator(task_id="op1")
+ op2 = BaseOperator(task_id="op2")
+ op1 >> op2
+
+ # Mock iter_mapped_task_groups to simulate a mapped task group
+ mock_mapped_tg = MagicMock(spec=TaskGroup)
+ mock_mapped_tg.group_id = "tg"
+ op1.iter_mapped_task_groups = MagicMock(spec=TaskGroup,
return_value=iter([mock_mapped_tg]))
+ op2.iter_mapped_task_groups = MagicMock(spec=TaskGroup,
return_value=iter([mock_mapped_tg]))
+
+ # Mock: op2 has 3 TIs (mapped by 3)
+ mock_supervisor_comms.send.return_value = TICount(count=3)
+
+ # For map_index=1 with ti_count=3, should return 1 (same index)
+ result = get_relevant_map_indexes(
+ task=op2,
+ run_id="run_id",
+ map_index=1,
+ ti_count=3,
+ relative=op1,
+ dag_id="dag1",
+ )
+ assert result == 1
+
+ def test_unmapped_task_pulling_from_mapped_returns_none(self):
+ """Unmapped task pulling from mapped upstream should return None (pull
all)."""
+ with DAG("dag1"):
+ op1 = BaseOperator(task_id="op1")
+ op2 = BaseOperator(task_id="op2")
+ op1 >> op2
+
+ # op2 is not in a mapped group, so there's no common ancestor
+ result = get_relevant_map_indexes(
+ task=op2,
+ run_id="run_id",
+ map_index=0,
+ ti_count=1,
+ relative=op1,
+ dag_id="dag1",
+ )
+ assert result is None