This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 23f141d5b2a AIP-72: Add "XCom" POST endpoint for Execution API (#44101)
23f141d5b2a is described below

commit 23f141d5b2aa41eb8bbecc4fecb75473f168d1e6
Author: Kaxil Naik <[email protected]>
AuthorDate: Mon Nov 18 15:32:59 2024 +0000

    AIP-72: Add "XCom" POST endpoint for Execution API (#44101)
    
    closes https://github.com/apache/airflow/issues/44100
    
    Follow-up of https://github.com/apache/airflow/pull/43894
---
 airflow/api_fastapi/execution_api/app.py           |  1 +
 airflow/api_fastapi/execution_api/routes/xcoms.py  | 82 ++++++++++++++++++++--
 .../src/airflow/sdk/api/datamodels/_generated.py   |  4 ++
 .../api_fastapi/execution_api/routes/test_xcoms.py | 49 ++++++++++++-
 4 files changed, 129 insertions(+), 7 deletions(-)

diff --git a/airflow/api_fastapi/execution_api/app.py 
b/airflow/api_fastapi/execution_api/app.py
index e019e8f14f3..61283dc2cf8 100644
--- a/airflow/api_fastapi/execution_api/app.py
+++ b/airflow/api_fastapi/execution_api/app.py
@@ -58,6 +58,7 @@ def create_task_execution_api_app(app: FastAPI) -> FastAPI:
             description=app.description,
             version=app.version,
             routes=app.routes,
+            servers=app.servers,
         )
 
         extra_schemas = get_extra_schemas()
diff --git a/airflow/api_fastapi/execution_api/routes/xcoms.py 
b/airflow/api_fastapi/execution_api/routes/xcoms.py
index 083947923dc..12dc50c7b97 100644
--- a/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -21,7 +21,8 @@ import json
 import logging
 from typing import Annotated
 
-from fastapi import Depends, HTTPException, Query, status
+from fastapi import Body, Depends, HTTPException, Query, status
+from pydantic import Json
 from sqlalchemy.orm import Session
 
 from airflow.api_fastapi.common.db.common import get_session
@@ -33,7 +34,10 @@ from airflow.models.xcom import BaseXCom
 
 # TODO: Add dependency on JWT token
 router = AirflowRouter(
-    responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
+    responses={
+        status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
+        status.HTTP_403_FORBIDDEN: {"description": "Task does not have access 
to the XCom"},
+    },
 )
 
 log = logging.getLogger(__name__)
@@ -41,10 +45,7 @@ log = logging.getLogger(__name__)
 
 @router.get(
     "/{dag_id}/{run_id}/{task_id}/{key}",
-    responses={
-        status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"},
-        status.HTTP_403_FORBIDDEN: {"description": "Task does not have access 
to the XCom"},
-    },
+    responses={status.HTTP_404_NOT_FOUND: {"description": "XCom not found"}},
 )
 def get_xcom(
     dag_id: str,
@@ -105,6 +106,75 @@ def get_xcom(
     return XComResponse(key=key, value=xcom_value)
 
 
[email protected](
+    "/{dag_id}/{run_id}/{task_id}/{key}",
+    status_code=status.HTTP_201_CREATED,
+    responses={
+        status.HTTP_400_BAD_REQUEST: {"description": "Invalid request body"},
+    },
+)
+def set_xcom(
+    dag_id: str,
+    run_id: str,
+    task_id: str,
+    key: str,
+    value: Annotated[
+        Json,
+        Body(
+            description="A JSON-formatted string representing the value to set 
for the XCom.",
+            openapi_examples={
+                "simple_value": {
+                    "summary": "Simple value",
+                    "value": '"value1"',
+                },
+                "dict_value": {
+                    "summary": "Dictionary value",
+                    "value": '{"key2": "value2"}',
+                },
+                "list_value": {
+                    "summary": "List value",
+                    "value": '["value1"]',
+                },
+            },
+        ),
+    ],
+    token: deps.TokenDep,
+    session: Annotated[Session, Depends(get_session)],
+    map_index: Annotated[int, Query()] = -1,
+):
+    """Set an Airflow XCom."""
+    if not has_xcom_access(key, token):
+        raise HTTPException(
+            status_code=status.HTTP_403_FORBIDDEN,
+            detail={
+                "reason": "access_denied",
+                "message": f"Task does not have access to set XCom key 
'{key}'",
+            },
+        )
+
+    # We use `BaseXCom.set` to set XComs directly to the database, bypassing 
the XCom Backend.
+    try:
+        BaseXCom.set(
+            key=key,
+            value=value,
+            dag_id=dag_id,
+            task_id=task_id,
+            run_id=run_id,
+            session=session,
+            map_index=map_index,
+        )
+    except TypeError as e:
+        raise HTTPException(
+            status_code=status.HTTP_400_BAD_REQUEST,
+            detail={
+                "reason": "invalid_format",
+                "message": f"XCom value is not a valid JSON: {e}",
+            },
+        )
+
+    return {"message": "XCom successfully set"}
+
+
 def has_xcom_access(xcom_key: str, token: TIToken) -> bool:
     """Check if the task has access to the XCom."""
     # TODO: Placeholder for actual implementation
diff --git a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
index c1d10f74d4a..0cfd3a12c40 100644
--- a/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task_sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -123,6 +123,10 @@ class XComResponse(BaseModel):
 
 
 class TaskInstance(BaseModel):
+    """
+    Schema for TaskInstance model with minimal required fields needed for 
Runtime.
+    """
+
     id: Annotated[UUID, Field(title="Id")]
     task_id: Annotated[str, Field(title="Task Id")]
     dag_id: Annotated[str, Field(title="Dag Id")]
diff --git a/tests/api_fastapi/execution_api/routes/test_xcoms.py 
b/tests/api_fastapi/execution_api/routes/test_xcoms.py
index a7cdde6c64f..d9d33f28d44 100644
--- a/tests/api_fastapi/execution_api/routes/test_xcoms.py
+++ b/tests/api_fastapi/execution_api/routes/test_xcoms.py
@@ -73,7 +73,6 @@ class TestXComsGetEndpoint:
         with 
mock.patch("airflow.api_fastapi.execution_api.routes.xcoms.has_xcom_access", 
return_value=False):
             response = client.get("/execution/xcoms/dag/runid/task/xcom_perms")
 
-        # Assert response status code and detail for access denied
         assert response.status_code == 403
         assert response.json() == {
             "detail": {
@@ -81,3 +80,51 @@ class TestXComsGetEndpoint:
                 "message": "Task does not have access to XCom key 
'xcom_perms'",
             }
         }
+
+
+class TestXComsSetEndpoint:
+    @pytest.mark.parametrize(
+        ("value", "expected_value"),
+        [
+            ('"value1"', "value1"),
+            ('{"key2": "value2"}', {"key2": "value2"}),
+            ('{"key2": "value2", "key3": ["value3"]}', {"key2": "value2", 
"key3": ["value3"]}),
+            ('["value1"]', ["value1"]),
+        ],
+    )
+    def test_xcom_set(self, client, create_task_instance, session, value, 
expected_value):
+        """
+        Test that XCom value is set correctly. The value is passed as a JSON 
string in the request body.
+        This is then validated via Pydantic.Json type in the request body and 
converted to
+        a Python object before being sent to XCom.set. XCom.set then uses 
json.dumps to
+        serialize it and store the value in the database. This is done so that 
Task SDK in multiple
+        languages can use the same API to set XCom values.
+        """
+        ti = create_task_instance()
+        session.commit()
+
+        response = client.post(
+            f"/execution/xcoms/{ti.dag_id}/{ti.run_id}/{ti.task_id}/xcom_1",
+            json=value,
+        )
+
+        assert response.status_code == 201
+        assert response.json() == {"message": "XCom successfully set"}
+
+        xcom = session.query(XCom).filter_by(task_id=ti.task_id, 
dag_id=ti.dag_id, key="xcom_1").first()
+        assert xcom.value == expected_value
+
+    def test_xcom_access_denied(self, client):
+        with 
mock.patch("airflow.api_fastapi.execution_api.routes.xcoms.has_xcom_access", 
return_value=False):
+            response = client.post(
+                "/execution/xcoms/dag/runid/task/xcom_perms",
+                json='"value1"',
+            )
+
+        assert response.status_code == 403
+        assert response.json() == {
+            "detail": {
+                "reason": "access_denied",
+                "message": "Task does not have access to set XCom key 
'xcom_perms'",
+            }
+        }

Reply via email to