This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f8abdcd7353 Implement slice on LazyXComSequence (#50117)
f8abdcd7353 is described below

commit f8abdcd735368f830ec037ad10e06675807f9485
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Jun 2 16:13:01 2025 +0800

    Implement slice on LazyXComSequence (#50117)
---
 .../api_fastapi/execution_api/datamodels/xcom.py   |  14 ++-
 .../api_fastapi/execution_api/routes/xcoms.py      | 132 ++++++++++++++++++++-
 .../execution_api/versions/head/test_xcoms.py      |  78 ++++++++++--
 .../versions/v2025_04_28/test_xcom.py              | 107 +++++++++++++++++
 task-sdk/src/airflow/sdk/api/client.py             |  29 ++++-
 .../src/airflow/sdk/api/datamodels/_generated.py   |  26 +++-
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  36 +++++-
 .../airflow/sdk/execution_time/lazy_sequence.py    |  94 ++++++++-------
 .../src/airflow/sdk/execution_time/supervisor.py   |  14 ++-
 .../task_sdk/execution_time/test_lazy_sequence.py  |  33 +++++-
 .../task_sdk/execution_time/test_supervisor.py     |  25 +++-
 11 files changed, 516 insertions(+), 72 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py
index ae7ddd26761..4df3e3f74f0 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import sys
 from typing import Any
 
-from pydantic import JsonValue
+from pydantic import JsonValue, RootModel
 
 from airflow.api_fastapi.core_api.base import BaseModel
 
@@ -36,3 +36,15 @@ class XComResponse(BaseModel):
     key: str
     value: JsonValue
     """The returned XCom value in a JSON-compatible format."""
+
+
+class XComSequenceIndexResponse(RootModel):
+    """XCom schema with minimal structure for index-based access."""
+
+    root: JsonValue
+
+
+class XComSequenceSliceResponse(RootModel):
+    """XCom schema with minimal structure for slice-based access."""
+
+    root: list[JsonValue]
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index d4c8c5160cd..a9ed4a5b48d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -27,7 +27,11 @@ from sqlalchemy import delete
 from sqlalchemy.sql.selectable import Select
 
 from airflow.api_fastapi.common.db.common import SessionDep
-from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
+from airflow.api_fastapi.execution_api.datamodels.xcom import (
+    XComResponse,
+    XComSequenceIndexResponse,
+    XComSequenceSliceResponse,
+)
 from airflow.api_fastapi.execution_api.deps import JWTBearerDep
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XComModel
@@ -184,6 +188,132 @@ def get_xcom(
     return XComResponse(key=key, value=result.value)
 
 
[email protected](
+    "/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}",
+    description="Get a single XCom value from a mapped task by sequence index",
+)
+def get_mapped_xcom_by_index(
+    dag_id: str,
+    run_id: str,
+    task_id: str,
+    key: str,
+    offset: int,
+    session: SessionDep,
+) -> XComSequenceIndexResponse:
+    xcom_query = XComModel.get_many(
+        run_id=run_id,
+        key=key,
+        task_ids=task_id,
+        dag_ids=dag_id,
+        session=session,
+    )
+    xcom_query = xcom_query.order_by(None)
+    if offset >= 0:
+        xcom_query = 
xcom_query.order_by(XComModel.map_index.asc()).offset(offset)
+    else:
+        xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 
- offset)
+
+    if (result := xcom_query.limit(1).first()) is None:
+        message = (
+            f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG 
run {run_id!r} of {dag_id!r}"
+        )
+        raise HTTPException(
+            status_code=status.HTTP_404_NOT_FOUND,
+            detail={"reason": "not_found", "message": message},
+        )
+    return XComSequenceIndexResponse(result.value)
+
+
+class GetXComSliceFilterParams(BaseModel):
+    """Class to house slice params."""
+
+    start: int | None = None
+    stop: int | None = None
+    step: int | None = None
+
+
[email protected](
+    "/{dag_id}/{run_id}/{task_id}/{key}/slice",
+    description="Get XCom values from a mapped task by sequence slice",
+)
+def get_mapped_xcom_by_slice(
+    dag_id: str,
+    run_id: str,
+    task_id: str,
+    key: str,
+    params: Annotated[GetXComSliceFilterParams, Query()],
+    session: SessionDep,
+) -> XComSequenceSliceResponse:
+    query = XComModel.get_many(
+        run_id=run_id,
+        key=key,
+        task_ids=task_id,
+        dag_ids=dag_id,
+        session=session,
+    )
+    query = query.order_by(None)
+
+    step = params.step or 1
+
+    # We want to optimize negative slicing (e.g. seq[-10:]) by not doing an
+    # additional COUNT query if possible. This is possible unless both start 
and
+    # stop are explicitly given and have different signs.
+    if (start := params.start) is None:
+        if (stop := params.stop) is None:
+            if step >= 0:
+                query = query.order_by(XComModel.map_index.asc())
+            else:
+                query = query.order_by(XComModel.map_index.desc())
+                step = -step
+        elif stop >= 0:
+            query = query.order_by(XComModel.map_index.asc())
+            if step >= 0:
+                query = query.limit(stop)
+            else:
+                query = query.offset(stop + 1)
+        else:
+            query = query.order_by(XComModel.map_index.desc())
+            step = -step
+            if step > 0:
+                query = query.limit(-stop - 1)
+            else:
+                query = query.offset(-stop)
+    elif start >= 0:
+        query = query.order_by(XComModel.map_index.asc())
+        if (stop := params.stop) is None:
+            if step >= 0:
+                query = query.offset(start)
+            else:
+                query = query.limit(start + 1)
+        else:
+            if stop < 0:
+                stop += get_query_count(query, session=session)
+            if step >= 0:
+                query = query.slice(start, stop)
+            else:
+                query = query.slice(stop + 1, start + 1)
+    else:
+        query = query.order_by(XComModel.map_index.desc())
+        step = -step
+        if (stop := params.stop) is None:
+            if step > 0:
+                query = query.offset(-start - 1)
+            else:
+                query = query.limit(-start)
+        else:
+            if stop >= 0:
+                stop -= get_query_count(query, session=session)
+            if step > 0:
+                query = query.slice(-1 - start, -1 - stop)
+            else:
+                query = query.slice(-stop, -start)
+
+    values = [row.value for row in query.with_entities(XComModel.value)]
+    if step != 1:
+        values = values[::step]
+    return XComSequenceSliceResponse(values)
+
+
 if sys.version_info < (3, 12):
     # zmievsa/cadwyn#262
     # Setting this to "Any" doesn't have any impact on the API as it has to be 
parsed as valid JSON regardless
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
index 951fbf5cbae..1b10e81cd23 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
@@ -1,5 +1,4 @@
 # Licensed to the Apache Software Foundation (ASF) under one
-# Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
 # regarding copyright ownership.  The ASF licenses this file
@@ -20,6 +19,7 @@ from __future__ import annotations
 
 import contextlib
 import logging
+import urllib.parse
 
 import httpx
 import pytest
@@ -148,12 +148,12 @@ class TestXComsGetEndpoint:
                 },
                 id="-4",
             ),
