This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c441e4bf12f Reduce API server memory usage by eliminating 
`SerializedDAG` loads on task start (#60803)
c441e4bf12f is described below

commit c441e4bf12f78c8846330fd6a37d02641f64e163
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jan 20 13:24:11 2026 +0000

    Reduce API server memory usage by eliminating `SerializedDAG` loads on task 
start (#60803)
---
 airflow-core/newsfragments/60803.significant.rst   |   1 +
 .../execution_api/datamodels/taskinstance.py       |   2 -
 .../execution_api/routes/task_instances.py         |  58 +------
 .../api_fastapi/execution_api/versions/__init__.py |   7 +-
 .../execution_api/versions/v2026_03_31.py          |  21 ++-
 .../versions/head/test_task_instances.py           |  72 +-------
 .../versions/v2025_04_28/test_task_instances.py    |  67 +-------
 devel-common/src/tests_common/pytest_plugin.py     |   5 -
 .../src/airflow/sdk/api/datamodels/_generated.py   |   3 -
 .../sdk/definitions/_internal/expandinput.py       |   4 +-
 task-sdk/src/airflow/sdk/definitions/xcom_arg.py   |  23 ++-
 .../src/airflow/sdk/execution_time/task_mapping.py | 133 +++++++++++++++
 .../src/airflow/sdk/execution_time/task_runner.py  |  39 ++++-
 .../task_sdk/definitions/test_mappedoperator.py    | 147 +++++++++++-----
 .../task_sdk/execution_time/test_task_mapping.py   | 189 +++++++++++++++++++++
 15 files changed, 521 insertions(+), 250 deletions(-)

diff --git a/airflow-core/newsfragments/60803.significant.rst 
b/airflow-core/newsfragments/60803.significant.rst
new file mode 100644
index 00000000000..1054258ad38
--- /dev/null
+++ b/airflow-core/newsfragments/60803.significant.rst
@@ -0,0 +1 @@
+Move ``upstream_map_indexes`` computation from API server to Task SDK, 
reducing memory usage on task start by eliminating ``SerializedDAG`` loads.
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index e7ebee9ebe7..513a99f6dc9 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -324,8 +324,6 @@ class TIRunContext(BaseModel):
     connections: Annotated[list[ConnectionResponse], 
Field(default_factory=list)]
     """Connections that can be accessed by the task instance."""
 
-    upstream_map_indexes: dict[str, int | list[int] | None] | None = None
-
     next_method: str | None = None
     """Method to call. Set when task resumes from a trigger."""
     next_kwargs: dict[str, Any] | str | None = None
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a729a1ee83b..a73145b30ab 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -31,7 +31,7 @@ from cadwyn import VersionedAPIRouter
 from fastapi import Body, HTTPException, Query, status
 from pydantic import JsonValue
 from sqlalchemy import func, or_, tuple_, update
-from sqlalchemy.engine import CursorResult, Row
+from sqlalchemy.engine import CursorResult
 from sqlalchemy.exc import NoResultFound, SQLAlchemyError
 from sqlalchemy.orm import joinedload
 from sqlalchemy.sql import select
@@ -64,14 +64,11 @@ from airflow.exceptions import TaskNotFound
 from airflow.models.asset import AssetActive
 from airflow.models.dag import DagModel
 from airflow.models.dagrun import DagRun as DR
-from airflow.models.expandinput import NotFullyPopulated
 from airflow.models.taskinstance import TaskInstance as TI, 
_stop_remaining_tasks
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
 from airflow.models.xcom import XComModel
 from airflow.serialization.definitions.assets import SerializedAsset, 
SerializedAssetUniqueKey
-from airflow.serialization.definitions.dag import SerializedDAG
-from airflow.task.trigger_rule import TriggerRule
 from airflow.utils.sqlalchemy import get_dialect_name
 from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
 
@@ -251,17 +248,6 @@ def ti_run(
             or 0
         )
 
-        if dag := dag_bag.get_dag_for_run(dag_run=dr, session=session):
-            upstream_map_indexes = dict(
-                _get_upstream_map_indexes(
-                    serialized_dag=dag,
-                    ti=ti,
-                    session=session,
-                )
-            )
-        else:
-            upstream_map_indexes = None
-
         context = TIRunContext(
             dag_run=dr,
             task_reschedule_count=task_reschedule_count,
@@ -271,7 +257,6 @@ def ti_run(
             connections=[],
             xcom_keys_to_clear=xcom_keys,
             should_retry=_is_eligible_to_retry(previous_state, ti.try_number, 
ti.max_tries),
-            upstream_map_indexes=upstream_map_indexes,
         )
 
         # Only set if they are non-null
@@ -287,47 +272,6 @@ def ti_run(
         )
 
 
-def _get_upstream_map_indexes(
-    *,
-    serialized_dag: SerializedDAG,
-    ti: TI | Row,
-    session: SessionDep,
-) -> Iterator[tuple[str, int | list[int] | None]]:
-    task = serialized_dag.get_task(ti.task_id)
-    for upstream_task in task.upstream_list:
-        map_indexes: int | list[int] | None
-        if (upstream_mapped_group := 
upstream_task.get_closest_mapped_task_group()) is None:
-            # regular tasks or non-mapped task groups
-            map_indexes = None
-        elif task.get_closest_mapped_task_group() is upstream_mapped_group:
-            # tasks in the same mapped task group hierarchy
-            map_indexes = ti.map_index
-        else:
-            # tasks not in the same mapped task group
-            # the upstream mapped task group should combine the return xcom as 
a list and return it
-            mapped_ti_count: int | None = None
-
-            try:
-                # First try: without resolving XCom
-                mapped_ti_count = 
upstream_mapped_group.get_parse_time_mapped_ti_count()
-            except NotFullyPopulated:
-                # Second try: resolve XCom for correct count
-                try:
-                    expand_input = upstream_mapped_group._expand_input
-                    mapped_ti_count = 
expand_input.get_total_map_length(ti.run_id, session=session)
-                except NotFullyPopulated:
-                    # For these trigger rules, unresolved map indexes are 
acceptable.
-                    # The success of the upstream task is not the main reason 
for triggering the current task.
-                    # Therefore, whether the upstream task is fully populated 
can be ignored.
-                    if task.trigger_rule != TriggerRule.ALL_SUCCESS:
-                        mapped_ti_count = None
-
-            # Compute map indexes if we have a valid count
-            map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is 
not None else None
-
-        yield upstream_task.task_id, map_indexes
-
-
 @ti_id_router.patch(
     "/{task_instance_id}/state",
     status_code=status.HTTP_204_NO_CONTENT,
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index 36c1b31b959..30d4159f745 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -33,11 +33,14 @@ from airflow.api_fastapi.execution_api.versions.v2025_12_08 
import (
     AddDagRunDetailEndpoint,
     MovePreviousRunEndpoint,
 )
-from airflow.api_fastapi.execution_api.versions.v2026_03_31 import 
ModifyDeferredTaskKwargsToJsonValue
+from airflow.api_fastapi.execution_api.versions.v2026_03_31 import (
+    ModifyDeferredTaskKwargsToJsonValue,
+    RemoveUpstreamMapIndexesField,
+)
 
 bundle = VersionBundle(
     HeadVersion(),
-    Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue),
+    Version("2026-03-31", ModifyDeferredTaskKwargsToJsonValue, 
RemoveUpstreamMapIndexesField),
     Version("2025-12-08", MovePreviousRunEndpoint, AddDagRunDetailEndpoint),
     Version("2025-11-07", AddPartitionKeyField),
     Version("2025-11-05", AddTriggeringUserNameField),
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
index 48630e0b50c..72e193426da 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_03_31.py
@@ -19,9 +19,9 @@ from __future__ import annotations
 
 from typing import Any
 
-from cadwyn import VersionChange, schema
+from cadwyn import ResponseInfo, VersionChange, 
convert_response_to_previous_version_for, schema
 
-from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TIDeferredStatePayload
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import 
TIDeferredStatePayload, TIRunContext
 
 
 class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
@@ -33,3 +33,20 @@ class ModifyDeferredTaskKwargsToJsonValue(VersionChange):
         
schema(TIDeferredStatePayload).field("trigger_kwargs").had(type=dict[str, Any] 
| str),
         schema(TIDeferredStatePayload).field("next_kwargs").had(type=dict[str, 
Any]),
     )
+
+
+class RemoveUpstreamMapIndexesField(VersionChange):
+    """Remove upstream_map_indexes field from TIRunContext - now computed by 
Task SDK."""
+
+    description = __doc__
+
+    instructions_to_migrate_to_previous_version = (
+        schema(TIRunContext)
+        .field("upstream_map_indexes")
+        .existed_as(type=dict[str, int | list[int] | None] | None),
+    )
+
+    @convert_response_to_previous_version_for(TIRunContext)  # type: 
ignore[arg-type]
+    def add_upstream_map_indexes_field(response: ResponseInfo) -> None:  # 
type: ignore[misc]
+        """Add upstream_map_indexes field with None for older API versions."""
+        response.body["upstream_map_indexes"] = None
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index e99b7d9380f..cfee6c9d46c 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -202,7 +202,6 @@ class TestTIRunState:
                 "partition_key": None,
             },
             "task_reschedule_count": 0,
-            "upstream_map_indexes": {},
             "max_tries": max_tries,
             "should_retry": should_retry,
             "variables": [],
@@ -249,10 +248,7 @@ class TestTIRunState:
         assert response.status_code == 409
 
     def test_dynamic_task_mapping_with_parse_time_value(self, client, 
dag_maker):
-        """
-        Test that the Task Instance upstream_map_indexes is correctly fetched 
when to running the Task Instances
-        """
-
+        """Test that dynamic task mapping works correctly with parse-time 
values."""
         with dag_maker("test_dynamic_task_mapping_with_parse_time_value", 
serialized=True):
 
             @task_group
@@ -278,23 +274,6 @@ class TestTIRunState:
             ti.set_state(State.QUEUED)
         dag_maker.session.flush()
 
-        # key: (task_id, map_index)
-        # value: result upstream_map_indexes ({task_id: map_indexes})
-        expected_upstream_map_indexes = {
-            # no upstream task for task_group_1.group_task_1
-            ("task_group_1.group1_task_1", 0): {},
-            ("task_group_1.group1_task_1", 1): {},
-            # the upstream task for task_group_1.group_task_2 is 
task_group_1.group_task_2
-            # since they are in the same task group, the upstream map index 
should be the same as the task
-            ("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1": 
0},
-            ("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1": 
1},
-            # the upstream task for task2 is the last tasks of task_group_1, 
which is
-            # task_group_1.group_task_2
-            # since they are not in the same task group, the upstream map 
index should include all the
-            # expanded tasks
-            ("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
-        }
-
         for ti in dr.get_task_instances():
             response = client.patch(
                 f"/execution/task-instances/{ti.id}/run",
@@ -308,13 +287,9 @@ class TestTIRunState:
             )
 
             assert response.status_code == 200
-            upstream_map_indexes = response.json()["upstream_map_indexes"]
-            assert upstream_map_indexes == 
expected_upstream_map_indexes[(ti.task_id, ti.map_index)]
 
-    def test_nested_mapped_task_group_upstream_indexes(self, client, 
dag_maker):
-        """
-        Test that upstream_map_indexes are correctly computed for tasks in 
nested mapped task groups.
-        """
+    def test_nested_mapped_task_group(self, client, dag_maker):
+        """Test that nested mapped task groups work correctly."""
         with dag_maker("test_nested_mapped_tg", serialized=True):
 
             @task
@@ -346,25 +321,11 @@ class TestTIRunState:
                 ti.set_state(State.QUEUED)
         dag_maker.session.flush()
 
-        # Expected upstream_map_indexes for each print_task instance
-        expected_upstream_map_indexes = {
-            ("expandable_task_group.inner_task_group.print_task", 0): {
-                "expandable_task_group.inner_task_group.alter_input": 0
-            },
-            ("expandable_task_group.inner_task_group.print_task", 1): {
-                "expandable_task_group.inner_task_group.alter_input": 1
-            },
-            ("expandable_task_group.inner_task_group.print_task", 2): {
-                "expandable_task_group.inner_task_group.alter_input": 2
-            },
-        }
-
         # Get only the expanded print_task instances (not the template)
         print_task_tis = [
             ti for ti in dr.get_task_instances() if "print_task" in ti.task_id 
and ti.map_index >= 0
         ]
 
-        # Test each print_task instance
         for ti in print_task_tis:
             response = client.patch(
                 f"/execution/task-instances/{ti.id}/run",
@@ -378,18 +339,9 @@ class TestTIRunState:
             )
 
             assert response.status_code == 200
-            upstream_map_indexes = response.json()["upstream_map_indexes"]
-            expected = expected_upstream_map_indexes[(ti.task_id, 
ti.map_index)]
-
-            assert upstream_map_indexes == expected, (
-                f"Task {ti.task_id}[{ti.map_index}] should have 
upstream_map_indexes {expected}, "
-                f"but got {upstream_map_indexes}"
-            )
 
     def test_dynamic_task_mapping_with_xcom(self, client: Client, dag_maker: 
DagMaker, session: Session):
-        """
-        Test that the Task Instance upstream_map_indexes is correctly fetched 
when to running the Task Instances with xcom
-        """
+        """Test that dynamic task mapping works correctly with XCom values."""
         from airflow.models.taskmap import TaskMap
 
         with dag_maker(session=session, serialized=True):
@@ -442,13 +394,10 @@ class TestTIRunState:
                 "start_date": "2024-09-30T12:00:00Z",
             },
         )
-        assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 
2, 3, 4, 5]}
+        assert response.status_code == 200
 
     def test_dynamic_task_mapping_with_all_success_trigger_rule(self, 
dag_maker: DagMaker, session: Session):
-        """
-        Test that the Task Instance upstream_map_indexes is not populuated but
-        the downstream task should not be run.
-        """
+        """Test that with ALL_SUCCESS trigger rule and skipped upstream, 
downstream should not run."""
 
         with dag_maker(session=session, serialized=True):
 
