This is an automated email from the ASF dual-hosted git repository.
amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new ef86d4e4813 Add support to get XComs with 'include_prior_dates'
(#48440)
ef86d4e4813 is described below
commit ef86d4e481399bda2e8bf0a9cfeb4f68302f05ff
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Mar 28 15:09:40 2025 +0530
Add support to get XComs with 'include_prior_dates' (#48440)
* Add support to get XComs with 'include_prior_dates'
* adding API versioning
---
.../api_fastapi/execution_api/routes/xcoms.py | 25 ++++++++++--
.../api_fastapi/execution_api/versions/__init__.py | 7 +++-
.../execution_api/versions/v2025_03_26.py | 14 ++++++-
airflow-core/src/airflow/models/xcom.py | 14 ++++++-
airflow-core/tests/unit/models/test_xcom.py | 44 ++++++++++++++++++++--
task-sdk/src/airflow/sdk/api/client.py | 10 ++++-
task-sdk/src/airflow/sdk/execution_time/comms.py | 1 +
.../src/airflow/sdk/execution_time/supervisor.py | 4 +-
.../src/airflow/sdk/execution_time/task_runner.py | 1 +
task-sdk/src/airflow/sdk/execution_time/xcom.py | 1 +
task-sdk/tests/task_sdk/api/test_client.py | 24 ++++++++++++
.../task_sdk/execution_time/test_supervisor.py | 21 +++++++++--
.../task_sdk/execution_time/test_task_runner.py | 1 +
13 files changed, 149 insertions(+), 18 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index e75d111c0fe..2f28acc4bc1 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -22,7 +22,7 @@ import sys
from typing import Annotated, Any
from fastapi import Body, Depends, HTTPException, Path, Query, Request,
Response, status
-from pydantic import JsonValue
+from pydantic import BaseModel, JsonValue
from sqlalchemy import delete
from sqlalchemy.sql.selectable import Select
@@ -122,6 +122,13 @@ def head_xcom(
response.headers["Content-Range"] = f"map_indexes {count}"
+class GetXcomFilterParams(BaseModel):
+ """Class to house the params that can optionally be set for Get XCom."""
+
+ map_index: int = -1
+ include_prior_dates: bool = False
+
+
@router.get(
"/{dag_id}/{run_id}/{task_id}/{key}",
description="Get a single XCom Value",
@@ -131,13 +138,22 @@ def get_xcom(
run_id: str,
task_id: str,
key: str,
- xcom_query: Annotated[Select, Depends(xcom_query)],
- map_index: Annotated[int, Query()] = -1,
+ session: SessionDep,
+ params: Annotated[GetXcomFilterParams, Query()],
) -> XComResponse:
"""Get an Airflow XCom from database - not other XCom Backends."""
# The xcom_query allows no map_index to be passed. This endpoint should
always return just a single item,
# so we override that query value
- xcom_query = xcom_query.filter(XComModel.map_index == map_index)
+ xcom_query = XComModel.get_many(
+ run_id=run_id,
+ key=key,
+ task_ids=task_id,
+ dag_ids=dag_id,
+ map_indexes=params.map_index,
+ include_prior_dates=params.include_prior_dates,
+ session=session,
+ )
+ xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
# We use `BaseXCom.get_many` to fetch XComs directly from the database,
bypassing the XCom Backend.
# This avoids deserialization via the backend (e.g., from a remote storage
like S3) and instead
# retrieves the raw serialized value from the database. By not relying on
`XCom.get_many` or `XCom.get_one`
@@ -145,6 +161,7 @@ def get_xcom(
# performance hits from retrieving large data files into the API server.
result = xcom_query.limit(1).first()
if result is None:
+ map_index = params.map_index
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index af93aab29ba..5fb011edad5 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -19,10 +19,13 @@ from __future__ import annotations
from cadwyn import HeadVersion, Version, VersionBundle
-from airflow.api_fastapi.execution_api.versions.v2025_03_26 import
RemoveTIRuntimeChecksEndpoint
+from airflow.api_fastapi.execution_api.versions.v2025_03_26 import (
+ AddIncludePriorDatesParam,
+ RemoveTIRuntimeChecksEndpoint,
+)
bundle = VersionBundle(
HeadVersion(),
- Version("2025-03-26", RemoveTIRuntimeChecksEndpoint),
+ Version("2025-03-26", RemoveTIRuntimeChecksEndpoint,
AddIncludePriorDatesParam),
Version("2025-03-19"),
)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
index 98d9b985399..728960b0f6d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
@@ -17,7 +17,9 @@
from __future__ import annotations
-from cadwyn import VersionChange, endpoint
+from cadwyn import VersionChange, endpoint, schema
+
+from airflow.api_fastapi.execution_api.routes.xcoms import GetXcomFilterParams
class RemoveTIRuntimeChecksEndpoint(VersionChange):
@@ -27,3 +29,13 @@ class RemoveTIRuntimeChecksEndpoint(VersionChange):
instructions_to_migrate_to_previous_version = (
endpoint("/task-instances/{task_instance_id}/runtime-checks",
["POST"]).existed,
)
+
+
+class AddIncludePriorDatesParam(VersionChange):
+ """Add the `include_prior_dates` query parameter to the GET XCom API."""
+
+ description = __doc__
+
+ instructions_to_migrate_to_previous_version = (
+ schema(GetXcomFilterParams).field("include_prior_dates").didnt_exist,
+ )
diff --git a/airflow-core/src/airflow/models/xcom.py
b/airflow-core/src/airflow/models/xcom.py
index e2fdf49ad6f..88549d65eb5 100644
--- a/airflow-core/src/airflow/models/xcom.py
+++ b/airflow-core/src/airflow/models/xcom.py
@@ -31,6 +31,7 @@ from sqlalchemy import (
PrimaryKeyConstraint,
String,
delete,
+ func,
select,
text,
)
@@ -303,8 +304,17 @@ class XComModel(TaskInstanceDependencies):
query = query.filter(cls.map_index == map_indexes)
if include_prior_dates:
- dr = session.query(DagRun.logical_date).filter(DagRun.run_id ==
run_id).subquery()
- query = query.filter(cls.logical_date <= dr.c.logical_date)
+ dr = (
+ session.query(
+ func.coalesce(DagRun.logical_date,
DagRun.run_after).label("logical_date_or_run_after")
+ )
+ .filter(DagRun.run_id == run_id)
+ .subquery()
+ )
+
+ query = query.filter(
+ func.coalesce(DagRun.logical_date, DagRun.run_after) <=
dr.c.logical_date_or_run_after
+ )
else:
query = query.filter(cls.run_id == run_id)
diff --git a/airflow-core/tests/unit/models/test_xcom.py
b/airflow-core/tests/unit/models/test_xcom.py
index a7c6d7845d2..35391d6a6d9 100644
--- a/airflow-core/tests/unit/models/test_xcom.py
+++ b/airflow-core/tests/unit/models/test_xcom.py
@@ -58,17 +58,20 @@ def reset_db():
@pytest.fixture
def task_instance_factory(request, session: Session):
- def func(*, dag_id, task_id, logical_date):
+ def func(*, dag_id, task_id, logical_date, run_after=None):
run_id = DagRun.generate_run_id(
- run_type=DagRunType.SCHEDULED, logical_date=logical_date,
run_after=logical_date
+ run_type=DagRunType.SCHEDULED,
+ logical_date=logical_date,
+ run_after=run_after if run_after is not None else logical_date,
)
+ interval = (logical_date, logical_date) if logical_date else None
run = DagRun(
dag_id=dag_id,
run_type=DagRunType.SCHEDULED,
run_id=run_id,
logical_date=logical_date,
- data_interval=(logical_date, logical_date),
- run_after=logical_date,
+ data_interval=interval,
+ run_after=run_after if run_after is not None else logical_date,
)
session.add(run)
ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
@@ -220,6 +223,25 @@ class TestXComGet:
return ti1, ti2
+ @pytest.fixture
+ def tis_for_xcom_get_one_from_prior_date_without_logical_date(
+ self, task_instance_factory, push_simple_json_xcom
+ ):
+ date1 = timezone.datetime(2021, 12, 3, 4, 56)
+ ti1 = task_instance_factory(dag_id="dag", logical_date=None,
task_id="task_1", run_after=date1)
+ ti2 = task_instance_factory(
+ dag_id="dag",
+ logical_date=None,
+ run_after=date1 + datetime.timedelta(days=1),
+ task_id="task_1",
+ )
+
+ # The earlier run pushes an XCom, but not the later run, but the later
+ # run can get this earlier XCom with ``include_prior_dates``.
+ push_simple_json_xcom(ti=ti1, key="xcom_1", value={"key": "value"})
+
+ return ti1, ti2
+
def test_xcom_get_one_from_prior_date(self, session,
tis_for_xcom_get_one_from_prior_date):
_, ti2 = tis_for_xcom_get_one_from_prior_date
retrieved_value = XComModel.get_many(
@@ -232,6 +254,20 @@ class TestXComGet:
).first()
assert XComModel.deserialize_value(retrieved_value) == {"key": "value"}
+ def test_xcom_get_one_from_prior_date_with_no_logical_dates(
+ self, session,
tis_for_xcom_get_one_from_prior_date_without_logical_date
+ ):
+ _, ti2 = tis_for_xcom_get_one_from_prior_date_without_logical_date
+ retrieved_value = XComModel.get_many(
+ run_id=ti2.run_id,
+ key="xcom_1",
+ task_ids="task_1",
+ dag_ids="dag",
+ include_prior_dates=True,
+ session=session,
+ ).first()
+ assert XComModel.deserialize_value(retrieved_value) == {"key": "value"}
+
@pytest.fixture
def setup_for_xcom_get_many_single_argument_value(self, task_instance,
push_simple_json_xcom):
push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key":
"value"})
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 0ab6bae0344..a6d757101fb 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -274,7 +274,13 @@ class XComOperations:
return int(content_range[len("map_indexes ") :])
def get(
- self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int
| None = None
+ self,
+ dag_id: str,
+ run_id: str,
+ task_id: str,
+ key: str,
+ map_index: int | None = None,
+ include_prior_dates: bool = False,
) -> XComResponse:
"""Get a XCom value from the API server."""
# TODO: check if we need to use map_index as params in the uri
@@ -282,6 +288,8 @@ class XComOperations:
params = {}
if map_index is not None and map_index >= 0:
params.update({"map_index": map_index})
+ if include_prior_dates:
+ params.update({"include_prior_dates": include_prior_dates})
try:
resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}",
params=params)
except ServerResponseError as e:
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index a667bc49f1a..ab7dfac1d39 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -319,6 +319,7 @@ class GetXCom(BaseModel):
run_id: str
task_id: str
map_index: int | None = None
+ include_prior_dates: bool = False
type: Literal["GetXCom"] = "GetXCom"
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index deceae7fb43..f8272e39c75 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -896,7 +896,9 @@ class ActivitySubprocess(WatchedSubprocess):
else:
resp = var.model_dump_json().encode()
elif isinstance(msg, GetXCom):
- xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id,
msg.key, msg.map_index)
+ xcom = self.client.xcoms.get(
+ msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index,
msg.include_prior_dates
+ )
xcom_result = XComResult.from_xcom_response(xcom)
resp = xcom_result.model_dump_json().encode()
elif isinstance(msg, GetXComCount):
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index de3229fcba7..f64582f5cbc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -341,6 +341,7 @@ class RuntimeTaskInstance(TaskInstance):
task_id=t_id,
dag_id=dag_id,
map_index=m_idx,
+ include_prior_dates=include_prior_dates,
)
xcoms.append(value if value else default)
diff --git a/task-sdk/src/airflow/sdk/execution_time/xcom.py
b/task-sdk/src/airflow/sdk/execution_time/xcom.py
index a10d831e2ee..abb964907f1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/xcom.py
+++ b/task-sdk/src/airflow/sdk/execution_time/xcom.py
@@ -241,6 +241,7 @@ class BaseXCom:
task_id=task_id,
run_id=run_id,
map_index=map_index,
+ include_prior_dates=include_prior_dates,
),
)
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index 5549178a794..409c140972d 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -511,6 +511,30 @@ class TestXCOMOperations:
assert result.key == "test_key"
assert result.value == "test_value"
+ def test_xcom_get_success_with_include_prior_dates(self):
+ # Simulate a successful response from the server when getting an xcom
with include_prior_dates passed
+ def handle_request(request: httpx.Request) -> httpx.Response:
+ if request.url.path == "/xcoms/dag_id/run_id/task_id/key" and
request.url.params.get(
+ "include_prior_dates"
+ ):
+ return httpx.Response(
+ status_code=201,
+ json={"key": "test_key", "value": "test_value"},
+ )
+ return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
+
+ client = make_client(transport=httpx.MockTransport(handle_request))
+ result = client.xcoms.get(
+ dag_id="dag_id",
+ run_id="run_id",
+ task_id="task_id",
+ key="key",
+ include_prior_dates=True,
+ )
+ assert isinstance(result, XComResponse)
+ assert result.key == "test_key"
+ assert result.value == "test_value"
+
@mock.patch("time.sleep", return_value=None)
def test_xcom_get_500_error(self, mock_sleep):
# Simulate a successful response from the server returning a 500 error
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index f1011bb9131..fcdeb46b9bd 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -1016,7 +1016,7 @@ class TestHandleRequest:
GetXCom(dag_id="test_dag", run_id="test_run",
task_id="test_task", key="test_key"),
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
- ("test_dag", "test_run", "test_task", "test_key", None),
+ ("test_dag", "test_run", "test_task", "test_key", None, False),
{},
XComResult(key="test_key", value="test_value"),
id="get_xcom",
@@ -1027,7 +1027,7 @@ class TestHandleRequest:
),
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
"xcoms.get",
- ("test_dag", "test_run", "test_task", "test_key", 2),
+ ("test_dag", "test_run", "test_task", "test_key", 2, False),
{},
XComResult(key="test_key", value="test_value"),
id="get_xcom_map_index",
@@ -1036,11 +1036,26 @@ class TestHandleRequest:
GetXCom(dag_id="test_dag", run_id="test_run",
task_id="test_task", key="test_key"),
b'{"key":"test_key","value":null,"type":"XComResult"}\n',
"xcoms.get",
- ("test_dag", "test_run", "test_task", "test_key", None),
+ ("test_dag", "test_run", "test_task", "test_key", None, False),
{},
XComResult(key="test_key", value=None, type="XComResult"),
id="get_xcom_not_found",
),
+ pytest.param(
+ GetXCom(
+ dag_id="test_dag",
+ run_id="test_run",
+ task_id="test_task",
+ key="test_key",
+ include_prior_dates=True,
+ ),
+ b'{"key":"test_key","value":null,"type":"XComResult"}\n',
+ "xcoms.get",
+ ("test_dag", "test_run", "test_task", "test_key", None, True),
+ {},
+ XComResult(key="test_key", value=None, type="XComResult"),
+ id="get_xcom_include_prior_dates",
+ ),
pytest.param(
SetXCom(
dag_id="test_dag",
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 418a48a37e4..808083d1495 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1577,6 +1577,7 @@ class TestXComAfterTaskExecution:
task_id="pull_task",
run_id="test_run",
map_index=-1,
+ include_prior_dates=False,
)
assert not any(