This is an automated email from the ASF dual-hosted git repository.

amoghdesai pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ef86d4e4813 Add support to get XComs with 'include_prior_dates' 
(#48440)
ef86d4e4813 is described below

commit ef86d4e481399bda2e8bf0a9cfeb4f68302f05ff
Author: Amogh Desai <[email protected]>
AuthorDate: Fri Mar 28 15:09:40 2025 +0530

    Add support to get XComs with 'include_prior_dates' (#48440)
    
    * Add support to get XComs with 'include_prior_dates'
    
    * adding API versioning
---
 .../api_fastapi/execution_api/routes/xcoms.py      | 25 ++++++++++--
 .../api_fastapi/execution_api/versions/__init__.py |  7 +++-
 .../execution_api/versions/v2025_03_26.py          | 14 ++++++-
 airflow-core/src/airflow/models/xcom.py            | 14 ++++++-
 airflow-core/tests/unit/models/test_xcom.py        | 44 ++++++++++++++++++++--
 task-sdk/src/airflow/sdk/api/client.py             | 10 ++++-
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  1 +
 .../src/airflow/sdk/execution_time/supervisor.py   |  4 +-
 .../src/airflow/sdk/execution_time/task_runner.py  |  1 +
 task-sdk/src/airflow/sdk/execution_time/xcom.py    |  1 +
 task-sdk/tests/task_sdk/api/test_client.py         | 24 ++++++++++++
 .../task_sdk/execution_time/test_supervisor.py     | 21 +++++++++--
 .../task_sdk/execution_time/test_task_runner.py    |  1 +
 13 files changed, 149 insertions(+), 18 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index e75d111c0fe..2f28acc4bc1 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -22,7 +22,7 @@ import sys
 from typing import Annotated, Any
 
 from fastapi import Body, Depends, HTTPException, Path, Query, Request, 
Response, status
-from pydantic import JsonValue
+from pydantic import BaseModel, JsonValue
 from sqlalchemy import delete
 from sqlalchemy.sql.selectable import Select
 
@@ -122,6 +122,13 @@ def head_xcom(
     response.headers["Content-Range"] = f"map_indexes {count}"
 
 
+class GetXcomFilterParams(BaseModel):
+    """Class to house the params that can optionally be set for Get XCom."""
+
+    map_index: int = -1
+    include_prior_dates: bool = False
+
+
 @router.get(
     "/{dag_id}/{run_id}/{task_id}/{key}",
     description="Get a single XCom Value",
@@ -131,13 +138,22 @@ def get_xcom(
     run_id: str,
     task_id: str,
     key: str,
-    xcom_query: Annotated[Select, Depends(xcom_query)],
-    map_index: Annotated[int, Query()] = -1,
+    session: SessionDep,
+    params: Annotated[GetXcomFilterParams, Query()],
 ) -> XComResponse:
     """Get an Airflow XCom from database - not other XCom Backends."""
     # The xcom_query allows no map_index to be passed. This endpoint should 
always return just a single item,
     # so we override that query value
-    xcom_query = xcom_query.filter(XComModel.map_index == map_index)
+    xcom_query = XComModel.get_many(
+        run_id=run_id,
+        key=key,
+        task_ids=task_id,
+        dag_ids=dag_id,
+        map_indexes=params.map_index,
+        include_prior_dates=params.include_prior_dates,
+        session=session,
+    )
+    xcom_query = xcom_query.filter(XComModel.map_index == params.map_index)
     # We use `BaseXCom.get_many` to fetch XComs directly from the database, 
bypassing the XCom Backend.
     # This avoids deserialization via the backend (e.g., from a remote storage 
like S3) and instead
     # retrieves the raw serialized value from the database. By not relying on 
`XCom.get_many` or `XCom.get_one`
@@ -145,6 +161,7 @@ def get_xcom(
     # performance hits from retrieving large data files into the API server.
     result = xcom_query.limit(1).first()
     if result is None:
+        map_index = params.map_index
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
             detail={
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index af93aab29ba..5fb011edad5 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -19,10 +19,13 @@ from __future__ import annotations
 
 from cadwyn import HeadVersion, Version, VersionBundle
 
-from airflow.api_fastapi.execution_api.versions.v2025_03_26 import 
RemoveTIRuntimeChecksEndpoint
+from airflow.api_fastapi.execution_api.versions.v2025_03_26 import (
+    AddIncludePriorDatesParam,
+    RemoveTIRuntimeChecksEndpoint,
+)
 
 bundle = VersionBundle(
     HeadVersion(),
-    Version("2025-03-26", RemoveTIRuntimeChecksEndpoint),
+    Version("2025-03-26", RemoveTIRuntimeChecksEndpoint, 
AddIncludePriorDatesParam),
     Version("2025-03-19"),
 )
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
index 98d9b985399..728960b0f6d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_03_26.py
@@ -17,7 +17,9 @@
 
 from __future__ import annotations
 
-from cadwyn import VersionChange, endpoint
+from cadwyn import VersionChange, endpoint, schema
+
+from airflow.api_fastapi.execution_api.routes.xcoms import GetXcomFilterParams
 
 
 class RemoveTIRuntimeChecksEndpoint(VersionChange):
@@ -27,3 +29,13 @@ class RemoveTIRuntimeChecksEndpoint(VersionChange):
     instructions_to_migrate_to_previous_version = (
         endpoint("/task-instances/{task_instance_id}/runtime-checks", 
["POST"]).existed,
     )
+
+
+class AddIncludePriorDatesParam(VersionChange):
+    """Add the `include_prior_dates` query parameter to the GET XCom API."""
+
+    description = __doc__
+
+    instructions_to_migrate_to_previous_version = (
+        schema(GetXcomFilterParams).field("include_prior_dates").didnt_exist,
+    )
diff --git a/airflow-core/src/airflow/models/xcom.py 
b/airflow-core/src/airflow/models/xcom.py
index e2fdf49ad6f..88549d65eb5 100644
--- a/airflow-core/src/airflow/models/xcom.py
+++ b/airflow-core/src/airflow/models/xcom.py
@@ -31,6 +31,7 @@ from sqlalchemy import (
     PrimaryKeyConstraint,
     String,
     delete,
+    func,
     select,
     text,
 )
@@ -303,8 +304,17 @@ class XComModel(TaskInstanceDependencies):
             query = query.filter(cls.map_index == map_indexes)
 
         if include_prior_dates:
-            dr = session.query(DagRun.logical_date).filter(DagRun.run_id == 
run_id).subquery()
-            query = query.filter(cls.logical_date <= dr.c.logical_date)
+            dr = (
+                session.query(
+                    func.coalesce(DagRun.logical_date, 
DagRun.run_after).label("logical_date_or_run_after")
+                )
+                .filter(DagRun.run_id == run_id)
+                .subquery()
+            )
+
+            query = query.filter(
+                func.coalesce(DagRun.logical_date, DagRun.run_after) <= 
dr.c.logical_date_or_run_after
+            )
         else:
             query = query.filter(cls.run_id == run_id)
 
diff --git a/airflow-core/tests/unit/models/test_xcom.py 
b/airflow-core/tests/unit/models/test_xcom.py
index a7c6d7845d2..35391d6a6d9 100644
--- a/airflow-core/tests/unit/models/test_xcom.py
+++ b/airflow-core/tests/unit/models/test_xcom.py
@@ -58,17 +58,20 @@ def reset_db():
 
 @pytest.fixture
 def task_instance_factory(request, session: Session):
-    def func(*, dag_id, task_id, logical_date):
+    def func(*, dag_id, task_id, logical_date, run_after=None):
         run_id = DagRun.generate_run_id(
-            run_type=DagRunType.SCHEDULED, logical_date=logical_date, 
run_after=logical_date
+            run_type=DagRunType.SCHEDULED,
+            logical_date=logical_date,
+            run_after=run_after if run_after is not None else logical_date,
         )
+        interval = (logical_date, logical_date) if logical_date else None
         run = DagRun(
             dag_id=dag_id,
             run_type=DagRunType.SCHEDULED,
             run_id=run_id,
             logical_date=logical_date,
-            data_interval=(logical_date, logical_date),
-            run_after=logical_date,
+            data_interval=interval,
+            run_after=run_after if run_after is not None else logical_date,
         )
         session.add(run)
         ti = TaskInstance(EmptyOperator(task_id=task_id), run_id=run_id)
@@ -220,6 +223,25 @@ class TestXComGet:
 
         return ti1, ti2
 
+    @pytest.fixture
+    def tis_for_xcom_get_one_from_prior_date_without_logical_date(
+        self, task_instance_factory, push_simple_json_xcom
+    ):
+        date1 = timezone.datetime(2021, 12, 3, 4, 56)
+        ti1 = task_instance_factory(dag_id="dag", logical_date=None, 
task_id="task_1", run_after=date1)
+        ti2 = task_instance_factory(
+            dag_id="dag",
+            logical_date=None,
+            run_after=date1 + datetime.timedelta(days=1),
+            task_id="task_1",
+        )
+
+        # The earlier run pushes an XCom, but not the later run, but the later
+        # run can get this earlier XCom with ``include_prior_dates``.
+        push_simple_json_xcom(ti=ti1, key="xcom_1", value={"key": "value"})
+
+        return ti1, ti2
+
     def test_xcom_get_one_from_prior_date(self, session, 
tis_for_xcom_get_one_from_prior_date):
         _, ti2 = tis_for_xcom_get_one_from_prior_date
         retrieved_value = XComModel.get_many(
@@ -232,6 +254,20 @@ class TestXComGet:
         ).first()
         assert XComModel.deserialize_value(retrieved_value) == {"key": "value"}
 
+    def test_xcom_get_one_from_prior_date_with_no_logical_dates(
+        self, session, 
tis_for_xcom_get_one_from_prior_date_without_logical_date
+    ):
+        _, ti2 = tis_for_xcom_get_one_from_prior_date_without_logical_date
+        retrieved_value = XComModel.get_many(
+            run_id=ti2.run_id,
+            key="xcom_1",
+            task_ids="task_1",
+            dag_ids="dag",
+            include_prior_dates=True,
+            session=session,
+        ).first()
+        assert XComModel.deserialize_value(retrieved_value) == {"key": "value"}
+
     @pytest.fixture
     def setup_for_xcom_get_many_single_argument_value(self, task_instance, 
push_simple_json_xcom):
         push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key": 
"value"})
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 0ab6bae0344..a6d757101fb 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -274,7 +274,13 @@ class XComOperations:
         return int(content_range[len("map_indexes ") :])
 
     def get(
-        self, dag_id: str, run_id: str, task_id: str, key: str, map_index: int 
| None = None
+        self,
+        dag_id: str,
+        run_id: str,
+        task_id: str,
+        key: str,
+        map_index: int | None = None,
+        include_prior_dates: bool = False,
     ) -> XComResponse:
         """Get a XCom value from the API server."""
         # TODO: check if we need to use map_index as params in the uri
@@ -282,6 +288,8 @@ class XComOperations:
         params = {}
         if map_index is not None and map_index >= 0:
             params.update({"map_index": map_index})
+        if include_prior_dates:
+            params.update({"include_prior_dates": include_prior_dates})
         try:
             resp = self.client.get(f"xcoms/{dag_id}/{run_id}/{task_id}/{key}", 
params=params)
         except ServerResponseError as e:
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index a667bc49f1a..ab7dfac1d39 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -319,6 +319,7 @@ class GetXCom(BaseModel):
     run_id: str
     task_id: str
     map_index: int | None = None
+    include_prior_dates: bool = False
     type: Literal["GetXCom"] = "GetXCom"
 
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index deceae7fb43..f8272e39c75 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -896,7 +896,9 @@ class ActivitySubprocess(WatchedSubprocess):
             else:
                 resp = var.model_dump_json().encode()
         elif isinstance(msg, GetXCom):
-            xcom = self.client.xcoms.get(msg.dag_id, msg.run_id, msg.task_id, 
msg.key, msg.map_index)
+            xcom = self.client.xcoms.get(
+                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, 
msg.include_prior_dates
+            )
             xcom_result = XComResult.from_xcom_response(xcom)
             resp = xcom_result.model_dump_json().encode()
         elif isinstance(msg, GetXComCount):
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index de3229fcba7..f64582f5cbc 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -341,6 +341,7 @@ class RuntimeTaskInstance(TaskInstance):
                 task_id=t_id,
                 dag_id=dag_id,
                 map_index=m_idx,
+                include_prior_dates=include_prior_dates,
             )
             xcoms.append(value if value else default)
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/xcom.py 
b/task-sdk/src/airflow/sdk/execution_time/xcom.py
index a10d831e2ee..abb964907f1 100644
--- a/task-sdk/src/airflow/sdk/execution_time/xcom.py
+++ b/task-sdk/src/airflow/sdk/execution_time/xcom.py
@@ -241,6 +241,7 @@ class BaseXCom:
                 task_id=task_id,
                 run_id=run_id,
                 map_index=map_index,
+                include_prior_dates=include_prior_dates,
             ),
         )
 
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index 5549178a794..409c140972d 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -511,6 +511,30 @@ class TestXCOMOperations:
         assert result.key == "test_key"
         assert result.value == "test_value"
 