-            pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
-            pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
-            pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
-            pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
-            pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
-            pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
+            pytest.param(-3, 200, "f", id="-3"),
+            pytest.param(-2, 200, "o", id="-2"),
+            pytest.param(-1, 200, "b", id="-1"),
+            pytest.param(0, 200, "f", id="0"),
+            pytest.param(1, 200, "o", id="1"),
+            pytest.param(2, 200, "b", id="2"),
             pytest.param(
                 3,
                 404,
@@ -207,10 +207,72 @@ class TestXComsGetEndpoint:
             session.add(x)
         session.commit()
 
-        response = 
client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
+        response = 
client.get(f"/execution/xcoms/dag/runid/task/xcom_1/item/{offset}")
         assert response.status_code == expected_status
         assert response.json() == expected_json
 
+    @pytest.mark.parametrize(
+        "key",
+        [
+            pytest.param(slice(None, None, None), id=":"),
+            pytest.param(slice(None, None, -2), id="::-2"),
+            pytest.param(slice(None, 2, None), id=":2"),
+            pytest.param(slice(None, 2, -1), id=":2:-1"),
+            pytest.param(slice(None, -2, None), id=":-2"),
+            pytest.param(slice(None, -2, -1), id=":-2:-1"),
+            pytest.param(slice(1, None, None), id="1:"),
+            pytest.param(slice(2, None, -1), id="2::-1"),
+            pytest.param(slice(1, 2, None), id="1:2"),
+            pytest.param(slice(2, 1, -1), id="2:1:-1"),
+            pytest.param(slice(1, -1, None), id="1:-1"),
+            pytest.param(slice(2, -2, -1), id="2:-2:-1"),
+            pytest.param(slice(-2, None, None), id="-2:"),
+            pytest.param(slice(-1, None, -1), id="-1::-1"),
+            pytest.param(slice(-2, -1, None), id="-2:-1"),
+            pytest.param(slice(-1, -3, -1), id="-1:-3:-1"),
+        ],
+    )
+    def test_xcom_get_with_slice(self, client, dag_maker, session, key):
+        xcom_values = ["f", None, "o", "b"]
+
+        class MyOperator(EmptyOperator):
+            def __init__(self, *, x, **kwargs):
+                super().__init__(**kwargs)
+                self.x = x
+
+        with dag_maker(dag_id="dag"):
+            MyOperator.partial(task_id="task").expand(x=xcom_values)
+        dag_run = dag_maker.create_dagrun(run_id="runid")
+        tis = {ti.map_index: ti for ti in dag_run.task_instances}
+
+        for map_index, db_value in enumerate(xcom_values):
+            if db_value is None:  # We don't put None to XCom.
+                continue
+            ti = tis[map_index]
+            x = XComModel(
+                key="xcom_1",
+                value=db_value,
+                dag_run_id=ti.dag_run.id,
+                run_id=ti.run_id,
+                task_id=ti.task_id,
+                dag_id=ti.dag_id,
+                map_index=map_index,
+            )
+            session.add(x)
+        session.commit()
+
+        qs = {}
+        if key.start is not None:
+            qs["start"] = key.start
+        if key.stop is not None:
+            qs["stop"] = key.stop
+        if key.step is not None:
+            qs["step"] = key.step
+
+        response = 
client.get(f"/execution/xcoms/dag/runid/task/xcom_1/slice?{urllib.parse.urlencode(qs)}")
+        assert response.status_code == 200
+        assert response.json() == ["f", "o", "b"][key]
+
 
 class TestXComsSetEndpoint:
     @pytest.mark.parametrize(
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py
new file mode 100644
index 00000000000..1de65d493d1
--- /dev/null
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.models.xcom import XComModel
+from airflow.providers.standard.operators.empty import EmptyOperator
+
+pytestmark = pytest.mark.db_test
+
+
+class TestXComsGetEndpoint:
+    @pytest.mark.parametrize(
+        "offset, expected_status, expected_json",
+        [
+            pytest.param(
+                -4,
+                404,
+                {
+                    "detail": {
+                        "reason": "not_found",
+                        "message": (
+                            "XCom with key='xcom_1' offset=-4 not found "
+                            "for task 'task' in DAG run 'runid' of 'dag'"
+                        ),
+                    },
+                },
+                id="-4",
+            ),
+            pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
+            pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
+            pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
+            pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
+            pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
+            pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
+            pytest.param(
+                3,
+                404,
+                {
+                    "detail": {
+                        "reason": "not_found",
+                        "message": (
+                            "XCom with key='xcom_1' offset=3 not found "
+                            "for task 'task' in DAG run 'runid' of 'dag'"
+                        ),
+                    },
+                },
+                id="3",
+            ),
+        ],
+    )
+    def test_xcom_get_with_offset(
+        self,
+        client,
+        dag_maker,
+        session,
+        offset,
+        expected_status,
+        expected_json,
+    ):
+        xcom_values = ["f", None, "o", "b"]
+
+        class MyOperator(EmptyOperator):
+            def __init__(self, *, x, **kwargs):
+                super().__init__(**kwargs)
+                self.x = x
+
+        with dag_maker(dag_id="dag"):
+            MyOperator.partial(task_id="task").expand(x=xcom_values)
+
+        dag_run = dag_maker.create_dagrun(run_id="runid")
+        tis = {ti.map_index: ti for ti in dag_run.task_instances}
+        for map_index, db_value in enumerate(xcom_values):
+            if db_value is None:  # We don't put None to XCom.
+                continue
+            ti = tis[map_index]
+            x = XComModel(
+                key="xcom_1",
+                value=db_value,
+                dag_run_id=ti.dag_run.id,
+                run_id=ti.run_id,
+                task_id=ti.task_id,
+                dag_id=ti.dag_id,
+                map_index=map_index,
+            )
+            session.add(x)
+        session.commit()
+
+        response = 
client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
+        assert response.status_code == expected_status
+        assert response.json() == expected_json
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 1fd548319e7..b9d0a4511ea 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -58,6 +58,8 @@ from airflow.sdk.api.datamodels._generated import (
     VariablePostBody,
     VariableResponse,
     XComResponse,
+    XComSequenceIndexResponse,
+    XComSequenceSliceResponse,
 )
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import (
@@ -442,10 +444,9 @@ class XComOperations:
         task_id: str,
         key: str,
         offset: int,
-    ) -> XComResponse | ErrorResponse:
-        params = {"offset": offset}
+    ) -> XComSequenceIndexResponse | ErrorResponse:
         try:
-            resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", 
params=params)
+            resp = 
self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}")
         except ServerResponseError as e:
             if e.response.status_code == HTTPStatus.NOT_FOUND:
                 log.error(
@@ -469,7 +470,27 @@ class XComOperations:
                     },
                 )
             raise