@@ -504,10 +453,7 @@ class TestTIRunState:
     def test_dynamic_task_mapping_with_non_all_success_trigger_rule(
         self, client: Client, dag_maker: DagMaker, session: Session, 
trigger_rule: TriggerRule
     ):
-        """
-        Test that the Task Instance upstream_map_indexes is not populuated but
-        the downstream task should still be run due to trigger rule.
-        """
+        """Test that with non-ALL_SUCCESS trigger rule, downstream task should 
still run."""
 
         with dag_maker(session=session, serialized=True):
 
@@ -564,7 +510,7 @@ class TestTIRunState:
                 "start_date": "2024-09-30T12:00:00Z",
             },
         )
-        assert response.json()["upstream_map_indexes"] == {"tg.task_2": None}
+        assert response.status_code == 200
 
     def test_next_kwargs_still_encoded(self, client, session, 
create_task_instance, time_machine):
         instant_str = "2024-09-30T12:00:00Z"
@@ -615,7 +561,6 @@ class TestTIRunState:
         assert response.json() == {
             "dag_run": mock.ANY,
             "task_reschedule_count": 0,
-            "upstream_map_indexes": {},
             "max_tries": 0,
             "should_retry": False,
             "variables": [],
@@ -687,7 +632,6 @@ class TestTIRunState:
         assert response.json() == {
             "dag_run": mock.ANY,
             "task_reschedule_count": 0,
-            "upstream_map_indexes": {},
             "max_tries": 0,
             "should_retry": False,
             "variables": [],
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
index 16371b67d04..efa29338d92 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_task_instances.py
@@ -17,8 +17,6 @@
 
 from __future__ import annotations
 
-from unittest.mock import patch
-
 import pytest
 
 from airflow._shared.timezones import timezone
@@ -50,52 +48,19 @@ class TestTIUpdateState:
         clear_db_assets()
         clear_db_runs()
 
-    @pytest.mark.parametrize(
-        ("mock_indexes", "expected_response_indexes"),
-        [
-            pytest.param(
-                [("task_a", 5), ("task_b", 10)],
-                {"task_a": 5, "task_b": 10},
-                id="plain ints",
-            ),
-            pytest.param(
-                [("task_a", [3, 4]), ("task_b", [9])],
-                {"task_a": 3, "task_b": 9},
-                id="list of ints",
-            ),
-            pytest.param(
-                [
-                    ("task_a", None),
-                ],
-                {"task_a": None},
-                id="task has no upstreams",
-            ),
-            pytest.param(
-                [("task_a", None), ("task_b", [6, 7]), ("task_c", 2)],
-                {"task_a": None, "task_b": 6, "task_c": 2},
-                id="mixed types",
-            ),
-        ],
-    )
-    
@patch("airflow.api_fastapi.execution_api.routes.task_instances._get_upstream_map_indexes")
     def test_ti_run(
         self,
-        mock_get_upstream_map_indexes,
         ver_client,
         session,
         create_task_instance,
         time_machine,
-        mock_indexes,
-        expected_response_indexes,
         get_execution_app,
     ):
         """
         Test that this version of the endpoint works.
 
-        Later versions modified the type of upstream_map_indexes.
+        upstream_map_indexes is now always None as it's computed by the Task 
SDK.
         """
-        mock_get_upstream_map_indexes.return_value = mock_indexes
-
         instant_str = "2024-09-30T12:00:00Z"
         instant = timezone.parse(instant_str)
         time_machine.move_to(instant, tick=False)
@@ -124,26 +89,10 @@ class TestTIUpdateState:
         )
 
         assert response.status_code == 200
-        assert response.json() == {
-            "dag_run": {
-                "dag_id": "dag",
-                "run_id": "test",
-                "clear_number": 0,
-                "logical_date": instant_str,
-                "data_interval_start": 
instant.subtract(days=1).to_iso8601_string(),
-                "data_interval_end": instant_str,
-                "run_after": instant_str,
-                "start_date": instant_str,
-                "end_date": None,
-                "run_type": "manual",
-                "conf": {},
-                "consumed_asset_events": [],
-            },
-            "task_reschedule_count": 0,
-            "upstream_map_indexes": expected_response_indexes,
-            "max_tries": 0,
-            "should_retry": False,
-            "variables": [],
-            "connections": [],
-            "xcom_keys_to_clear": [],
-        }
+        result = response.json()
+        # upstream_map_indexes is now computed by SDK, server returns None
+        assert result["upstream_map_indexes"] is None
+        assert result["dag_run"]["dag_id"] == "dag"
+        assert result["task_reschedule_count"] == 0
+        assert result["max_tries"] == 0
+        assert result["should_retry"] is False
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index bc18c94f17b..258d3345706 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2457,7 +2457,6 @@ def create_runtime_ti(mocked_parse):
         run_type: str = "manual",
         try_number: int = 1,
         map_index: int | None = -1,
-        upstream_map_indexes: dict[str, int | list[int] | None] | None = None,
         task_reschedule_count: int = 0,
         ti_id: UUID | None = None,
         conf: dict[str, Any] | None = None,
@@ -2532,12 +2531,8 @@ def create_runtime_ti(mocked_parse):
             task_reschedule_count=task_reschedule_count,
             max_tries=task_retries if max_tries is None else max_tries,
             should_retry=should_retry if should_retry is not None else 
try_number <= task_retries,
-            upstream_map_indexes=upstream_map_indexes,
         )
 
