This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-0-test by this push:
     new 9a18b9f05c1 [v3-0-test] Implement offset to get the xcom for a given 
task by offset. (#50011) (#50048)
9a18b9f05c1 is described below

commit 9a18b9f05c10546f1400aef708c144e97bd0b629
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Thu May 1 01:49:25 2025 +0530

    [v3-0-test] Implement offset to get the xcom for a given task by offset. 
(#50011) (#50048)
    
    (cherry picked from commit 58c736a197b15949210faea705132f74a65b8eec)
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
    Co-authored-by: Karthikeyan Singaravelan <[email protected]>
---
 .../api_fastapi/execution_api/routes/xcoms.py      |  30 +++--
 .../execution_api/versions/head/test_xcoms.py      |  81 +++++++++++++
 .../tests/unit/models/test_mappedoperator.py       |  12 +-
 task-sdk/src/airflow/sdk/api/client.py             |  36 ++++++
 task-sdk/src/airflow/sdk/definitions/xcom_arg.py   |   2 +-
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  10 ++
 .../airflow/sdk/execution_time/lazy_sequence.py    | 103 ++++++++--------
 .../src/airflow/sdk/execution_time/supervisor.py   |  10 ++
 .../task_sdk/execution_time/test_lazy_sequence.py  | 131 +++++++++++++++++++++
 .../task_sdk/execution_time/test_supervisor.py     |  31 +++++
 10 files changed, 386 insertions(+), 60 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index 53a0582015a..9f5e7d686a3 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -126,6 +126,7 @@ class GetXcomFilterParams(BaseModel):
 
     map_index: int = -1
     include_prior_dates: bool = False
+    offset: int | None = None
 
 
 @router.get(
@@ -141,18 +142,23 @@ def get_xcom(
     params: Annotated[GetXcomFilterParams, Query()],
 ) -> XComResponse:
     """Get an Airflow XCom from database - not other XCom Backends."""
-    # The xcom_query allows no map_index to be passed. This endpoint should 
always return just a single item,
-    # so we override that query value
     xcom_query = XComModel.get_many(
         run_id=run_id,
         key=key,
         task_ids=task_id,
         dag_ids=dag_id,
-        map_indexes=params.map_index,
         include_prior_dates=params.include_prior_dates,
         session=session,
     )
-    xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
+    if params.offset is not None:
+        xcom_query = 
xcom_query.filter(XComModel.value.is_not(None)).order_by(None)
+        if params.offset >= 0:
+            xcom_query = 
xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
+        else:
+            xcom_query = 
xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
+    else:
+        xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
+
     # We use `BaseXCom.get_many` to fetch XComs directly from the database, 
bypassing the XCom Backend.
     # This avoids deserialization via the backend (e.g., from a remote storage 
like S3) and instead
     # retrieves the raw serialized value from the database. By not relying on 
`XCom.get_many` or `XCom.get_one`
@@ -160,13 +166,19 @@ def get_xcom(
     # performance hits from retrieving large data files into the API server.
     result = xcom_query.limit(1).first()
     if result is None:
-        map_index = params.map_index
+        if params.offset is None:
+            message = (
+                f"XCom with {key=} map_index={params.map_index} not found for "
+                f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
+            )
+        else:
+            message = (
+                f"XCom with {key=} offset={params.offset} not found for "
+                f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
+            )
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
-            detail={
-                "reason": "not_found",
-                "message": f"XCom with {key=} {map_index=} not found for task 
{task_id!r} in DAG run {run_id!r} of {dag_id!r}",
-            },
+            detail={"reason": "not_found", "message": message},
         )
 
     return XComResponse(key=key, value=result.value)
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
index c2b49841b3a..951fbf5cbae 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
@@ -29,6 +29,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import 
XComResponse
 from airflow.models.dagrun import DagRun
 from airflow.models.taskmap import TaskMap
 from airflow.models.xcom import XComModel
+from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.serialization.serde import deserialize, serialize
 from airflow.utils.session import create_session
 
@@ -130,6 +131,86 @@ class TestXComsGetEndpoint:
         }
         assert any(msg.startswith("Checking read XCom access") for msg in 
caplog.messages)
 
+    @pytest.mark.parametrize(
+        "offset, expected_status, expected_json",
+        [
+            pytest.param(
+                -4,
+                404,
+                {
+                    "detail": {
+                        "reason": "not_found",
+                        "message": (
+                            "XCom with key='xcom_1' offset=-4 not found "
+                            "for task 'task' in DAG run 'runid' of 'dag'"
+                        ),
+                    },
+                },
+                id="-4",
+            ),
+            pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
+            pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
+            pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
+            pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
+            pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
+            pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
+            pytest.param(
+                3,
+                404,
+                {
+                    "detail": {
+                        "reason": "not_found",
+                        "message": (
+                            "XCom with key='xcom_1' offset=3 not found "
+                            "for task 'task' in DAG run 'runid' of 'dag'"
+                        ),
+                    },
+                },
+                id="3",
+            ),
+        ],
+    )
+    def test_xcom_get_with_offset(
+        self,
+        client,
+        dag_maker,
+        session,
+        offset,
+        expected_status,
+        expected_json,
+    ):
+        xcom_values = ["f", None, "o", "b"]
+
+        class MyOperator(EmptyOperator):
+            def __init__(self, *, x, **kwargs):
+                super().__init__(**kwargs)
+                self.x = x
+
+        with dag_maker(dag_id="dag"):
+            MyOperator.partial(task_id="task").expand(x=xcom_values)
+        dag_run = dag_maker.create_dagrun(run_id="runid")
+        tis = {ti.map_index: ti for ti in dag_run.task_instances}
+
+        for map_index, db_value in enumerate(xcom_values):
+            if db_value is None:  # We don't put None to XCom.
+                continue
+            ti = tis[map_index]
+            x = XComModel(
+                key="xcom_1",
+                value=db_value,
+                dag_run_id=ti.dag_run.id,
+                run_id=ti.run_id,
+                task_id=ti.task_id,
+                dag_id=ti.dag_id,
+                map_index=map_index,
+            )
+            session.add(x)
+        session.commit()
+
+        response = 
client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
+        assert response.status_code == expected_status
+        assert response.json() == expected_json
+
 
 class TestXComsSetEndpoint:
     @pytest.mark.parametrize(
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py 
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 7f2f3770bb4..dd35f9461dc 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -32,7 +32,7 @@ from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskmap import TaskMap
 from airflow.providers.standard.operators.python import PythonOperator
 from airflow.sdk import setup, task, task_group, teardown
-from airflow.sdk.execution_time.comms import XComCountResponse
+from airflow.sdk.execution_time.comms import XComCountResponse, XComResult
 from airflow.utils.state import TaskInstanceState
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.trigger_rule import TriggerRule
@@ -1270,8 +1270,16 @@ class TestMappedSetupTeardown:
         ) as supervisor_comms:
             # TODO: TaskSDK: this is a bit of a hack that we need to stub this 
at all. `dag.test()` should
             # really work without this!
-            supervisor_comms.get_message.return_value = 
XComCountResponse(len=3)
+            supervisor_comms.get_message.side_effect = [
+                XComCountResponse(len=3),
+                XComResult(key="return_value", value=1),
+                XComCountResponse(len=3),
+                XComResult(key="return_value", value=2),
+                XComCountResponse(len=3),
+                XComResult(key="return_value", value=3),
+            ]
             dr = dag.test()
+            assert supervisor_comms.get_message.call_count == 6
         states = self.get_states(dr)
         expected = {
             "tg_1.my_pre_setup": "success",
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index c64f721b9ae..399954ddfcb 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -429,6 +429,42 @@ class XComOperations:
         # decouple from the server response string
         return OKResponse(ok=True)
 
+    def get_sequence_item(
+        self,
+        dag_id: str,
+        run_id: str,
+        task_id: str,
+        key: str,
+        offset: int,
+    ) -> XComResponse | ErrorResponse:
+        params = {"offset": offset}
+        try:
+            resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", 
params=params)
+        except ServerResponseError as e:
+            if e.response.status_code == HTTPStatus.NOT_FOUND:
+                log.error(
+                    "XCom not found",
+                    dag_id=dag_id,
+                    run_id=run_id,
+                    task_id=task_id,
+                    key=key,
+                    offset=offset,
+                    detail=e.detail,
+                    status_code=e.response.status_code,
+                )
+                return ErrorResponse(
+                    error=ErrorType.XCOM_NOT_FOUND,
+                    detail={
+                        "dag_id": dag_id,
+                        "run_id": run_id,
+                        "task_id": task_id,
+                        "key": key,
+                        "offset": offset,
+                    },
+                )
+            raise
+        return XComResponse.model_validate_json(resp.read())
+
 
 class AssetOperations:
     __slots__ = ("client",)
diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py 
b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
index ab3a5a82b22..1adcb7efaa7 100644
--- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
+++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
@@ -337,7 +337,7 @@ class PlainXComArg(XComArg):
         task_id = self.operator.task_id
 
         if self.operator.is_mapped:
-            return LazyXComSequence[Any](xcom_arg=self, ti=ti)
+            return LazyXComSequence(xcom_arg=self, ti=ti)
         tg = self.operator.get_closest_mapped_task_group()
         result = None
         if tg is None:
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index b4d68086b0c..039e2a5409c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -441,6 +441,15 @@ class GetXComCount(BaseModel):
     type: Literal["GetNumberXComs"] = "GetNumberXComs"
 
 
+class GetXComSequenceItem(BaseModel):
+    key: str
+    dag_id: str
+    run_id: str
+    task_id: str
+    offset: int
+    type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
+
+
 class SetXCom(BaseModel):
     key: str
     value: Annotated[
@@ -605,6 +614,7 @@ ToSupervisor = Annotated[
         GetVariable,
         GetXCom,
         GetXComCount,
+        GetXComSequenceItem,
         PutVariable,
         RescheduleTask,
         RetryTask,
diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py 
b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
index 095ab051fb1..79822787f38 100644
--- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
+++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
@@ -43,10 +43,10 @@ class LazyXComIterator(Iterator[T]):
         if self.index < 0:
             # When iterating backwards, avoid extra HTTP request
             raise StopIteration()
-        val = self.seq._get_item(self.index)
-        if val is None:
-            # None isn't the best signal (it's bad in fact) but it's the best 
we can do until https://github.com/apache/airflow/issues/46426
-            raise StopIteration()
+        try:
+            val = self.seq[self.index]
+        except IndexError:
+            raise StopIteration from None
         self.index += self.dir
         return val
 
@@ -109,52 +109,59 @@ class LazyXComSequence(Sequence[T]):
     def __getitem__(self, key: slice) -> Sequence[T]: ...
 
     def __getitem__(self, key: int | slice) -> T | Sequence[T]:
-        if isinstance(key, int):
-            if key >= 0:
-                return self._get_item(key)
-            # val[-1] etc.
-            return self._get_item(len(self) + key)
+        if not isinstance(key, (int, slice)):
+            raise TypeError(f"Sequence indices must be integers or slices, not 
{type(key).__name__}")
 
         if isinstance(key, slice):
-            # This implements the slicing syntax. We want to optimize negative 
slicing (e.g. seq[-10:]) by not
-            # doing an additional COUNT query (via HEAD http request) if 
possible. We can do this unless the
-            # start and stop have different signs (i.e. one is positive and 
another negative).
-            ...
-        """
-        Todo?
-        elif isinstance(key, slice):
-            start, stop, reverse = _coerce_slice(key)
-            if start >= 0:
-                if stop is None:
-                    stmt = self._select_asc.offset(start)
-                elif stop >= 0:
-                    stmt = self._select_asc.slice(start, stop)
-                else:
-                    stmt = self._select_asc.slice(start, len(self) + stop)
-                rows = [self._process_row(row) for row in 
self._session.execute(stmt)]
-                if reverse:
-                    rows.reverse()
-            else:
-                if stop is None:
-                    stmt = self._select_desc.limit(-start)
-                elif stop < 0:
-                    stmt = self._select_desc.slice(-stop, -start)
-                else:
-                    stmt = self._select_desc.slice(len(self) - stop, -start)
-                rows = [self._process_row(row) for row in 
self._session.execute(stmt)]
-                if not reverse:
-                    rows.reverse()
-            return rows
-        """
-        raise TypeError(f"Sequence indices must be integers or slices, not 
{type(key).__name__}")
-
-    def _get_item(self, index: int) -> T:
-        # TODO: maybe we need to call SUPERVISOR_COMMS manually so we can 
handle not found here?
-        return self._ti.xcom_pull(
-            task_ids=self._xcom_arg.operator.task_id,
-            key=self._xcom_arg.key,
-            map_indexes=index,
-        )
+            raise TypeError("slice is not implemented yet")
+        # TODO...
+        # This implements the slicing syntax. We want to optimize negative 
slicing (e.g. seq[-10:]) by not
+        # doing an additional COUNT query (via HEAD http request) if possible. 
We can do this unless the
+        # start and stop have different signs (i.e. one is positive and 
another negative).
+        # start, stop, reverse = _coerce_slice(key)
+        # if start >= 0:
+        #     if stop is None:
+        #         stmt = self._select_asc.offset(start)
+        #     elif stop >= 0:
+        #         stmt = self._select_asc.slice(start, stop)
+        #     else:
+        #         stmt = self._select_asc.slice(start, len(self) + stop)
+        #     rows = [self._process_row(row) for row in 
self._session.execute(stmt)]
+        #     if reverse:
+        #         rows.reverse()
+        # else:
+        #     if stop is None:
+        #         stmt = self._select_desc.limit(-start)
+        #     elif stop < 0:
+        #         stmt = self._select_desc.slice(-stop, -start)
+        #     else:
+        #         stmt = self._select_desc.slice(len(self) - stop, -start)
+        #     rows = [self._process_row(row) for row in 
self._session.execute(stmt)]
+        #     if not reverse:
+        #         rows.reverse()
+        # return rows
+
+        from airflow.sdk.bases.xcom import BaseXCom
+        from airflow.sdk.execution_time.comms import GetXComSequenceItem, 
XComResult
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        with SUPERVISOR_COMMS.lock:
+            source = (xcom_arg := self._xcom_arg).operator
+            SUPERVISOR_COMMS.send_request(
+                log=log,
+                msg=GetXComSequenceItem(
+                    key=xcom_arg.key,
+                    dag_id=source.dag_id,
+                    task_id=source.task_id,
+                    run_id=self._ti.run_id,
+                    offset=key,
+                ),
+            )
+            msg = SUPERVISOR_COMMS.get_message()
+
+        if not isinstance(msg, XComResult):
+            raise IndexError(key)
+        return BaseXCom.deserialize_value(msg)
 
 
 def _coerce_index(value: Any) -> int | None:
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index c90d6ea5a02..ff804f3f769 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -59,6 +59,7 @@ from airflow.sdk.api.datamodels._generated import (
     TaskStatesResponse,
     TerminalTIState,
     VariableResponse,
+    XComResponse,
 )
 from airflow.sdk.exceptions import ErrorType
 from airflow.sdk.execution_time.comms import (
@@ -84,6 +85,7 @@ from airflow.sdk.execution_time.comms import (
     GetVariable,
     GetXCom,
     GetXComCount,
+    GetXComSequenceItem,
     PrevSuccessfulDagRunResult,
     PutVariable,
     RescheduleTask,
@@ -1034,6 +1036,14 @@ class ActivitySubprocess(WatchedSubprocess):
         elif isinstance(msg, GetXComCount):
             len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, 
msg.key)
             resp = XComCountResponse(len=len)
+        elif isinstance(msg, GetXComSequenceItem):
+            xcom = self.client.xcoms.get_sequence_item(
+                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
+            )
+            if isinstance(xcom, XComResponse):
+                resp = XComResult.from_xcom_response(xcom)
+            else:
+                resp = xcom
         elif isinstance(msg, DeferTask):
             self._terminal_state = IntermediateTIState.DEFERRED
             self.client.task_instances.defer(self.id, msg)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py 
b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
new file mode 100644
index 00000000000..a42572e5df1
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import ANY, Mock, call
+
+import pytest
+
+from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time.comms import (
+    ErrorResponse,
+    GetXComCount,
+    GetXComSequenceItem,
+    XComCountResponse,
+    XComResult,
+)
+from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
+
+
[email protected]
+def mock_operator():
+    return Mock(spec=["dag_id", "task_id"], dag_id="dag", task_id="task")
+
+
[email protected]
+def mock_xcom_arg(mock_operator):
+    return Mock(spec=["operator", "key"], operator=mock_operator, 
key="return_value")
+
+
[email protected]
+def mock_ti():
+    return Mock(spec=["run_id"], run_id="run")
+
+
[email protected]
+def lazy_sequence(mock_xcom_arg, mock_ti):
+    return LazyXComSequence(mock_xcom_arg, mock_ti)
+
+
+def test_len(mock_supervisor_comms, lazy_sequence):
+    mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3)
+    assert len(lazy_sequence) == 3
+    assert mock_supervisor_comms.send_request.mock_calls == [
+        call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag", 
task_id="task", run_id="run")),
+    ]
+
+
+def test_iter(mock_supervisor_comms, lazy_sequence):
+    it = iter(lazy_sequence)
+
+    mock_supervisor_comms.get_message.side_effect = [
+        XComResult(key="return_value", value="f"),
+        ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops": 
"sorry!"}),
+    ]
+    assert list(it) == ["f"]
+    assert mock_supervisor_comms.send_request.mock_calls == [
+        call(
+            log=ANY,
+            msg=GetXComSequenceItem(
+                key="return_value",
+                dag_id="dag",
+                task_id="task",
+                run_id="run",
+                offset=0,
+            ),
+        ),
+        call(
+            log=ANY,
+            msg=GetXComSequenceItem(
+                key="return_value",
+                dag_id="dag",
+                task_id="task",
+                run_id="run",
+                offset=1,
+            ),
+        ),
+    ]
+
+
+def test_getitem_index(mock_supervisor_comms, lazy_sequence):
+    mock_supervisor_comms.get_message.return_value = 
XComResult(key="return_value", value="f")
+    assert lazy_sequence[4] == "f"
+    assert mock_supervisor_comms.send_request.mock_calls == [
+        call(
+            log=ANY,
+            msg=GetXComSequenceItem(
+                key="return_value",
+                dag_id="dag",
+                task_id="task",
+                run_id="run",
+                offset=4,
+            ),
+        ),
+    ]
+
+
+def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence):
+    mock_supervisor_comms.get_message.return_value = ErrorResponse(
+        error=ErrorType.XCOM_NOT_FOUND,
+        detail={"oops": "sorry!"},
+    )
+    with pytest.raises(IndexError) as ctx:
+        lazy_sequence[4]
+    assert ctx.value.args == (4,)
+    assert mock_supervisor_comms.send_request.mock_calls == [
+        call(
+            log=ANY,
+            msg=GetXComSequenceItem(
+                key="return_value",
+                dag_id="dag",
+                task_id="task",
+                run_id="run",
+                offset=4,
+            ),
+        ),
+    ]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 4aff50a5fd4..a3eef617a4e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -74,6 +74,7 @@ from airflow.sdk.execution_time.comms import (
     GetTICount,
     GetVariable,
     GetXCom,
+    GetXComSequenceItem,
     OKResponse,
     PrevSuccessfulDagRunResult,
     PutVariable,
@@ -1436,6 +1437,36 @@ class TestHandleRequest:
                 TaskStatesResult(task_states={"run_id": {"task1": "success", 
"task2": "failed"}}),
                 id="get_task_states",
             ),
+            pytest.param(
+                GetXComSequenceItem(
+                    key="test_key",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    offset=0,
+                ),
+                
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
+                "xcoms.get_sequence_item",
+                ("test_dag", "test_run", "test_task", "test_key", 0),
+                {},
+                XComResult(key="test_key", value="test_value"),
+                id="get_xcom_seq_item",
+            ),
+            pytest.param(
+                GetXComSequenceItem(
+                    key="test_key",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    offset=2,
+                ),
+                
b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n',
+                "xcoms.get_sequence_item",
+                ("test_dag", "test_run", "test_task", "test_key", 2),
+                {},
+                ErrorResponse(error=ErrorType.XCOM_NOT_FOUND),
+                id="get_xcom_seq_item_not_found",
+            ),
         ],
     )
     def test_handle_requests(

Reply via email to