This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new f8abdcd7353 Implement slice on LazyXComSequence (#50117)
f8abdcd7353 is described below
commit f8abdcd735368f830ec037ad10e06675807f9485
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Mon Jun 2 16:13:01 2025 +0800
Implement slice on LazyXComSequence (#50117)
---
.../api_fastapi/execution_api/datamodels/xcom.py | 14 ++-
.../api_fastapi/execution_api/routes/xcoms.py | 132 ++++++++++++++++++++-
.../execution_api/versions/head/test_xcoms.py | 78 ++++++++++--
.../versions/v2025_04_28/test_xcom.py | 107 +++++++++++++++++
task-sdk/src/airflow/sdk/api/client.py | 29 ++++-
.../src/airflow/sdk/api/datamodels/_generated.py | 26 +++-
task-sdk/src/airflow/sdk/execution_time/comms.py | 36 +++++-
.../airflow/sdk/execution_time/lazy_sequence.py | 94 ++++++++-------
.../src/airflow/sdk/execution_time/supervisor.py | 14 ++-
.../task_sdk/execution_time/test_lazy_sequence.py | 33 +++++-
.../task_sdk/execution_time/test_supervisor.py | 25 +++-
11 files changed, 516 insertions(+), 72 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py
index ae7ddd26761..4df3e3f74f0 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/xcom.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import sys
from typing import Any
-from pydantic import JsonValue
+from pydantic import JsonValue, RootModel
from airflow.api_fastapi.core_api.base import BaseModel
@@ -36,3 +36,15 @@ class XComResponse(BaseModel):
key: str
value: JsonValue
"""The returned XCom value in a JSON-compatible format."""
+
+
+class XComSequenceIndexResponse(RootModel):
+ """XCom schema with minimal structure for index-based access."""
+
+ root: JsonValue
+
+
+class XComSequenceSliceResponse(RootModel):
+ """XCom schema with minimal structure for slice-based access."""
+
+ root: list[JsonValue]
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index d4c8c5160cd..a9ed4a5b48d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -27,7 +27,11 @@ from sqlalchemy import delete
from sqlalchemy.sql.selectable import Select
from airflow.api_fastapi.common.db.common import SessionDep
-from airflow.api_fastapi.execution_api.datamodels.xcom import XComResponse
+from airflow.api_fastapi.execution_api.datamodels.xcom import (
+ XComResponse,
+ XComSequenceIndexResponse,
+ XComSequenceSliceResponse,
+)
from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
@@ -184,6 +188,132 @@ def get_xcom(
return XComResponse(key=key, value=result.value)
[email protected](
+ "/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}",
+ description="Get a single XCom value from a mapped task by sequence index",
+)
+def get_mapped_xcom_by_index(
+ dag_id: str,
+ run_id: str,
+ task_id: str,
+ key: str,
+ offset: int,
+ session: SessionDep,
+) -> XComSequenceIndexResponse:
+ xcom_query = XComModel.get_many(
+ run_id=run_id,
+ key=key,
+ task_ids=task_id,
+ dag_ids=dag_id,
+ session=session,
+ )
+ xcom_query = xcom_query.order_by(None)
+ if offset >= 0:
+ xcom_query =
xcom_query.order_by(XComModel.map_index.asc()).offset(offset)
+ else:
+ xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1
- offset)
+
+ if (result := xcom_query.limit(1).first()) is None:
+ message = (
+ f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG
run {run_id!r} of {dag_id!r}"
+ )
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND,
+ detail={"reason": "not_found", "message": message},
+ )
+ return XComSequenceIndexResponse(result.value)
+
+
+class GetXComSliceFilterParams(BaseModel):
+ """Class to house slice params."""
+
+ start: int | None = None
+ stop: int | None = None
+ step: int | None = None
+
+
[email protected](
+ "/{dag_id}/{run_id}/{task_id}/{key}/slice",
+ description="Get XCom values from a mapped task by sequence slice",
+)
+def get_mapped_xcom_by_slice(
+ dag_id: str,
+ run_id: str,
+ task_id: str,
+ key: str,
+ params: Annotated[GetXComSliceFilterParams, Query()],
+ session: SessionDep,
+) -> XComSequenceSliceResponse:
+ query = XComModel.get_many(
+ run_id=run_id,
+ key=key,
+ task_ids=task_id,
+ dag_ids=dag_id,
+ session=session,
+ )
+ query = query.order_by(None)
+
+ step = params.step or 1
+
+ # We want to optimize negative slicing (e.g. seq[-10:]) by not doing an
+ # additional COUNT query if possible. This is possible unless both start
and
+ # stop are explicitly given and have different signs.
+ if (start := params.start) is None:
+ if (stop := params.stop) is None:
+ if step >= 0:
+ query = query.order_by(XComModel.map_index.asc())
+ else:
+ query = query.order_by(XComModel.map_index.desc())
+ step = -step
+ elif stop >= 0:
+ query = query.order_by(XComModel.map_index.asc())
+ if step >= 0:
+ query = query.limit(stop)
+ else:
+ query = query.offset(stop + 1)
+ else:
+ query = query.order_by(XComModel.map_index.desc())
+ step = -step
+ if step > 0:
+ query = query.limit(-stop - 1)
+ else:
+ query = query.offset(-stop)
+ elif start >= 0:
+ query = query.order_by(XComModel.map_index.asc())
+ if (stop := params.stop) is None:
+ if step >= 0:
+ query = query.offset(start)
+ else:
+ query = query.limit(start + 1)
+ else:
+ if stop < 0:
+ stop += get_query_count(query, session=session)
+ if step >= 0:
+ query = query.slice(start, stop)
+ else:
+ query = query.slice(stop + 1, start + 1)
+ else:
+ query = query.order_by(XComModel.map_index.desc())
+ step = -step
+ if (stop := params.stop) is None:
+ if step > 0:
+ query = query.offset(-start - 1)
+ else:
+ query = query.limit(-start)
+ else:
+ if stop >= 0:
+ stop -= get_query_count(query, session=session)
+ if step > 0:
+ query = query.slice(-1 - start, -1 - stop)
+ else:
+ query = query.slice(-stop, -start)
+
+ values = [row.value for row in query.with_entities(XComModel.value)]
+ if step != 1:
+ values = values[::step]
+ return XComSequenceSliceResponse(values)
+
+
if sys.version_info < (3, 12):
# zmievsa/cadwyn#262
# Setting this to "Any" doesn't have any impact on the API as it has to be
parsed as valid JSON regardless
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
index 951fbf5cbae..1b10e81cd23 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
@@ -1,5 +1,4 @@
# Licensed to the Apache Software Foundation (ASF) under one
-# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
@@ -20,6 +19,7 @@ from __future__ import annotations
import contextlib
import logging
+import urllib.parse
import httpx
import pytest
@@ -148,12 +148,12 @@ class TestXComsGetEndpoint:
},
id="-4",
),
- pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
- pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
- pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
- pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
- pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
- pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
+ pytest.param(-3, 200, "f", id="-3"),
+ pytest.param(-2, 200, "o", id="-2"),
+ pytest.param(-1, 200, "b", id="-1"),
+ pytest.param(0, 200, "f", id="0"),
+ pytest.param(1, 200, "o", id="1"),
+ pytest.param(2, 200, "b", id="2"),
pytest.param(
3,
404,
@@ -207,10 +207,72 @@ class TestXComsGetEndpoint:
session.add(x)
session.commit()
- response =
client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
+ response =
client.get(f"/execution/xcoms/dag/runid/task/xcom_1/item/{offset}")
assert response.status_code == expected_status
assert response.json() == expected_json
+ @pytest.mark.parametrize(
+ "key",
+ [
+ pytest.param(slice(None, None, None), id=":"),
+ pytest.param(slice(None, None, -2), id="::-2"),
+ pytest.param(slice(None, 2, None), id=":2"),
+ pytest.param(slice(None, 2, -1), id=":2:-1"),
+ pytest.param(slice(None, -2, None), id=":-2"),
+ pytest.param(slice(None, -2, -1), id=":-2:-1"),
+ pytest.param(slice(1, None, None), id="1:"),
+ pytest.param(slice(2, None, -1), id="2::-1"),
+ pytest.param(slice(1, 2, None), id="1:2"),
+ pytest.param(slice(2, 1, -1), id="2:1:-1"),
+ pytest.param(slice(1, -1, None), id="1:-1"),
+ pytest.param(slice(2, -2, -1), id="2:-2:-1"),
+ pytest.param(slice(-2, None, None), id="-2:"),
+ pytest.param(slice(-1, None, -1), id="-1::-1"),
+ pytest.param(slice(-2, -1, None), id="-2:-1"),
+ pytest.param(slice(-1, -3, -1), id="-1:-3:-1"),
+ ],
+ )
+ def test_xcom_get_with_slice(self, client, dag_maker, session, key):
+ xcom_values = ["f", None, "o", "b"]
+
+ class MyOperator(EmptyOperator):
+ def __init__(self, *, x, **kwargs):
+ super().__init__(**kwargs)
+ self.x = x
+
+ with dag_maker(dag_id="dag"):
+ MyOperator.partial(task_id="task").expand(x=xcom_values)
+ dag_run = dag_maker.create_dagrun(run_id="runid")
+ tis = {ti.map_index: ti for ti in dag_run.task_instances}
+
+ for map_index, db_value in enumerate(xcom_values):
+ if db_value is None: # We don't put None to XCom.
+ continue
+ ti = tis[map_index]
+ x = XComModel(
+ key="xcom_1",
+ value=db_value,
+ dag_run_id=ti.dag_run.id,
+ run_id=ti.run_id,
+ task_id=ti.task_id,
+ dag_id=ti.dag_id,
+ map_index=map_index,
+ )
+ session.add(x)
+ session.commit()
+
+ qs = {}
+ if key.start is not None:
+ qs["start"] = key.start
+ if key.stop is not None:
+ qs["stop"] = key.stop
+ if key.step is not None:
+ qs["step"] = key.step
+
+ response =
client.get(f"/execution/xcoms/dag/runid/task/xcom_1/slice?{urllib.parse.urlencode(qs)}")
+ assert response.status_code == 200
+ assert response.json() == ["f", "o", "b"][key]
+
class TestXComsSetEndpoint:
@pytest.mark.parametrize(
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py
new file mode 100644
index 00000000000..1de65d493d1
--- /dev/null
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/v2025_04_28/test_xcom.py
@@ -0,0 +1,107 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.models.xcom import XComModel
+from airflow.providers.standard.operators.empty import EmptyOperator
+
+pytestmark = pytest.mark.db_test
+
+
+class TestXComsGetEndpoint:
+ @pytest.mark.parametrize(
+ "offset, expected_status, expected_json",
+ [
+ pytest.param(
+ -4,
+ 404,
+ {
+ "detail": {
+ "reason": "not_found",
+ "message": (
+ "XCom with key='xcom_1' offset=-4 not found "
+ "for task 'task' in DAG run 'runid' of 'dag'"
+ ),
+ },
+ },
+ id="-4",
+ ),
+ pytest.param(-3, 200, {"key": "xcom_1", "value": "f"}, id="-3"),
+ pytest.param(-2, 200, {"key": "xcom_1", "value": "o"}, id="-2"),
+ pytest.param(-1, 200, {"key": "xcom_1", "value": "b"}, id="-1"),
+ pytest.param(0, 200, {"key": "xcom_1", "value": "f"}, id="0"),
+ pytest.param(1, 200, {"key": "xcom_1", "value": "o"}, id="1"),
+ pytest.param(2, 200, {"key": "xcom_1", "value": "b"}, id="2"),
+ pytest.param(
+ 3,
+ 404,
+ {
+ "detail": {
+ "reason": "not_found",
+ "message": (
+ "XCom with key='xcom_1' offset=3 not found "
+ "for task 'task' in DAG run 'runid' of 'dag'"
+ ),
+ },
+ },
+ id="3",
+ ),
+ ],
+ )
+ def test_xcom_get_with_offset(
+ self,
+ client,
+ dag_maker,
+ session,
+ offset,
+ expected_status,
+ expected_json,
+ ):
+ xcom_values = ["f", None, "o", "b"]
+
+ class MyOperator(EmptyOperator):
+ def __init__(self, *, x, **kwargs):
+ super().__init__(**kwargs)
+ self.x = x
+
+ with dag_maker(dag_id="dag"):
+ MyOperator.partial(task_id="task").expand(x=xcom_values)
+
+ dag_run = dag_maker.create_dagrun(run_id="runid")
+ tis = {ti.map_index: ti for ti in dag_run.task_instances}
+ for map_index, db_value in enumerate(xcom_values):
+ if db_value is None: # We don't put None to XCom.
+ continue
+ ti = tis[map_index]
+ x = XComModel(
+ key="xcom_1",
+ value=db_value,
+ dag_run_id=ti.dag_run.id,
+ run_id=ti.run_id,
+ task_id=ti.task_id,
+ dag_id=ti.dag_id,
+ map_index=map_index,
+ )
+ session.add(x)
+ session.commit()
+
+ response =
client.get(f"/execution/xcoms/dag/runid/task/xcom_1?offset={offset}")
+ assert response.status_code == expected_status
+ assert response.json() == expected_json
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 1fd548319e7..b9d0a4511ea 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -58,6 +58,8 @@ from airflow.sdk.api.datamodels._generated import (
VariablePostBody,
VariableResponse,
XComResponse,
+ XComSequenceIndexResponse,
+ XComSequenceSliceResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
@@ -442,10 +444,9 @@ class XComOperations:
task_id: str,
key: str,
offset: int,
- ) -> XComResponse | ErrorResponse:
- params = {"offset": offset}
+ ) -> XComSequenceIndexResponse | ErrorResponse:
try:
- resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}",
params=params)
+ resp =
self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/item/{offset}")
except ServerResponseError as e:
if e.response.status_code == HTTPStatus.NOT_FOUND:
log.error(
@@ -469,7 +470,27 @@ class XComOperations:
},
)
raise
- return XComResponse.model_validate_json(resp.read())
+ return XComSequenceIndexResponse.model_validate_json(resp.read())
+
+ def get_sequence_slice(
+ self,
+ dag_id: str,
+ run_id: str,
+ task_id: str,
+ key: str,
+ start: int | None,
+ stop: int | None,
+ step: int | None,
+ ) -> XComSequenceSliceResponse:
+ params = {}
+ if start is not None:
+ params["start"] = start
+ if stop is not None:
+ params["stop"] = stop
+ if step is not None:
+ params["step"] = step
+ resp =
self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}/slice", params=params)
+ return XComSequenceSliceResponse.model_validate_json(resp.read())
class AssetOperations:
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 3efae80e5b6..f6b1c907ef5 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -25,7 +25,7 @@ from enum import Enum
from typing import Annotated, Any, Final, Literal
from uuid import UUID
-from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue
+from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue,
RootModel
API_VERSION: Final[str] = "2025-05-20"
@@ -356,6 +356,30 @@ class XComResponse(BaseModel):
value: JsonValue
+class XComSequenceIndexResponse(RootModel[JsonValue]):
+ root: Annotated[
+ JsonValue,
+ Field(
+ description="XCom schema with minimal structure for index-based
access.",
+ title="XComSequenceIndexResponse",
+ ),
+ ]
+
+
+class XComSequenceSliceResponse(RootModel[list[JsonValue]]):
+ """
+ XCom schema with minimal structure for slice-based access.
+ """
+
+ root: Annotated[
+ list[JsonValue],
+ Field(
+ description="XCom schema with minimal structure for slice-based
access.",
+ title="XComSequenceSliceResponse",
+ ),
+ ]
+
+
class TaskInstance(BaseModel):
"""
Schema for TaskInstance model with minimal required fields needed for
Runtime.
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index a25ba574582..ecc34852252 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -74,6 +74,8 @@ from airflow.sdk.api.datamodels._generated import (
TriggerDAGRunPayload,
VariableResponse,
XComResponse,
+ XComSequenceIndexResponse,
+ XComSequenceSliceResponse,
)
from airflow.sdk.exceptions import ErrorType
@@ -227,6 +229,24 @@ class XComCountResponse(BaseModel):
type: Literal["XComLengthResponse"] = "XComLengthResponse"
+class XComSequenceIndexResult(BaseModel):
+ root: JsonValue
+ type: Literal["XComSequenceIndexResult"] = "XComSequenceIndexResult"
+
+ @classmethod
+ def from_response(cls, response: XComSequenceIndexResponse) ->
XComSequenceIndexResult:
+ return cls(root=response.root, type="XComSequenceIndexResult")
+
+
+class XComSequenceSliceResult(BaseModel):
+ root: list[JsonValue]
+ type: Literal["XComSequenceSliceResult"] = "XComSequenceSliceResult"
+
+ @classmethod
+ def from_response(cls, response: XComSequenceSliceResponse) ->
XComSequenceSliceResult:
+ return cls(root=response.root, type="XComSequenceSliceResult")
+
+
class ConnectionResult(ConnectionResponse):
type: Literal["ConnectionResult"] = "ConnectionResult"
@@ -352,8 +372,10 @@ ToTask = Annotated[
TICount,
TaskStatesResult,
VariableResult,
- XComResult,
XComCountResponse,
+ XComResult,
+ XComSequenceIndexResult,
+ XComSequenceSliceResult,
OKResponse,
],
Field(discriminator="type"),
@@ -451,6 +473,17 @@ class GetXComSequenceItem(BaseModel):
type: Literal["GetXComSequenceItem"] = "GetXComSequenceItem"
+class GetXComSequenceSlice(BaseModel):
+ key: str
+ dag_id: str
+ run_id: str
+ task_id: str
+ start: int | None
+ stop: int | None
+ step: int | None
+ type: Literal["GetXComSequenceSlice"] = "GetXComSequenceSlice"
+
+
class SetXCom(BaseModel):
key: str
value: Annotated[
@@ -616,6 +649,7 @@ ToSupervisor = Annotated[
GetXCom,
GetXComCount,
GetXComSequenceItem,
+ GetXComSequenceSlice,
PutVariable,
RescheduleTask,
RetryTask,
diff --git a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
index 0fbfcf39498..9cf9acfac81 100644
--- a/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
+++ b/task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+import collections
import itertools
from collections.abc import Iterator, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload
@@ -30,6 +31,11 @@ if TYPE_CHECKING:
T = TypeVar("T")
+# This is used to wrap values from the API so the structure is compatible with
+# ``XCom.deserialize_value``. We don't want to wrap the API values in a nested
+# {"value": value} dict since it wastes bandwidth.
+_XComWrapper = collections.namedtuple("_XComWrapper", "value")
+
log = structlog.get_logger(logger_name=__name__)
@@ -98,7 +104,7 @@ class LazyXComSequence(Sequence[T]):
if isinstance(msg, ErrorResponse):
raise RuntimeError(msg)
if not isinstance(msg, XComCountResponse):
- raise TypeError(f"Got unexpected response to GetXComCount:
{msg}")
+ raise TypeError(f"Got unexpected response to GetXComCount:
{msg!r}")
self._len = msg.len
return self._len
@@ -109,41 +115,42 @@ class LazyXComSequence(Sequence[T]):
def __getitem__(self, key: slice) -> Sequence[T]: ...
def __getitem__(self, key: int | slice) -> T | Sequence[T]:
- if not isinstance(key, (int, slice)):
- raise TypeError(f"Sequence indices must be integers or slices, not
{type(key).__name__}")
-
- if isinstance(key, slice):
- raise TypeError("slice is not implemented yet")
- # TODO...
- # This implements the slicing syntax. We want to optimize negative
slicing (e.g. seq[-10:]) by not
- # doing an additional COUNT query (via HEAD http request) if possible.
We can do this unless the
- # start and stop have different signs (i.e. one is positive and
another negative).
- # start, stop, reverse = _coerce_slice(key)
- # if start >= 0:
- # if stop is None:
- # stmt = self._select_asc.offset(start)
- # elif stop >= 0:
- # stmt = self._select_asc.slice(start, stop)
- # else:
- # stmt = self._select_asc.slice(start, len(self) + stop)
- # rows = [self._process_row(row) for row in
self._session.execute(stmt)]
- # if reverse:
- # rows.reverse()
- # else:
- # if stop is None:
- # stmt = self._select_desc.limit(-start)
- # elif stop < 0:
- # stmt = self._select_desc.slice(-stop, -start)
- # else:
- # stmt = self._select_desc.slice(len(self) - stop, -start)
- # rows = [self._process_row(row) for row in
self._session.execute(stmt)]
- # if not reverse:
- # rows.reverse()
- # return rows
- from airflow.sdk.execution_time.comms import GetXComSequenceItem,
XComResult
+ from airflow.sdk.execution_time.comms import (
+ ErrorResponse,
+ GetXComSequenceItem,
+ GetXComSequenceSlice,
+ XComSequenceIndexResult,
+ XComSequenceSliceResult,
+ )
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
from airflow.sdk.execution_time.xcom import XCom
+ if isinstance(key, slice):
+ start, stop, step = _coerce_slice(key)
+ with SUPERVISOR_COMMS.lock:
+ source = (xcom_arg := self._xcom_arg).operator
+ SUPERVISOR_COMMS.send_request(
+ log=log,
+ msg=GetXComSequenceSlice(
+ key=xcom_arg.key,
+ dag_id=source.dag_id,
+ task_id=source.task_id,
+ run_id=self._ti.run_id,
+ start=start,
+ stop=stop,
+ step=step,
+ ),
+ )
+ msg = SUPERVISOR_COMMS.get_message()
+ if not isinstance(msg, XComSequenceSliceResult):
+ raise TypeError(f"Got unexpected response to
GetXComSequenceSlice: {msg!r}")
+ return [XCom.deserialize_value(_XComWrapper(value)) for value in
msg.root]
+
+ if not isinstance(key, int):
+ if (index := getattr(key, "__index__", None)) is not None:
+ key = index()
+ raise TypeError(f"Sequence indices must be integers or slices not
{type(key).__name__}")
+
with SUPERVISOR_COMMS.lock:
source = (xcom_arg := self._xcom_arg).operator
SUPERVISOR_COMMS.send_request(
@@ -157,13 +164,14 @@ class LazyXComSequence(Sequence[T]):
),
)
msg = SUPERVISOR_COMMS.get_message()
-
- if not isinstance(msg, XComResult):
+ if isinstance(msg, ErrorResponse):
raise IndexError(key)
- return XCom.deserialize_value(msg)
+ if not isinstance(msg, XComSequenceIndexResult):
+ raise TypeError(f"Got unexpected response to GetXComSequenceItem:
{msg!r}")
+ return XCom.deserialize_value(_XComWrapper(msg.root))
-def _coerce_index(value: Any) -> int | None:
+def _coerce_slice_index(value: Any) -> int | None:
"""
Check slice attribute's type and convert it to int.
@@ -177,17 +185,13 @@ def _coerce_index(value: Any) -> int | None:
raise TypeError("slice indices must be integers or None or have an
__index__ method")
-def _coerce_slice(key: slice) -> tuple[int, int | None, bool]:
+def _coerce_slice(key: slice) -> tuple[int | None, int | None, int | None]:
"""
Check slice content and convert it for SQL.
See CPython documentation on this:
https://docs.python.org/3/reference/datamodel.html#slice-objects
"""
- if key.step is None or key.step == 1:
- reverse = False
- elif key.step == -1:
- reverse = True
- else:
- raise ValueError("non-trivial slice step not supported")
- return _coerce_index(key.start) or 0, _coerce_index(key.stop), reverse
+ if (step := _coerce_slice_index(key.step)) == 0:
+ raise ValueError("slice step cannot be zero")
+ return _coerce_slice_index(key.start), _coerce_slice_index(key.stop), step
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 1006b861378..65d05cc023d 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -60,7 +60,7 @@ from airflow.sdk.api.datamodels._generated import (
TaskInstanceState,
TaskStatesResponse,
VariableResponse,
- XComResponse,
+ XComSequenceIndexResponse,
)
from airflow.sdk.exceptions import ErrorType
from airflow.sdk.execution_time.comms import (
@@ -87,6 +87,7 @@ from airflow.sdk.execution_time.comms import (
GetXCom,
GetXComCount,
GetXComSequenceItem,
+ GetXComSequenceSlice,
PrevSuccessfulDagRunResult,
PutVariable,
RescheduleTask,
@@ -103,6 +104,8 @@ from airflow.sdk.execution_time.comms import (
VariableResult,
XComCountResponse,
XComResult,
+ XComSequenceIndexResult,
+ XComSequenceSliceResult,
)
from airflow.sdk.execution_time.secrets_masker import mask_secret
@@ -1108,10 +1111,15 @@ class ActivitySubprocess(WatchedSubprocess):
xcom = self.client.xcoms.get_sequence_item(
msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
)
- if isinstance(xcom, XComResponse):
- resp = XComResult.from_xcom_response(xcom)
+ if isinstance(xcom, XComSequenceIndexResponse):
+ resp = XComSequenceIndexResult.from_response(xcom)
else:
resp = xcom
+ elif isinstance(msg, GetXComSequenceSlice):
+ xcoms = self.client.xcoms.get_sequence_slice(
+ msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.start,
msg.stop, msg.step
+ )
+ resp = XComSequenceSliceResult.from_response(xcoms)
elif isinstance(msg, DeferTask):
self._terminal_state = TaskInstanceState.DEFERRED
self._rendered_map_index = msg.rendered_map_index
diff --git a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
index 2430f85f35e..e4943196a09 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_lazy_sequence.py
@@ -28,8 +28,10 @@ from airflow.sdk.execution_time.comms import (
ErrorResponse,
GetXComCount,
GetXComSequenceItem,
+ GetXComSequenceSlice,
XComCountResponse,
- XComResult,
+ XComSequenceIndexResult,
+ XComSequenceSliceResult,
)
from airflow.sdk.execution_time.lazy_sequence import LazyXComSequence
from airflow.sdk.execution_time.xcom import resolve_xcom_backend
@@ -75,7 +77,7 @@ def test_iter(mock_supervisor_comms, lazy_sequence):
it = iter(lazy_sequence)
mock_supervisor_comms.get_message.side_effect = [
- XComResult(key="return_value", value="f"),
+ XComSequenceIndexResult(root="f"),
ErrorResponse(error=ErrorType.XCOM_NOT_FOUND, detail={"oops":
"sorry!"}),
]
assert list(it) == ["f"]
@@ -104,7 +106,7 @@ def test_iter(mock_supervisor_comms, lazy_sequence):
def test_getitem_index(mock_supervisor_comms, lazy_sequence):
- mock_supervisor_comms.get_message.return_value =
XComResult(key="return_value", value="f")
+ mock_supervisor_comms.get_message.return_value =
XComSequenceIndexResult(root="f")
assert lazy_sequence[4] == "f"
assert mock_supervisor_comms.send_request.mock_calls == [
call(
@@ -121,12 +123,12 @@ def test_getitem_index(mock_supervisor_comms,
lazy_sequence):
@conf_vars({("core", "xcom_backend"):
"task_sdk.execution_time.test_lazy_sequence.CustomXCom"})
-def test_getitem_calls_correct_deserialise(mock_supervisor_comms,
lazy_sequence):
- mock_supervisor_comms.get_message.return_value =
XComResult(key="return_value", value="some-value")
+def test_getitem_calls_correct_deserialise(monkeypatch, mock_supervisor_comms,
lazy_sequence):
+ mock_supervisor_comms.get_message.return_value =
XComSequenceIndexResult(root="some-value")
xcom = resolve_xcom_backend()
assert xcom.__name__ == "CustomXCom"
- airflow.sdk.execution_time.xcom.XCom = xcom
+ monkeypatch.setattr(airflow.sdk.execution_time.xcom, "XCom", xcom)
assert lazy_sequence[4] == "Made with CustomXCom: some-value"
assert mock_supervisor_comms.send_request.mock_calls == [
@@ -163,3 +165,22 @@ def test_getitem_indexerror(mock_supervisor_comms,
lazy_sequence):
),
),
]
+
+
+def test_getitem_slice(mock_supervisor_comms, lazy_sequence):
+ mock_supervisor_comms.get_message.return_value =
XComSequenceSliceResult(root=[6, 4, 1])
+ assert lazy_sequence[:5] == [6, 4, 1]
+ assert mock_supervisor_comms.send_request.mock_calls == [
+ call(
+ log=ANY,
+ msg=GetXComSequenceSlice(
+ key="return_value",
+ dag_id="dag",
+ task_id="task",
+ run_id="run",
+ start=None,
+ stop=5,
+ step=None,
+ ),
+ ),
+ ]
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 4696908bae7..86a2e747e0f 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -76,6 +76,7 @@ from airflow.sdk.execution_time.comms import (
GetVariable,
GetXCom,
GetXComSequenceItem,
+ GetXComSequenceSlice,
OKResponse,
PrevSuccessfulDagRunResult,
PutVariable,
@@ -91,6 +92,8 @@ from airflow.sdk.execution_time.comms import (
TriggerDagRun,
VariableResult,
XComResult,
+ XComSequenceIndexResult,
+ XComSequenceSliceResult,
)
from airflow.sdk.execution_time.supervisor import (
BUFFER_SIZE,
@@ -1618,11 +1621,11 @@ class TestHandleRequest:
task_id="test_task",
offset=0,
),
-
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
+ b'{"root":"test_value","type":"XComSequenceIndexResult"}\n',
"xcoms.get_sequence_item",
("test_dag", "test_run", "test_task", "test_key", 0),
{},
- XComResult(key="test_key", value="test_value"),
+ XComSequenceIndexResult(root="test_value"),
None,
id="get_xcom_seq_item",
),
@@ -1642,6 +1645,24 @@ class TestHandleRequest:
None,
id="get_xcom_seq_item_not_found",
),
+ pytest.param(
+ GetXComSequenceSlice(
+ key="test_key",
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ start=None,
+ stop=None,
+ step=None,
+ ),
+ b'{"root":["foo","bar"],"type":"XComSequenceSliceResult"}\n',
+ "xcoms.get_sequence_slice",
+ ("test_dag", "test_run", "test_task", "test_key", None, None,
None),
+ {},
+ XComSequenceSliceResult(root=["foo", "bar"]),
+ None,
+ id="get_xcom_seq_slice",
+ ),
],
)
def test_handle_requests(