-        if upstream_map_indexes is not None:
-            ti_context.upstream_map_indexes = upstream_map_indexes
-
         compat_fields = {
             "requests_fd": 0,
             "sentry_integration": "",
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 6a3a07f5e8a..149aa5d7dcb 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -631,9 +631,6 @@ class TIRunContext(BaseModel):
     max_tries: Annotated[int, Field(title="Max Tries")]
     variables: Annotated[list[VariableResponse] | None, 
Field(title="Variables")] = None
     connections: Annotated[list[ConnectionResponse] | None, 
Field(title="Connections")] = None
-    upstream_map_indexes: Annotated[
-        dict[str, int | list[int] | None] | None, Field(title="Upstream Map 
Indexes")
-    ] = None
     next_method: Annotated[str | None, Field(title="Next Method")] = None
     next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next 
Kwargs")] = None
     xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To 
Clear")] = None
diff --git a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py 
b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
index fa56ae84650..b6ffbd22142 100644
--- a/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
+++ b/task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py
@@ -189,7 +189,9 @@ class DictOfListsExpandInput(ResolveMixin):
         if map_index is None or map_index < 0:
             raise RuntimeError("can't resolve task-mapping argument without 
expanding")
 
-        upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", 
{})
+        # Get pre-computed upstream_map_indexes if available, otherwise 
default to empty dict.
+        # When empty, individual XComArgs will compute their map_indexes 
lazily in xcom_arg.py.
+        upstream_map_indexes = getattr(context["ti"], "_upstream_map_indexes", 
None) or {}
 
         # TODO: This initiates one API call for each XComArg. Would it be
         # more efficient to do one single call and unpack the value here?
diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py 
b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
index ef316574df5..88db9009242 100644
--- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
+++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
@@ -30,7 +30,7 @@ from airflow.sdk import TriggerRule
 from airflow.sdk.definitions._internal.abstractoperator import AbstractOperator
 from airflow.sdk.definitions._internal.mixins import DependencyMixin, 
ResolveMixin
 from airflow.sdk.definitions._internal.setup_teardown import 
SetupTeardownContext
-from airflow.sdk.definitions._internal.types import NOTSET, is_arg_set
+from airflow.sdk.definitions._internal.types import NOTSET, ArgNotSet, 
is_arg_set
 from airflow.sdk.exceptions import AirflowException, XComNotFound
 from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
 from airflow.sdk.execution_time.xcom import BaseXCom
@@ -337,10 +337,25 @@ class PlainXComArg(XComArg):
             return LazyXComSequence(xcom_arg=self, ti=ti)
         tg = self.operator.get_closest_mapped_task_group()
         if tg is None:
-            map_indexes = None
+            # No mapped task group - pull from unmapped instance
+            map_indexes: int | range | None | ArgNotSet = None
         else:
-            upstream_map_indexes = getattr(ti, "_upstream_map_indexes", {})
-            map_indexes = upstream_map_indexes.get(task_id, None)
+            # Check for pre-computed value from server (backward compatibility)
+            upstream_map_indexes = getattr(ti, "_upstream_map_indexes", None)
+            if upstream_map_indexes is not None:
+                # Use None as default to match original behavior (filter for 
unmapped XCom)
+                map_indexes = upstream_map_indexes.get(task_id, None)
+            else:
+                # Compute lazily - ti_count will be queried if needed
+                cached_context = getattr(ti, "_cached_template_context", None)
+                ti_count = cached_context.get("expanded_ti_count") if 
cached_context else None
+                computed = ti.get_relevant_upstream_map_indexes(
+                    upstream=self.operator,
+                    ti_count=ti_count,
+                    session=None,  # Not used in SDK implementation
+                )
+                # None means "no filtering needed" -> use NOTSET to pull all 
values
+                map_indexes = NOTSET if computed is None else computed
         result = ti.xcom_pull(
             task_ids=task_id,
             key=self.key,
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_mapping.py 
b/task-sdk/src/airflow/sdk/execution_time/task_mapping.py
new file mode 100644
index 00000000000..fc54bb794d9
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/execution_time/task_mapping.py
@@ -0,0 +1,133 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utility functions for computing upstream map indexes in the Task SDK."""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.sdk.execution_time.comms import GetTICount, TICount
+
+if TYPE_CHECKING:
+    from airflow.sdk import BaseOperator
+    from airflow.sdk.definitions.taskgroup import MappedTaskGroup, TaskGroup
+
+
+def _find_common_ancestor_mapped_group(node1: BaseOperator, node2: 
BaseOperator) -> MappedTaskGroup | None:
+    """
+    Given two operators, find their innermost common mapped task group.
+
+    :param node1: First operator
+    :param node2: Second operator
+    :return: The common mapped task group, or None if they don't share one
+    """
+    try:
+        dag1 = node1.dag
+        dag2 = node2.dag
+    except RuntimeError:
+        # Operator not assigned to a DAG
+        return None
+
+    if dag1 is None or dag2 is None or node1.dag_id != node2.dag_id:
+        return None
+    parent_group_ids = {g.group_id for g in node1.iter_mapped_task_groups()}
+    common_groups = (g for g in node2.iter_mapped_task_groups() if g.group_id 
in parent_group_ids)
+    return next(common_groups, None)
+
+
+def _is_further_mapped_inside(operator: BaseOperator, container: TaskGroup) -> 
bool:
+    """
+    Whether given operator is *further* mapped inside a task group.
+
+    :param operator: The operator to check
+    :param container: The container task group
+    :return: True if the operator is further mapped inside the container
+    """
+    # Use getattr for compatibility with both SDK and serialized operators
+    if getattr(operator, "is_mapped", False):
+        return True
+    task_group = operator.task_group
+    while task_group is not None and task_group.group_id != container.group_id:
+        if getattr(task_group, "is_mapped", False):
+            return True
+        task_group = task_group.parent_group
+    return False
+
+
+def get_ti_count_for_task(task_id: str, dag_id: str, run_id: str) -> int:
+    """
+    Query TI count for a specific task.
+
+    :param task_id: The task ID
+    :param dag_id: The DAG ID
+    :param run_id: The run ID
+    :return: The count of task instances for the task
+    """
+    # Import here because SUPERVISOR_COMMS is set at runtime, not import time
+    from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+    response = SUPERVISOR_COMMS.send(GetTICount(dag_id=dag_id, 
task_ids=[task_id], run_ids=[run_id]))
+    if not isinstance(response, TICount):
+        raise RuntimeError(f"Unexpected response type: {type(response)}")
+    return response.count
+
+
+def get_relevant_map_indexes(
+    task: BaseOperator,
+    run_id: str,
+    map_index: int,
+    ti_count: int,
+    relative: BaseOperator,
+    dag_id: str,
+) -> int | range | None:
+    """
+    Determine map indexes for XCom aggregation.
+
+    This is used to figure out which specific map indexes of an upstream task
+    are relevant when resolving XCom values for a task in a mapped task group.
+
+    :param task: The current task
+    :param run_id: The current run ID
+    :param map_index: The map index of the current task instance
+    :param ti_count: The total count of task instances for the current task
+    :param relative: The upstream/downstream task to find relevant map indexes 
for
+    :param dag_id: The DAG ID
+    :return: None (use entire value), int (single index), or range (subset of 
indexes)
+    """
+    if not ti_count:
+        return None
+
+    common_ancestor = _find_common_ancestor_mapped_group(task, relative)
+    if common_ancestor is None or common_ancestor.group_id is None:
+        return None  # Different mapping contexts → use whole value
+
+    # Query TI count using the current task, which is in the mapped task group.
+    # This gives us the number of expansion iterations, not total TIs in the 
group.
+    ancestor_ti_count = get_ti_count_for_task(task.task_id, dag_id, run_id)
+    if not ancestor_ti_count:
+        return None
+
+    ancestor_map_index = map_index * ancestor_ti_count // ti_count
+
+    if not _is_further_mapped_inside(relative, common_ancestor):
+        return ancestor_map_index  # Single index
+
+    # Partial aggregation for selected TIs
+    further_count = ti_count // ancestor_ti_count
+    map_index_start = ancestor_map_index * further_count
+    return range(map_index_start, map_index_start + further_count)
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 60577066998..1cfc2c82d45 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -268,10 +268,10 @@ class RuntimeTaskInstance(TaskInstance):
                     }
                 )
 
-            if from_server.upstream_map_indexes is not None:
-                # We stash this in here for later use, but we purposefully 
don't want to document it's
-                # existence. Should this be a private attribute on RuntimeTI 
instead perhaps?
-                setattr(self, "_upstream_map_indexes", 
from_server.upstream_map_indexes)
+            # Backward compatibility: old servers may still send 
upstream_map_indexes
+            upstream_map_indexes = getattr(from_server, 
"upstream_map_indexes", None)
+            if upstream_map_indexes is not None:
+                setattr(self, "_upstream_map_indexes", upstream_map_indexes)
 
         return self._cached_template_context
 
@@ -436,8 +436,35 @@ class RuntimeTaskInstance(TaskInstance):
     def get_relevant_upstream_map_indexes(
         self, upstream: BaseOperator, ti_count: int | None, session: Any
     ) -> int | range | None:
-        # TODO: Implement this method
-        return None
+        """
+        Compute the relevant upstream map indexes for XCom resolution.
+
+        :param upstream: The upstream operator
+        :param ti_count: The total count of task instances for this task's 
expansion
+        :param session: Not used (kept for API compatibility)
+        :return: None (use entire value), int (single index), or range (subset 
of indexes)
+        """
+        from airflow.sdk.execution_time.task_mapping import 
get_relevant_map_indexes, get_ti_count_for_task
+
+        map_index = self.map_index
+        if map_index is None or map_index < 0:
+            return None
+
+        # If ti_count not provided, we need to query it
+        if ti_count is None:
+            ti_count = get_ti_count_for_task(self.task_id, self.dag_id, 
self.run_id)
+
+        if not ti_count:
+            return None
+
+        return get_relevant_map_indexes(
+            task=self.task,
+            run_id=self.run_id,
+            map_index=map_index,
+            ti_count=ti_count,
+            relative=upstream,
+            dag_id=self.dag_id,
+        )
 
     def get_first_reschedule_date(self, context: Context) -> AwareDatetime | 
None:
         """Get the first reschedule date for the task instance if found, none 
otherwise."""
diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py 
b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
index 4d0e52e8f2f..85b6e9c0b27 100644
--- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
+++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
@@ -31,7 +31,15 @@ from airflow.sdk.bases.xcom import BaseXCom
 from airflow.sdk.definitions.dag import DAG
 from airflow.sdk.definitions.mappedoperator import MappedOperator
 from airflow.sdk.definitions.xcom_arg import XComArg
-from airflow.sdk.execution_time.comms import GetXCom, SetXCom, XComResult
+from airflow.sdk.execution_time.comms import (
+    GetTICount,
+    GetXCom,
+    GetXComSequenceSlice,
+    SetXCom,
+    TICount,
+    XComResult,
+    XComSequenceSliceResult,
+)
 
 from tests_common.test_utils.mapping import expand_mapped_task  # noqa: F401
 from tests_common.test_utils.mock_operators import (
@@ -252,9 +260,18 @@ def test_mapped_render_template_fields_validating_operator(
         )
         mapped = callable(mapped, task1.output)
 
-    mock_supervisor_comms.send.return_value = 
XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=["{{ ds }}"])
+    def mock_comms(msg):
+        if isinstance(msg, GetXCom):
+            return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=["{{ ds }}"])
+        if isinstance(msg, GetXComSequenceSlice):
+            return XComSequenceSliceResult(root=["{{ ds }}"])
+        if isinstance(msg, GetTICount):
+            return TICount(count=1)
+        return mock.DEFAULT
+
+    mock_supervisor_comms.send.side_effect = mock_comms
 
-    mapped_ti = create_runtime_ti(task=mapped, map_index=0, 
upstream_map_indexes={task1.task_id: 1})
+    mapped_ti = create_runtime_ti(task=mapped, map_index=0)
 
     assert isinstance(mapped_ti.task, MappedOperator)
     
mapped_ti.task.render_template_fields(context=mapped_ti.get_template_context())
@@ -273,13 +290,13 @@ def 
test_mapped_render_nested_template_fields(create_runtime_ti, mock_supervisor
             task_id="t", arg2=NestedFields(field_1="{{ ti.task_id }}", 
field_2="value_2")
         ).expand(arg1=["{{ ti.task_id }}", ["s", "{{ ti.task_id }}"]])
 
-    ti = create_runtime_ti(task=mapped, map_index=0, upstream_map_indexes={})
+    ti = create_runtime_ti(task=mapped, map_index=0)
     ti.task.render_template_fields(context=ti.get_template_context())
     assert ti.task.arg1 == "t"
     assert ti.task.arg2.field_1 == "t"
     assert ti.task.arg2.field_2 == "value_2"
 
-    ti = create_runtime_ti(task=mapped, map_index=1, upstream_map_indexes={})
+    ti = create_runtime_ti(task=mapped, map_index=1)
     ti.task.render_template_fields(context=ti.get_template_context())
     assert ti.task.arg1 == ["s", "t"]
     assert ti.task.arg2.field_1 == "t"
@@ -300,11 +317,20 @@ def 
test_expand_kwargs_render_template_fields_validating_operator(
         task1 = BaseOperator(task_id="op1")
         mapped = MockOperator.partial(task_id="a", arg2="{{ ti.task_id 
}}").expand_kwargs(task1.output)
 
-    mock_supervisor_comms.send.return_value = XComResult(
-        key=BaseXCom.XCOM_RETURN_KEY, value=[{"arg1": "{{ ds }}"}, {"arg1": 2}]
-    )
+    xcom_values = [{"arg1": "{{ ds }}"}, {"arg1": 2}]
+
+    def mock_comms(msg):
+        if isinstance(msg, GetXCom):
+            return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=xcom_values)
+        if isinstance(msg, GetXComSequenceSlice):
+            return XComSequenceSliceResult(root=xcom_values)
+        if isinstance(msg, GetTICount):
+            return TICount(count=2)
+        return mock.DEFAULT
+
+    mock_supervisor_comms.send.side_effect = mock_comms
 
-    ti = create_runtime_ti(task=mapped, map_index=map_index, 
upstream_map_indexes={})
+    ti = create_runtime_ti(task=mapped, map_index=map_index)
     assert isinstance(ti.task, MappedOperator)
     ti.task.render_template_fields(context=ti.get_template_context())
     assert isinstance(ti.task, MockOperator)
@@ -428,14 +454,29 @@ def test_map_cross_product(run_ti: RunTI, 
mock_supervisor_comms):
 
         show.expand(number=emit_numbers(), letter=emit_letters())
 
-    def xcom_get(msg):
-        if not isinstance(msg, GetXCom):
-            return mock.DEFAULT
-        task = dag.get_task(msg.task_id)
-        value = task.python_callable()
-        return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+    numbers = [1, 2]
+    letters = {"a": "x", "b": "y", "c": "z"}
+
+    def mock_comms(msg):
+        if isinstance(msg, GetXCom):
+            if msg.task_id == "emit_numbers":
+                return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=numbers)
+            if msg.task_id == "emit_letters":
+                return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=letters)
+        elif isinstance(msg, GetXComSequenceSlice):
+            if msg.task_id == "emit_numbers":
+                return XComSequenceSliceResult(root=numbers)
+            if msg.task_id == "emit_letters":
+                # Convert dict items to list for XComSequenceSliceResult
+                return XComSequenceSliceResult(root=list(letters.items()))
+        elif isinstance(msg, GetTICount):
+            # show is mapped by 6 (2 numbers * 3 letters)
+            if msg.task_ids and msg.task_ids[0] == "show":
+                return TICount(count=6)
+            return TICount(count=1)
+        return mock.DEFAULT
 
-    mock_supervisor_comms.send.side_effect = xcom_get
+    mock_supervisor_comms.send.side_effect = mock_comms
 
     states = [run_ti(dag, "show", map_index) for map_index in range(6)]
     assert states == [TaskInstanceState.SUCCESS] * 6
@@ -466,14 +507,23 @@ def test_map_product_same(run_ti: RunTI, 
mock_supervisor_comms):
         emit_task = emit_numbers()
         show.expand(a=emit_task, b=emit_task)
 
-    def xcom_get(msg):
-        if not isinstance(msg, GetXCom):
-            return mock.DEFAULT
-        task = dag.get_task(msg.task_id)
-        value = task.python_callable()
-        return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+    numbers = [1, 2]
+
+    def mock_comms(msg):
+        if isinstance(msg, GetXCom):
+            if msg.task_id == "emit_numbers":
+                return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=numbers)
+        elif isinstance(msg, GetXComSequenceSlice):
+            if msg.task_id == "emit_numbers":
+                return XComSequenceSliceResult(root=numbers)
+        elif isinstance(msg, GetTICount):
+            # show is mapped by 4 (2 * 2 cross product)
+            if msg.task_ids and msg.task_ids[0] == "show":
+                return TICount(count=4)
+            return TICount(count=1)
+        return mock.DEFAULT
 
-    mock_supervisor_comms.send.side_effect = xcom_get
+    mock_supervisor_comms.send.side_effect = mock_comms
 
     states = [run_ti(dag, "show", map_index) for map_index in range(4)]
     assert states == [TaskInstanceState.SUCCESS] * 4
@@ -591,20 +641,37 @@ def 
test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_super
         # Aggregates results from task group.
         t.override(task_id="t3")(tg1)
 
-    def xcom_get(msg):
-        if not isinstance(msg, GetXCom):
-            return mock.DEFAULT
-        key = (msg.task_id, msg.map_index)
-        if key in expected_values:
-            value = expected_values[key]
-            return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
-        if msg.map_index is None:
-            # Get all mapped XComValues for this ti
-            value = [v for k, v in expected_values.items() if k[0] == 
msg.task_id]
-            return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+    # Map task group IDs to their expansion counts
+    task_group_expansion = {"tg": 3}
+
+    def mock_comms_response(msg):
+        if isinstance(msg, GetXCom):
+            key = (msg.task_id, msg.map_index)
+            if key in expected_values:
+                value = expected_values[key]
+                return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+            if msg.map_index is None:
+                # Get all mapped XComValues for this ti
+                value = [v for k, v in expected_values.items() if k[0] == 
msg.task_id]
+                return XComResult(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+        elif isinstance(msg, GetXComSequenceSlice):
+            # Handle sequence slicing for pulling all XCom values from mapped 
tasks
+            task_id = msg.task_id
+            values = [v for k, v in expected_values.items() if k[0] == task_id 
and k[1] is not None]
+            return XComSequenceSliceResult(root=values)
+        elif isinstance(msg, GetTICount):
+            # Handle TI count queries for upstream_map_indexes computation
+            if msg.task_ids:
+                task_id = msg.task_ids[0]
+                if task_id in expansion_per_task_id:
+                    return 
TICount(count=len(list(expansion_per_task_id[task_id])))
+                return TICount(count=1)
+            if msg.task_group_id:
+                return 
TICount(count=task_group_expansion.get(msg.task_group_id, 0))
+            return TICount(count=0)
         return mock.DEFAULT
 
-    mock_supervisor_comms.send.side_effect = xcom_get
+    mock_supervisor_comms.send.side_effect = mock_comms_response
 
     expected_values = {
         ("tg.t1", 0): ["a", "b"],
@@ -622,21 +689,11 @@ def 
test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_super
         "tg.t2": range(3),
         "t3": [None],
     }
-    upstream_map_indexes_per_task_id = {
-        ("tg.t1", 0): {},
-        ("tg.t1", 1): {},
-        ("tg.t1", 2): {},
-        ("tg.t2", 0): {"tg.t1": 0},
-        ("tg.t2", 1): {"tg.t1": 1},
-        ("tg.t2", 2): {"tg.t1": 2},
-        ("t3", None): {"tg.t2": [0, 1, 2]},
-    }
     for task in dag.tasks:
         for map_index in expansion_per_task_id[task.task_id]:
             mapped_ti = create_runtime_ti(
                 task=task.prepare_for_execution(),
                 map_index=map_index,
-                
upstream_map_indexes=upstream_map_indexes_per_task_id[(task.task_id, 
map_index)],
             )
             context = mapped_ti.get_template_context()
             mapped_ti.task.render_template_fields(context)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_mapping.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_mapping.py
new file mode 100644
index 00000000000..5f6bf044785
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_mapping.py
@@ -0,0 +1,189 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+from airflow.sdk import DAG, BaseOperator
+from airflow.sdk.definitions.taskgroup import TaskGroup
+from airflow.sdk.execution_time.comms import TICount
+from airflow.sdk.execution_time.task_mapping import (
+    _find_common_ancestor_mapped_group,
+    _is_further_mapped_inside,
+    get_relevant_map_indexes,
+    get_ti_count_for_task,
+)
+
+
+class TestFindCommonAncestorMappedGroup:
+    """Tests for _find_common_ancestor_mapped_group function."""
+
+    def test_no_common_group_different_dags(self):
+        """Tasks in different DAGs should return None."""
+        with DAG("dag1"):
+            op1 = BaseOperator(task_id="op1")
+
+        with DAG("dag2"):
+            op2 = BaseOperator(task_id="op2")
+
+        result = _find_common_ancestor_mapped_group(op1, op2)
+        assert result is None
+
+    def test_no_common_group_no_mapped_groups(self):
+        """Tasks not in any mapped group should return None."""
+        with DAG("dag1"):
+            op1 = BaseOperator(task_id="op1")
+            op2 = BaseOperator(task_id="op2")
+
+        result = _find_common_ancestor_mapped_group(op1, op2)
+        assert result is None
+
+    def test_no_dag_returns_none(self):
+        """Tasks without DAG should return None."""
+        op1 = BaseOperator(task_id="op1")
+        op2 = BaseOperator(task_id="op2")
+
+        # Function should handle operators not assigned to a DAG gracefully
+        result = _find_common_ancestor_mapped_group(op1, op2)
+        assert result is None
+
+
+class TestIsFurtherMappedInside:
+    """Tests for _is_further_mapped_inside function."""
+
+    def test_mapped_operator_returns_true(self):
+        """A mapped operator should return True."""
+        with DAG("dag1"):
+            with TaskGroup("tg") as tg:
+                op = BaseOperator(task_id="op")
+
+        # Simulate a mapped operator
+        op._is_mapped = True
+
+        result = _is_further_mapped_inside(op, tg)
+        assert result is True
+
+    def test_non_mapped_operator_returns_false(self):
+        """A non-mapped operator with no mapped parent groups should return 
False."""
+        with DAG("dag1"):
+            with TaskGroup("tg") as tg:
+                op = BaseOperator(task_id="op")
+
+        result = _is_further_mapped_inside(op, tg)
+        assert result is False
+
+
+class TestGetTiCountForTask:
+    """Tests for get_ti_count_for_task function."""
+
+    def test_queries_supervisor(self, mock_supervisor_comms):
+        """Should send GetTICount message to supervisor with task_ids."""
+        from airflow.sdk.execution_time.comms import TICount
+
+        mock_supervisor_comms.send.return_value = TICount(count=3)
+
+        result = get_ti_count_for_task("task_id", "dag_id", "run_id")
+
+        assert result == 3
+        mock_supervisor_comms.send.assert_called_once()
+        call_args = mock_supervisor_comms.send.call_args[0][0]
+        assert call_args.dag_id == "dag_id"
+        assert call_args.task_ids == ["task_id"]
+        assert call_args.run_ids == ["run_id"]
+
+
+class TestGetRelevantMapIndexes:
+    """Tests for get_relevant_map_indexes function."""
+
+    def test_returns_none_when_no_ti_count(self):
+        """Should return None when ti_count is 0 or None."""
+        with DAG("dag1"):
+            op1 = BaseOperator(task_id="op1")
+            op2 = BaseOperator(task_id="op2")
+
+        result = get_relevant_map_indexes(
+            task=op1,
+            run_id="run_id",
+            map_index=0,
+            ti_count=0,
+            relative=op2,
+            dag_id="dag1",
+        )
+        assert result is None
+
+    def test_returns_none_when_no_common_ancestor(self):
+        """Should return None when tasks have no common mapped ancestor."""
+        with DAG("dag1"):
+            op1 = BaseOperator(task_id="op1")
+            op2 = BaseOperator(task_id="op2")
+
+        result = get_relevant_map_indexes(
+            task=op1,
+            run_id="run_id",
+            map_index=0,
+            ti_count=3,
+            relative=op2,
+            dag_id="dag1",
+        )
+        assert result is None
+
+    def test_same_mapped_group_returns_single_index(self, 
mock_supervisor_comms):
+        """Tasks in same mapped group should get single index matching their 
map_index."""
+        with DAG("dag1"):
+            with TaskGroup("tg"):
+                op1 = BaseOperator(task_id="op1")
+                op2 = BaseOperator(task_id="op2")
+                op1 >> op2
+
+        # Mock iter_mapped_task_groups to simulate a mapped task group
+        mock_mapped_tg = MagicMock(spec=TaskGroup)
+        mock_mapped_tg.group_id = "tg"
+        op1.iter_mapped_task_groups = MagicMock(spec=TaskGroup, 
return_value=iter([mock_mapped_tg]))
+        op2.iter_mapped_task_groups = MagicMock(spec=TaskGroup, 
return_value=iter([mock_mapped_tg]))
+
+        # Mock: op2 has 3 TIs (mapped by 3)
+        mock_supervisor_comms.send.return_value = TICount(count=3)
+
+        # For map_index=1 with ti_count=3, should return 1 (same index)
+        result = get_relevant_map_indexes(
+            task=op2,
+            run_id="run_id",
+            map_index=1,
+            ti_count=3,
+            relative=op1,
+            dag_id="dag1",
+        )
+        assert result == 1
+
+    def test_unmapped_task_pulling_from_mapped_returns_none(self):
+        """Unmapped task pulling from mapped upstream should return None (pull 
all)."""
+        with DAG("dag1"):
+            op1 = BaseOperator(task_id="op1")
+            op2 = BaseOperator(task_id="op2")
+            op1 >> op2
+
+        # op2 is not in a mapped group, so there's no common ancestor
+        result = get_relevant_map_indexes(
+            task=op2,
+            run_id="run_id",
+            map_index=0,
+            ti_count=1,
+            relative=op1,
+            dag_id="dag1",
+        )
+        assert result is None

Reply via email to