+    def test_xcom_get_success_with_include_prior_dates(self):
+        # Simulate a successful response from the server when getting an xcom 
with include_prior_dates passed
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            if request.url.path == "/xcoms/dag_id/run_id/task_id/key" and 
request.url.params.get(
+                "include_prior_dates"
+            ):
+                return httpx.Response(
+                    status_code=201,
+                    json={"key": "test_key", "value": "test_value"},
+                )
+            return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.xcoms.get(
+            dag_id="dag_id",
+            run_id="run_id",
+            task_id="task_id",
+            key="key",
+            include_prior_dates=True,
+        )
+        assert isinstance(result, XComResponse)
+        assert result.key == "test_key"
+        assert result.value == "test_value"
+
     @mock.patch("time.sleep", return_value=None)
     def test_xcom_get_500_error(self, mock_sleep):
         # Simulate a successful response from the server returning a 500 error
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index f1011bb9131..fcdeb46b9bd 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -1016,7 +1016,7 @@ class TestHandleRequest:
                 GetXCom(dag_id="test_dag", run_id="test_run", 
task_id="test_task", key="test_key"),
                 
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
                 "xcoms.get",
-                ("test_dag", "test_run", "test_task", "test_key", None),
+                ("test_dag", "test_run", "test_task", "test_key", None, False),
                 {},
                 XComResult(key="test_key", value="test_value"),
                 id="get_xcom",