-        return XComResponse.model_validate_json(resp.read())
+        return XComSequenceIndexResponse.model_validate_json(resp.read())
+
+    def get_sequence_slice(
+        self,
+        dag_id: str,
+        run_id: str,
+        task_id: str,
+        key: str,
+        start: int | None,
+        stop: int | None,
+        step: int | None,
+    ) -> XComSequenceSliceResponse:
+        params = {}
+        if start is not None:
+            params["start"] = start
+        if stop is not None:
+            params["stop"] = stop
+        if step is not None:
+            params["step"] = step
+        resp = 
self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params)
+        return XComSequenceSliceResponse.model_validate_json(resp.read())
 
 
 class AssetOperations:
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 3efae80e5b6..f6b1c907ef5 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -25,7 +25,7 @@ from enum import Enum
 from typing import Annotated, Any, Final, Literal
 from uuid import UUID
 
-from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue
+from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue, 
RootModel
 
 API_VERSION: Final[str] = "2025-05-20"
 
@@ -356,6 +356,30 @@ class XComResponse(BaseModel):
     value: JsonValue
 
 
+class XComSequenceIndexResponse(RootModel[JsonValue]):
+    root: Annotated[
+        JsonValue,
+        Field(
+            description="XCom schema with minimal structure for index-based 
access.",
+            title="XComSequenceIndexResponse",
+        ),
+    ]
+
+
+class XComSequenceSliceResponse(RootModel[list[JsonValue]]):
+    """
+    XCom schema with minimal structure for slice-based access.
+    """
+
+    root: Annotated[
+        list[JsonValue],
+        Field(
+            description="XCom schema with minimal structure for slice-based 
access.",
+            title="XComSequenceSliceResponse",
+        ),
+    ]
+
+
 class TaskInstance(BaseModel):
     """
     Schema for TaskInstance model with minimal required fields needed for 
Runtime.
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index a25ba574582..ecc34852252 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -74,6 +74,8 @@ from airflow.sdk.api.datamodels._generated import (
     TriggerDAGRunPayload,
     VariableResponse,
     XComResponse,
+    XComSequenceIndexResponse,
+    XComSequenceSliceResponse,
 )
 from airflow.sdk.exceptions import ErrorType
 
@@ -227,6 +229,24 @@ class XComCountResponse(BaseModel):
     type: Literal["XComLengthResponse"] = "XComLengthResponse"
 
 
+class XComSequenceIndexResult(BaseModel):
+    root: JsonValue
+    type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult"
+
+    @classmethod
+    def from_response(cls, response: XComSequenceIndexResponse) -> 
XComSequenceIndexResult:
+        return cls(root=response.root, type="XComSequenceIndexResult")
+
+
+class XComSequenceSliceResult(BaseModel):
+    root: list[JsonValue]
+    type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult"
+
+    @classmethod
+    def from_response(cls, response: XComSequenceSliceResponse) -> 
XComSequenceSliceResult:
+        return cls(root=response.root, type="XComSequenceSliceResult")
+
+
 class ConnectionResult(ConnectionResponse):
     type: Literal["ConnectionResult"] = "ConnectionResult"
 
@@ -352,8 +372,10 @@ ToTask = Annotated[
         TICount,
         TaskStatesResult,
         VariableResult,
-        XComResult,
         XComCountResponse,
+        XComResult,
+        XComSequenceIndexResult,
+        XComSequenceSliceResult,
         OKResponse,
     ],
     Field(discriminator="type"),
@@ -451,6 +473,17 @@ class GetXComSequenceItem(BaseModel):
     type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
 
 
+class GetXComSequenceSlice(BaseModel):
+    key: str
+    dag_id: str
+    run_id: str
+    task_id: str
+    start: int | None
+    stop: int | None
+    step: int | None
+    type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice"
+
+
 class SetXCom(BaseModel):
     key: str
     value: Annotated[
@@ -616,6 +649,7 @@ ToSupervisor = Annotated[
         GetXCom,
         GetXComCount,
         GetXComSequenceItem,
+        GetXComSequenceSlice,
         PutVariable,
         RescheduleTask,
         RetryTask,
diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py 
b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
index 0fbfcf39498..9cf9acfac81 100644
--- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
+++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+import collections
 import itertools
 from collections.abc import Iterator, Sequence
 from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
@@ -30,6 +31,11 @@ if TYPE_CHECKING:
 
 T = TypeVar("T")
 
+# This is used to wrap values from the API so the structure is compatible with
+# ``XCom.deserialize_value``. We don't want to wrap the API values in a nested
+# {"value": value} dict since it wastes bandwidth.
+_XComWrapper = collections.namedtuple("_XComWrapper", "value")
+
 log = structlog.get_logger(logger_name=__name__)
 
 
@@ -98,7 +104,7 @@ class LazyXComSequence(Sequence[T]):
             if isinstance(msg, ErrorResponse):
                 raise RuntimeError(msg)
             if not isinstance(msg, XComCountResponse):
-                raise TypeError(f"Got unexpected response to GetXComCount: 
{msg}")
+                raise TypeError(f"Got unexpected response to GetXComCount: 
{msg!r}")
             self._len = msg.len
         return self._len
 
@@ -109,41 +115,42 @@ class LazyXComSequence(Sequence[T]):
     def __getitem__(self, key: slice) -> Sequence[T]: ...
 
     def __getitem__(self, key: int | slice) -> T | Sequence[T]:
-        if not isinstance(key, (int, slice)):
-            raise TypeError(f"Sequence indices must be integers or slices, not 
{type(key).__name__}")
-
-        if isinstance(key, slice):
-            raise TypeError("slice is not implemented yet")
-        # TODO...
-        # This implements the slicing syntax. We want to optimize negative 
slicing (e.g. seq[-10:]) by not
-        # doing an additional COUNT query (via HEAD http request) if possible. 
We can do this unless the
-        # start and stop have different signs (i.e. one is positive and 
another negative).
-        # start, stop, reverse = _coerce_slice(key)
-        # if start >= 0:
-        #     if stop is None:
-        #         stmt = self._select_asc.offset(start)
-        #     elif stop >= 0:
-        #         stmt = self._select_asc.slice(start, stop)
-        #     else:
-        #         stmt = self._select_asc.slice(start, len(self) + stop)
-        #     rows = [self._process_row(row) for row in 
self._session.execute(stmt)]
-        #     if reverse:
-        #         rows.reverse()
-        # else:
-        #     if stop is None:
-        #         stmt = self._select_desc.limit(-start)
-        #     elif stop < 0:
-        #         stmt = self._select_desc.slice(-stop, -start)
-        #     else:
-        #         stmt = self._select_desc.slice(len(self) - stop, -start)
-        #     rows = [self._process_row(row) for row in 
self._session.execute(stmt)]
-        #     if not reverse:
-        #         rows.reverse()
-        # return rows
-        from airflow.sdk.execution_time.comms import GetXComSequenceItem, 
XComResult
+        from airflow.sdk.execution_time.comms import (
+            ErrorResponse,
+            GetXComSequenceItem,
+            GetXComSequenceSlice,
+            XComSequenceIndexResult,
+            XComSequenceSliceResult,
+        )
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
         from airflow.sdk.execution_time.xcom import XCom
 
+        if isinstance(key, slice):
+            start, stop, step = _coerce_slice(key)
+            with SUPERVISOR_COMMS.lock:
+                source = (xcom_arg := self._xcom_arg).operator
+                SUPERVISOR_COMMS.send_request(
+                    log=log,
+                    msg=GetXComSequenceSlice(
+                        key=xcom_arg.key,
+                        dag_id=source.dag_id,
+                        task_id=source.task_id,
+                        run_id=self._ti.run_id,
+                        start=start,
+                        stop=stop,
+                        step=step,
+                    ),
+                )
+                msg = SUPERVISOR_COMMS.get_message()
+                if not isinstance(msg, XComSequenceSliceResult):
+                    raise TypeError(f"Got unexpected response to 
GetXComSequenceSlice: {msg!r}")
+            return [XCom.deserialize_value(_XComWrapper(value)) for value in 
msg.root]
+
+        if not isinstance(key, int):
+            if (index := getattr(key, "__index__", None)) is not None:
+                key = index()
+            raise TypeError(f"Sequence indices must be integers or slices not 
{type(key).__name__}")
+
         with SUPERVISOR_COMMS.lock:
             source = (xcom_arg := self._xcom_arg).operator
             SUPERVISOR_COMMS.send_request(
@@ -157,13 +164,14 @@ class LazyXComSequence(Sequence[T]):
                 ),
             )
             msg = SUPERVISOR_COMMS.get_message()
-
-        if not isinstance(msg, XComResult):
+        if isinstance(msg, ErrorResponse):
             raise IndexError(key)
-        return XCom.deserialize_value(msg)
+        if not isinstance(msg, XComSequenceIndexResult):
+            raise TypeError(f"Got unexpected response to GetXComSequenceItem: 
{msg!r}")
+        return XCom.deserialize_value(_XComWrapper(msg.root))
 
 
-def _coerce_index(value: Any) -> int | None:
+def _coerce_slice_index(value: Any) -> int | None:
     """
     Check slice attribute's type and convert it to int.
 
