This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-0-test by this push:
new 9a18b9f05c1 [v3-0-test] Implement offset to get the xcom for a given
task by offset. (#50011) (#50048)
9a18b9f05c1 is described below
commit 9a18b9f05c10546f1400aef708c144e97bd0b629
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Thu May 1 01:49:25 2025 +0530
[v3-0-test] Implement offset to get the xcom for a given task by offset.
(#50011) (#50048)
(cherry picked from commit 58c736a197b15949210faea705132f74a65b8eec)
Co-authored-by: Tzu-ping Chung <[email protected]>
Co-authored-by: Karthikeyan Singaravelan <[email protected]>
---
.../api_fastapi/execution_api/routes/xcoms.py | 30 +++--
.../execution_api/versions/head/test_xcoms.py | 81 +++++++++++++
.../tests/unit/models/test_mappedoperator.py | 12 +-
task-sdk/src/airflow/sdk/api/client.py | 36 ++++++
task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 2 +-
task-sdk/src/airflow/sdk/execution_time/comms.py | 10 ++
.../airflow/sdk/execution_time/lazy_sequence.py | 103 ++++++++--------
.../src/airflow/sdk/execution_time/supervisor.py | 10 ++
.../task_sdk/execution_time/test_lazy_sequence.py | 131 +++++++++++++++++++++
.../task_sdk/execution_time/test_supervisor.py | 31 +++++
10 files changed, 386 insertions(+), 60 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index 53a0582015a..9f5e7d686a3 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -126,6 +126,7 @@ class GetXcomFilterParams(BaseModel):
map_index: int = -1
include_prior_dates: bool = False
+ offset: int | None = None
@router.get(
@@ -141,18 +142,23 @@ def get_xcom(
params: Annotated[GetXcomFilterParams, Query()],
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
- # The xcom_query allows no map_index to be passed. This endpoint should
always return just a single item,
- # so we override that query value
xcom_query = XComModel.get_many(
run_id=run_id,
key=key,
task_ids=task_id,
dag_ids=dag_id,
- map_indexes=params.map_index,
include_prior_dates=params.include_prior_dates,
session=session,
)
- xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
+ if params.offset is not None:
+ xcom_query =
xcom_query.filter(XComModel.value.is_not(None)).order_by(None)
+ if params.offset >= 0:
+ xcom_query =
xcom_query.order_by(XComModel.map_index.asc()).offset(params.offset)
+ else:
+ xcom_query =
xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - params.offset)
+ else:
+ xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
+
# We use `BaseXCom.get_many` to fetch XComs directly from the database,
bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage
like S3) and instead
# retrieves the raw serialized value from the database. By not relying on
`XCom.get_many` or `XCom.get_one`
@@ -160,13 +166,19 @@ def get_xcom(
# performance hits from retrieving large data files into the API server.
result = xcom_query.limit(1).first()
if result is None:
- map_index = params.map_index
+ if params.offset is None:
+ message = (
+ f"XCom with {key=} map_index={params.map_index} not found for "
+ f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
+ )
+ else:
+ message = (
+ f"XCom with {key=} offset={params.offset} not found for "
+ f"task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
+ )
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
- detail={
- "reason": "not_found",
- "message": f"XCom with {key=} {map_index=} not found for task
{task_id!r} in DAG run {run_id!r} of {dag_id!r}",
- },
+ detail={"reason": "not_found", "message": message},
)
return XComResponse(key=key, value=result.value)
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
index c2b49841b3a..951fbf5cbae 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
@@ -29,6 +29,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import
XComResponse
from airflow.models.dagrun import DagRun
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
+from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.serialization.serde import deserialize, serialize
from airflow.utils.session import create_session
@@ -130,6 +131,86 @@ class TestXComsGetEndpoint:
}
assert any(msg.startswith("Checking read XCom access") for msg in
caplog.messages)
+ @pytest.mark.parametrize(
+ "offset, expected_status, expected_json",
+ [
+ pytest.param(
+ -4,
+ 404,
+ {
+ "detail": {
+ "reason": "not_found",
+ "message": (
+ "XCom with key='xcom_1' offset=-4 not found "
+ "for task 'task' in DAG run 'runid' of 'dag'"
+ ),
+ },
+ },
+ id="-4",
+ ),
+ pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
+ pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
+ pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
+ pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
+ pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
+ pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
+ pytest.param(
+ 3,
+ 404,
+ {
+ "detail": {
+ "reason": "not_found",
+ "message": (
+ "XCom with key='xcom_1' offset=3 not found "
+ "for task 'task' in DAG run 'runid' of 'dag'"
+ ),
+ },
+ },
+ id="3",
+ ),
+ ],
+ )
+ def test_xcom_get_with_offset(
+ self,
+ client,
+ dag_maker,
+ session,
+ offset,
+ expected_status,
+ expected_json,
+ ):
+ xcom_values = ["f", None, "o", "b"]
+
+ class MyOperator(EmptyOperator):
+ def __init__(self, *, x, **kwargs):
+ super().__init__(**kwargs)
+ self.x = x
+
+ with dag_maker(dag_id="dag"):
+ MyOperator.partial(task_id="task").expand(x=xcom_values)
+ dag_run = dag_maker.create_dagrun(run_id="runid")
+ tis = {ti.map_index: ti for ti in dag_run.task_instances}
+
+ for map_index, db_value in enumerate(xcom_values):
+ if db_value is None: # We don't put None to XCom.
+ continue
+ ti = tis[map_index]
+ x = XComModel(
+ key="xcom_1",
+ value=db_value,
+ dag_run_id=ti.dag_run.id,
+ run_id=ti.run_id,
+ task_id=ti.task_id,
+ dag_id=ti.dag_id,
+ map_index=map_index,
+ )
+ session.add(x)
+ session.commit()
+
+ response =
client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
+ assert response.status_code == expected_status
+ assert response.json() == expected_json
+
class TestXComsSetEndpoint:
@pytest.mark.parametrize(
diff --git a/airflow-core/tests/unit/models/test_mappedoperator.py
b/airflow-core/tests/unit/models/test_mappedoperator.py
index 7f2f3770bb4..dd35f9461dc 100644
--- a/airflow-core/tests/unit/models/test_mappedoperator.py
+++ b/airflow-core/tests/unit/models/test_mappedoperator.py
@@ -32,7 +32,7 @@ from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk import setup, task, task_group, teardown
-from airflow.sdk.execution_time.comms import XComCountResponse
+from airflow.sdk.execution_time.comms import XComCountResponse, XComResult
from airflow.utils.state import TaskInstanceState
from airflow.utils.task_group import TaskGroup
from airflow.utils.trigger_rule import TriggerRule
@@ -1270,8 +1270,16 @@ class TestMappedSetupTeardown:
) as supervisor_comms:
# TODO: TaskSDK: this is a bit of a hack that we need to stub this
at all. `dag.test()` should
# really work without this!
- supervisor_comms.get_message.return_value =
XComCountResponse(len=3)
+ supervisor_comms.get_message.side_effect = [
+ XComCountResponse(len=3),
+ XComResult(key="return_value", value=1),
+ XComCountResponse(len=3),
+ XComResult(key="return_value", value=2),
+ XComCountResponse(len=3),
+ XComResult(key="return_value", value=3),
+ ]
dr = dag.test()
+ assert supervisor_comms.get_message.call_count == 6
states = self.get_states(dr)
expected = {
"tg_1.my_pre_setup": "success",
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index c64f721b9ae..399954ddfcb 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -429,6 +429,42 @@ class XComOperations:
# decouple from the server response string
return OKResponse(ok=True)
+ def get_sequence_item(
+ self,
+ dag_id: str,
+ run_id: str,
+ task_id: str,
+ key: str,
+ offset: int,
+ ) -> XComResponse | ErrorResponse:
+ params = {"offset": offset}
+ try:
+ resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}",
params=params)
+ except ServerResponseError as e:
+ if e.response.status_code == HTTPStatus.NOT_FOUND:
+ log.error(
+ "XCom not found",
+ dag_id=dag_id,
+ run_id=run_id,
+ task_id=task_id,
+ key=key,
+ offset=offset,
+ detail=e.detail,
+ status_code=e.response.status_code,
+ )
+ return ErrorResponse(
+ error=ErrorType.XCOM_NOT_FOUND,
+ detail={
+ "dag_id": dag_id,
+ "run_id": run_id,
+ "task_id": task_id,
+ "key": key,
+ "offset": offset,
+ },
+ )
+ raise
+ return XComResponse.model_validate_json(resp.read())
+
class AssetOperations:
__slots__ = ("client",)
diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
index ab3a5a82b22..1adcb7efaa7 100644
--- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
+++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
@@ -337,7 +337,7 @@ class PlainXComArg(XComArg):
task_id = self.operator.task_id
if self.operator.is_mapped:
- return LazyXComSequence[Any](xcom_arg=self, ti=ti)
+ return LazyXComSequence(xcom_arg=self, ti=ti)
tg = self.operator.get_closest_mapped_task_group()
result = None
if tg is None:
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index b4d68086b0c..039e2a5409c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -441,6 +441,15 @@ class GetXComCount(BaseModel):
type: Literal["GetNumberXComs"] = "GetNumberXComs"
+class GetXComSequenceItem(BaseModel):
+ key: str
+ dag_id: str
+ run_id: str
+ task_id: str
+ offset: int
+ type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
+
+
class SetXCom(BaseModel):
key: str
value: Annotated[
@@ -605,6 +614,7 @@ ToSupervisor = Annotated[
GetVariable,
GetXCom,
GetXComCount,
+ GetXComSequenceItem,
PutVariable,
RescheduleTask,
RetryTask,
diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
index 095ab051fb1..79822787f38 100644
--- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
+++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
@@ -43,10 +43,10 @@ class LazyXComIterator(Iterator[T]):
if self.index < 0:
# When iterating backwards, avoid extra HTTP request
raise StopIteration()
- val = self.seq._get_item(self.index)
- if val is None:
- # None isn't the best signal (it's bad in fact) but it's the best
we can do until https://github.com/apache/airflow/issues/46426
- raise StopIteration()
+ try:
+ val = self.seq[self.index]
+ except IndexError:
+ raise StopIteration from None
self.index += self.dir
return val
@@ -109,52 +109,59 @@ class LazyXComSequence(Sequence[T]):
def __getitem__(self, key: slice) -> Sequence[T]: ...
def __getitem__(self, key: int | slice) -> T | Sequence[T]:
- if isinstance(key, int):
- if key >= 0:
- return self._get_item(key)
- # val[-1] etc.
- return self._get_item(len(self) + key)
+ if not isinstance(key, (int, slice)):
+ raise TypeError(f"Sequence indices must be integers or slices, not
{type(key).__name__}")
if isinstance(key, slice):
- # This implements the slicing syntax. We want to optimize negative
slicing (e.g. seq[-10:]) by not
- # doing an additional COUNT query (via HEAD http request) if
possible. We can do this unless the
- # start and stop have different signs (i.e. one is positive and
another negative).
- ...
- """
- Todo?
- elif isinstance(key, slice):
- start, stop, reverse = _coerce_slice(key)
- if start >= 0:
- if stop is None:
- stmt = self._select_asc.offset(start)
- elif stop >= 0:
- stmt = self._select_asc.slice(start, stop)
- else:
- stmt = self._select_asc.slice(start, len(self) + stop)
- rows = [self._process_row(row) for row in
self._session.execute(stmt)]
- if reverse:
- rows.reverse()
- else:
- if stop is None:
- stmt = self._select_desc.limit(-start)
- elif stop < 0:
- stmt = self._select_desc.slice(-stop, -start)
- else:
- stmt = self._select_desc.slice(len(self) - stop, -start)
- rows = [self._process_row(row) for row in
self._session.execute(stmt)]
- if not reverse:
- rows.reverse()
- return rows
- """
- raise TypeError(f"Sequence indices must be integers or slices, not
{type(key).__name__}")
-
- def _get_item(self, index: int) -> T:
- # TODO: maybe we need to call SUPERVISOR_COMMS manually so we can
handle not found here?
- return self._ti.xcom_pull(
- task_ids=self._xcom_arg.operator.task_id,
- key=self._xcom_arg.key,
- map_indexes=index,
- )
+ raise TypeError("slice is not implemented yet")
+ # TODO...
+ # This implements the slicing syntax. We want to optimize negative
slicing (e.g. seq[-10:]) by not
+ # doing an additional COUNT query (via HEAD http request) if possible.
We can do this unless the
+ # start and stop have different signs (i.e. one is positive and
another negative).
+ # start, stop, reverse = _coerce_slice(key)
+ # if start >= 0:
+ # if stop is None:
+ # stmt = self._select_asc.offset(start)
+ # elif stop >= 0:
+ # stmt = self._select_asc.slice(start, stop)
+ # else:
+ # stmt = self._select_asc.slice(start, len(self) + stop)
+ # rows = [self._process_row(row) for row in
self._session.execute(stmt)]
+ # if reverse:
+ # rows.reverse()
+ # else:
+ # if stop is None:
+ # stmt = self._select_desc.limit(-start)
+ # elif stop < 0:
+ # stmt = self._select_desc.slice(-stop, -start)
+ # else:
+ # stmt = self._select_desc.slice(len(self) - stop, -start)
+ # rows = [self._process_row(row) for row in
self._session.execute(stmt)]
+ # if not reverse:
+ # rows.reverse()
+ # return rows
+
+ from airflow.sdk.bases.xcom import BaseXCom
+ from airflow.sdk.execution_time.comms import GetXComSequenceItem,
XComResult
+ from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+ with SUPERVISOR_COMMS.lock:
+ source = (xcom_arg := self._xcom_arg).operator
+ SUPERVISOR_COMMS.send_request(
+ log=log,
+ msg=GetXComSequenceItem(
+ key=xcom_arg.key,
+ dag_id=source.dag_id,
+ task_id=source.task_id,
+ run_id=self._ti.run_id,
+ offset=key,
+ ),
+ )
+ msg = SUPERVISOR_COMMS.get_message()
+
+ if not isinstance(msg, XComResult):
+ raise IndexError(key)
+ return BaseXCom.deserialize_value(msg)
def _coerce_index(value: Any) -> int | None:
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index c90d6ea5a02..ff804f3f769 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -59,6 +59,7 @@ from airflow.sdk.api.datamodels._generated import (
TaskStatesResponse,
TerminalTIState,
VariableResponse,
+ XComResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
@@ -84,6 +85,7 @@ from airflow.sdk.execution_time.comms import (
GetVariable,
GetXCom,
GetXComCount,
+ GetXComSequenceItem,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
@@ -1034,6 +1036,14 @@ class ActivitySubprocess(WatchedSubprocess):
elif isinstance(msg, GetXComCount):
len = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id,
msg.key)
resp = XComCountResponse(len=len)
+ elif isinstance(msg, GetXComSequenceItem):
+ xcom = self.client.xcoms.get_sequence_item(
+ msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
+ )
+ if isinstance(xcom, XComResponse):
+ resp = XComResult.from_xcom_response(xcom)
+ else:
+ resp = xcom
elif isinstance(msg, DeferTask):
self._terminal_state = IntermediateTIState.DEFERRED
self.client.task_instances.defer(self.id, msg)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
new file mode 100644
index 00000000000..a42572e5df1
--- /dev/null
+++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest.mock import ANY, Mock, call
+
+import pytest
+
+from airflow.sdk.exceptions import ErrorType
+from airflow.sdk.execution_time.comms import (
+ ErrorResponse,
+ GetXComCount,
+ GetXComSequenceItem,
+ XComCountResponse,
+ XComResult,
+)
+from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
+
+
[email protected]
+def mock_operator():
+ return Mock(spec=["dag_id", "task_id"], dag_id="dag", task_id="task")
+
+
[email protected]
+def mock_xcom_arg(mock_operator):
+ return Mock(spec=["operator", "key"], operator=mock_operator,
key="return_value")
+
+
[email protected]
+def mock_ti():
+ return Mock(spec=["run_id"], run_id="run")
+
+
[email protected]
+def lazy_sequence(mock_xcom_arg, mock_ti):
+ return LazyXComSequence(mock_xcom_arg, mock_ti)
+
+
+def test_len(mock_supervisor_comms, lazy_sequence):
+ mock_supervisor_comms.get_message.return_value = XComCountResponse(len=3)
+ assert len(lazy_sequence) == 3
+ assert mock_supervisor_comms.send_request.mock_calls == [
+ call(log=ANY, msg=GetXComCount(key="return_value", dag_id="dag",
task_id="task", run_id="run")),
+ ]
+
+
+def test_iter(mock_supervisor_comms, lazy_sequence):
+ it = iter(lazy_sequence)
+
+ mock_supervisor_comms.get_message.side_effect = [
+ XComResult(key="return_value", value="f"),
+ ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops":
"sorry!"}),
+ ]
+ assert list(it) == ["f"]
+ assert mock_supervisor_comms.send_request.mock_calls == [
+ call(
+ log=ANY,
+ msg=GetXComSequenceItem(
+ key="return_value",
+ dag_id="dag",
+ task_id="task",
+ run_id="run",
+ offset=0,
+ ),
+ ),
+ call(
+ log=ANY,
+ msg=GetXComSequenceItem(
+ key="return_value",
+ dag_id="dag",
+ task_id="task",
+ run_id="run",
+ offset=1,
+ ),
+ ),
+ ]
+
+
+def test_getitem_index(mock_supervisor_comms, lazy_sequence):
+ mock_supervisor_comms.get_message.return_value =
XComResult(key="return_value", value="f")
+ assert lazy_sequence[4] == "f"
+ assert mock_supervisor_comms.send_request.mock_calls == [
+ call(
+ log=ANY,
+ msg=GetXComSequenceItem(
+ key="return_value",
+ dag_id="dag",
+ task_id="task",
+ run_id="run",
+ offset=4,
+ ),
+ ),
+ ]
+
+
+def test_getitem_indexerror(mock_supervisor_comms, lazy_sequence):
+ mock_supervisor_comms.get_message.return_value = ErrorResponse(
+ error=ErrorType.XCOM_NOT_FOUND,
+ detail={"oops": "sorry!"},
+ )
+ with pytest.raises(IndexError) as ctx:
+ lazy_sequence[4]
+ assert ctx.value.args == (4,)
+ assert mock_supervisor_comms.send_request.mock_calls == [
+ call(
+ log=ANY,
+ msg=GetXComSequenceItem(
+ key="return_value",
+ dag_id="dag",
+ task_id="task",
+ run_id="run",
+ offset=4,
+ ),
+ ),
+ ]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 4aff50a5fd4..a3eef617a4e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -74,6 +74,7 @@ from airflow.sdk.execution_time.comms import (
GetTICount,
GetVariable,
GetXCom,
+ GetXComSequenceItem,
OKResponse,
PrevSuccessfulDagRunResult,
PutVariable,
@@ -1436,6 +1437,36 @@ class TestHandleRequest:
TaskStatesResult(task_states={"run_id": {"task1": "success",
"task2": "failed"}}),
id="get_task_states",
),
+ pytest.param(
+ GetXComSequenceItem(
+ key="test_key",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ offset=0,
+ ),
+
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
+ "xcoms.get_sequence_item",
+ ("test_dag", "test_run", "test_task", "test_key", 0),
+ {},
+ XComResult(key="test_key", value="test_value"),
+ id="get_xcom_seq_item",
+ ),
+ pytest.param(
+ GetXComSequenceItem(
+ key="test_key",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ offset=2,
+ ),
+
b'{"error":"XCOM_NOT_FOUND","detail":null,"type":"ErrorResponse"}\n',
+ "xcoms.get_sequence_item",
+ ("test_dag", "test_run", "test_task", "test_key", 2),
+ {},
+ ErrorResponse(error=ErrorType.XCOM_NOT_FOUND),
+ id="get_xcom_seq_item_not_found",
+ ),
],
)
def test_handle_requests(