@@ -1027,7 +1027,7 @@ class TestHandleRequest:
                 ),
                 
b'{"key":"test_key","value":"test_value","type":"XComResult"}\n',
                 "xcoms.get",
-                ("test_dag", "test_run", "test_task", "test_key", 2),
+                ("test_dag", "test_run", "test_task", "test_key", 2, False),
                 {},
                 XComResult(key="test_key", value="test_value"),
                 id="get_xcom_map_index",
@@ -1036,11 +1036,26 @@ class TestHandleRequest:
                 GetXCom(dag_id="test_dag", run_id="test_run", 
task_id="test_task", key="test_key"),
                 b'{"key":"test_key","value":null,"type":"XComResult"}\n',
                 "xcoms.get",
-                ("test_dag", "test_run", "test_task", "test_key", None),
+                ("test_dag", "test_run", "test_task", "test_key", None, False),
                 {},
                 XComResult(key="test_key", value=None, type="XComResult"),
                 id="get_xcom_not_found",
             ),
+            pytest.param(
+                GetXCom(
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    key="test_key",
+                    include_prior_dates=True,
+                ),
+                b'{"key":"test_key","value":null,"type":"XComResult"}\n',
+                "xcoms.get",
+                ("test_dag", "test_run", "test_task", "test_key", None, True),
+                {},
+                XComResult(key="test_key", value=None, type="XComResult"),
+                id="get_xcom_include_prior_dates",
+            ),
             pytest.param(
                 SetXCom(
                     dag_id="test_dag",
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 418a48a37e4..808083d1495 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -1577,6 +1577,7 @@ class TestXComAfterTaskExecution:
             task_id="pull_task",
             run_id="test_run",
             map_index=-1,
+            include_prior_dates=False,
         )
 
         assert not any(

Reply via email to