@@ -177,17 +185,13 @@ def _coerce_index(value: Any) -> int | None:
     raise TypeError("slice indices must be integers or None or have an 
__index__ method")
 
 
-def _coerce_slice(key: slice) -> tuple[int, int | None, bool]:
+def _coerce_slice(key: slice) -> tuple[int | None, int | None, int | None]:
     """
     Check slice content and convert it for SQL.
 
     See CPython documentation on this:
     https://docs.python.org/3/reference/datamodel.html#slice-objects
     """
-    if key.step is None or key.step == 1:
-        reverse = False
-    elif key.step == -1:
-        reverse = True
-    else:
-        raise ValueError("non-trivial slice step not supported")
-    return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse
+    if (step := _coerce_slice_index(key.step)) == 0:
+        raise ValueError("slice step cannot be zero")
+    return _coerce_slice_index(key.start), _coerce_slice_index(key.stop), step
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 1006b861378..65d05cc023d 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -60,7 +60,7 @@ from airflow.sdk.api.datamodels._generated import (
     TaskInstanceState,
     TaskStatesResponse,
     VariableResponse,
-    XComResponse,
+    XComSequenceIndexResponse,
 )
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import (
@@ -87,6 +87,7 @@ from airflow.sdk.execution_time.comms import (
     GetXCom,
     GetXComCount,
     GetXComSequenceItem,
+    GetXComSequenceSlice,
     PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
@@ -103,6 +104,8 @@ from airflow.sdk.execution_time.comms import (
     VariableResult,
     XComCountResponse,
     XComResult,
+    XComSequenceIndexResult,
+    XComSequenceSliceResult,
 )
 from airflow.sdk.execution_time.secrets_masker import mask_secret
 
@@ -1108,10 +1111,15 @@ class ActivitySubprocess(WatchedSubprocess):
             xcom = self.client.xcoms.get_sequence_item(
                 msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
             )
-            if isinstance(xcom, XComResponse):
-                resp = XComResult.from_xcom_response(xcom)
+            if isinstance(xcom, XComSequenceIndexResponse):
+                resp = XComSequenceIndexResult.from_response(xcom)
             else:
                 resp = xcom
+        elif isinstance(msg, GetXComSequenceSlice):
+            xcoms = self.client.xcoms.get_sequence_slice(
+                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.start, 
msg.stop, msg.step
+            )
+            resp = XComSequenceSliceResult.from_response(xcoms)
         elif isinstance(msg, DeferTask):
             self._terminal_state = TaskInstanceState.DEFERRED
             self._rendered_map_index = msg.rendered_map_index
diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py 
b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
index 2430f85f35e..e4943196a09 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
@@ -28,8 +28,10 @@ from airflow.sdk.execution_time.comms import (
     ErrorResponse,
     GetXComCount,
     GetXComSequenceItem,
+    GetXComSequenceSlice,
     XComCountResponse,
-    XComResult,
+    XComSequenceIndexResult,
+    XComSequenceSliceResult,
 )
 from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
 from airflow.sdk.execution_time.xcom import resolve_xcom_backend
@@ -75,7 +77,7 @@ def test_iter(mock_supervisor_comms, lazy_sequence):
     it = iter(lazy_sequence)
 
     mock_supervisor_comms.get_message.side_effect = [
-        XComResult(key="return_value", value="f"),
+        XComSequenceIndexResult(root="f"),
         ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": 
"sorry!"}),
     ]
     assert list(it) == ["f"]
@@ -104,7 +106,7 @@ def test_iter(mock_supervisor_comms, lazy_sequence):
 
 
 def test_getitem_index(mock_supervisor_comms, lazy_sequence):
-    mock_supervisor_comms.get_message.return_value = 
XComResult(key="return_value", value="f")
+    mock_supervisor_comms.get_message.return_value = 
XComSequenceIndexResult(root="f")
     assert lazy_sequence[4] == "f"
     assert mock_supervisor_comms.send_request.mock_calls == [
         call(
@@ -121,12 +123,12 @@ def test_getitem_index(mock_supervisor_comms, 
lazy_sequence):
 
 
 @conf_vars({("core", "xcom_backend"): 
"task_sdk.execution_time.test_lazy_sequence.CustomXCom"})
-def test_getitem_calls_correct_deserialise(mock_supervisor_comms, 
lazy_sequence):
-    mock_supervisor_comms.get_message.return_value = 
XComResult(key="return_value", value="some-value")
+def test_getitem_calls_correct_deserialise(monkeypatch, mock_supervisor_comms, 
lazy_sequence):
+    mock_supervisor_comms.get_message.return_value = 
XComSequenceIndexResult(root="some-value")
 
     xcom = resolve_xcom_backend()
     assert xcom.__name__ == "CustomXCom"
-    airflow.sdk.execution_time.xcom.XCom = xcom
+    monkeypatch.setattr(airflow.sdk.execution_time.xcom, "XCom", xcom)
 
     assert lazy_sequence[4] == "Made with CustomXCom: some-value"
     assert mock_supervisor_comms.send_request.mock_calls == [
@@ -163,3 +165,22 @@ def test_getitem_indexerror(mock_supervisor_comms, 
lazy_sequence):
             ),
         ),
     ]
+
+
+def test_getitem_slice(mock_supervisor_comms, lazy_sequence):
+    mock_supervisor_comms.get_message.return_value = 
XComSequenceSliceResult(root=[6, 4, 1])
+    assert lazy_sequence[:5] == [6, 4, 1]
+    assert mock_supervisor_comms.send_request.mock_calls == [
+        call(
+            log=ANY,
+            msg=GetXComSequenceSlice(
+                key="return_value",
+                dag_id="dag",
+                task_id="task",
+                run_id="run",
+                start=None,
+                stop=5,
+                step=None,
+            ),
+        ),
+    ]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 4696908bae7..86a2e747e0f 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -76,6 +76,7 @@ from airflow.sdk.execution_time.comms import (
     GetVariable,
     GetXCom,
     GetXComSequenceItem,
+    GetXComSequenceSlice,
     OKResponse,
     PrevSuccessfulDagRunResult,
     PutVariable,
@@ -91,6 +92,8 @@ from airflow.sdk.execution_time.comms import (
     TriggerDagRun,
     VariableResult,
     XComResult,
+    XComSequenceIndexResult,
+    XComSequenceSliceResult,
 )
 from airflow.sdk.execution_time.supervisor import (
     BUFFER_SIZE,
@@ -1618,11 +1621,11 @@ class TestHandleRequest:
                     task_id="test_task",
                     offset=0,
                 ),
-                
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
+                b'{"root":"test_value","type":"XComSequenceIndexResult"}\n',
                 "xcoms.get_sequence_item",
                 ("test_dag", "test_run", "test_task", "test_key", 0),
                 {},
-                XComResult(key="test_key", value="test_value"),
+                XComSequenceIndexResult(root="test_value"),
                 None,
                 id="get_xcom_seq_item",
             ),
@@ -1642,6 +1645,24 @@ class TestHandleRequest:
                 None,
                 id="get_xcom_seq_item_not_found",
             ),
+            pytest.param(
+                GetXComSequenceSlice(
+                    key="test_key",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                ),
+                b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n',
+                "xcoms.get_sequence_slice",
+                ("test_dag", "test_run", "test_task", "test_key", None, None, 
None),
+                {},
+                XComSequenceSliceResult(root=["foo", "bar"]),
+                None,
+                id="get_xcom_seq_slice",
+            ),
         ],
     )
     def test_handle_requests(


Reply via email to