This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 98f615961ea Add create XCom endpoint in RestAPI (#46042)
98f615961ea is described below

commit 98f615961ea907ef0b7456635cde9eac700a57cf
Author: Shubham Raj <[email protected]>
AuthorDate: Tue Feb 4 21:01:36 2025 +0530

    Add create XCom endpoint in RestAPI (#46042)
    
    * create xcom api
    
    * Add tests for create xcom API
    
    * Small tweak
    
    ---------
    
    Co-authored-by: pierrejeambrun <[email protected]>
---
 airflow/api_fastapi/core_api/datamodels/xcom.py    |  10 +-
 .../api_fastapi/core_api/openapi/v1-generated.yaml |  86 +++++++++++++++++
 airflow/api_fastapi/core_api/routes/public/xcom.py |  85 +++++++++++++++-
 airflow/ui/openapi-gen/queries/common.ts           |   3 +
 airflow/ui/openapi-gen/queries/queries.ts          |  47 +++++++++
 airflow/ui/openapi-gen/requests/schemas.gen.ts     |  22 +++++
 airflow/ui/openapi-gen/requests/services.gen.ts    |  34 +++++++
 airflow/ui/openapi-gen/requests/types.gen.ts       |  47 +++++++++
 .../core_api/routes/public/test_xcom.py            | 107 ++++++++++++++++++---
 9 files changed, 426 insertions(+), 15 deletions(-)

diff --git a/airflow/api_fastapi/core_api/datamodels/xcom.py 
b/airflow/api_fastapi/core_api/datamodels/xcom.py
index b63db3ff87d..f874f8bdeed 100644
--- a/airflow/api_fastapi/core_api/datamodels/xcom.py
+++ b/airflow/api_fastapi/core_api/datamodels/xcom.py
@@ -21,7 +21,7 @@ from typing import Any
 
 from pydantic import field_validator
 
-from airflow.api_fastapi.core_api.base import BaseModel
+from airflow.api_fastapi.core_api.base import BaseModel, StrictBaseModel
 
 
 class XComResponse(BaseModel):
@@ -57,3 +57,11 @@ class XComCollectionResponse(BaseModel):
 
     xcom_entries: list[XComResponse]
     total_entries: int
+
+
+class XComCreateBody(StrictBaseModel):
+    """Payload serializer for creating an XCom entry."""
+
+    key: str
+    value: Any
+    map_index: int = -1
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index b5b4a91ee5f..81a799dc513 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -4480,6 +4480,74 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+    post:
+      tags:
+      - XCom
+      summary: Create Xcom Entry
+      description: Create an XCom entry.
+      operationId: create_xcom_entry
+      parameters:
+      - name: dag_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Id
+      - name: task_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Task Id
+      - name: dag_run_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Run Id
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/XComCreateBody'
+      responses:
+        '201':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/XComResponseNative'
+        '401':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Unauthorized
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '400':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Bad Request
+        '404':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Not Found
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
   /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}:
     get:
       tags:
@@ -10484,6 +10552,24 @@ components:
       - total_entries
       title: XComCollectionResponse
       description: XCom Collection serializer for responses.
+    XComCreateBody:
+      properties:
+        key:
+          type: string
+          title: Key
+        value:
+          title: Value
+        map_index:
+          type: integer
+          title: Map Index
+          default: -1
+      additionalProperties: false
+      type: object
+      required:
+      - key
+      - value
+      title: XComCreateBody
+      description: Payload serializer for creating an XCom entry.
     XComResponse:
       properties:
         key:
diff --git a/airflow/api_fastapi/core_api/routes/public/xcom.py 
b/airflow/api_fastapi/core_api/routes/public/xcom.py
index b8fa6456e57..e1ef40685d3 100644
--- a/airflow/api_fastapi/core_api/routes/public/xcom.py
+++ b/airflow/api_fastapi/core_api/routes/public/xcom.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 import copy
 from typing import Annotated
 
-from fastapi import HTTPException, Query, status
+from fastapi import HTTPException, Query, Request, status
 from sqlalchemy import and_, select
 
 from airflow.api_fastapi.common.db.common import SessionDep, paginated_select
@@ -27,11 +27,13 @@ from airflow.api_fastapi.common.parameters import 
QueryLimit, QueryOffset
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.core_api.datamodels.xcom import (
     XComCollectionResponse,
+    XComCreateBody,
     XComResponseNative,
     XComResponseString,
 )
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
-from airflow.models import DagRun as DR, XCom
+from airflow.exceptions import TaskNotFound
+from airflow.models import DAG, DagRun as DR, XCom
 from airflow.settings import conf
 
 xcom_router = AirflowRouter(
@@ -141,3 +143,82 @@ def get_xcom_entries(
     query = query.order_by(XCom.dag_id, XCom.task_id, XCom.run_id, 
XCom.map_index, XCom.key)
     xcoms = session.scalars(query)
     return XComCollectionResponse(xcom_entries=xcoms, 
total_entries=total_entries)
+
+
+@xcom_router.post(
+    "",
+    status_code=status.HTTP_201_CREATED,
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_404_NOT_FOUND,
+        ]
+    ),
+)
+def create_xcom_entry(
+    dag_id: str,
+    task_id: str,
+    dag_run_id: str,
+    request_body: XComCreateBody,
+    session: SessionDep,
+    request: Request,
+) -> XComResponseNative:
+    """Create an XCom entry."""
+    # Validate DAG ID
+    dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
+    if not dag:
+        raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with ID: 
`{dag_id}` was not found")
+
+    # Validate Task ID
+    try:
+        dag.get_task(task_id)
+    except TaskNotFound:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND, f"Task with ID: `{task_id}` not found 
in DAG: `{dag_id}`"
+        )
+
+    # Validate DAG Run ID
+    dag_run = dag.get_dagrun(dag_run_id, session)
+    if not dag_run:
+        raise HTTPException(
+            status.HTTP_404_NOT_FOUND, f"DAG Run with ID: `{dag_run_id}` not 
found for DAG: `{dag_id}`"
+        )
+
+    # Check existing XCom
+    if XCom.get_one(
+        key=request_body.key,
+        task_id=task_id,
+        dag_id=dag_id,
+        run_id=dag_run_id,
+        map_index=request_body.map_index,
+        session=session,
+    ):
+        raise HTTPException(
+            status_code=status.HTTP_409_CONFLICT,
+            detail=f"The XCom with key: `{request_body.key}` with mentioned 
task instance already exists.",
+        )
+
+    # Create XCom entry
+    XCom.set(
+        dag_id=dag_id,
+        task_id=task_id,
+        run_id=dag_run_id,
+        key=request_body.key,
+        value=XCom.serialize_value(request_body.value),
+        map_index=request_body.map_index,
+        session=session,
+    )
+
+    xcom = session.scalar(
+        select(XCom)
+        .filter(
+            XCom.dag_id == dag_id,
+            XCom.task_id == task_id,
+            XCom.run_id == dag_run_id,
+            XCom.key == request_body.key,
+            XCom.map_index == request_body.map_index,
+        )
+        .limit(1)
+    )
+
+    return XComResponseNative.model_validate(xcom)
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index dfc5ce38ded..4a0ee27e830 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1668,6 +1668,9 @@ export type 
TaskInstanceServicePostClearTaskInstancesMutationResult = Awaited<
   ReturnType<typeof TaskInstanceService.postClearTaskInstances>
 >;
 export type PoolServicePostPoolMutationResult = Awaited<ReturnType<typeof 
PoolService.postPool>>;
+export type XcomServiceCreateXcomEntryMutationResult = Awaited<
+  ReturnType<typeof XcomService.createXcomEntry>
+>;
 export type VariableServicePostVariableMutationResult = Awaited<
   ReturnType<typeof VariableService.postVariable>
 >;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index 320c226f2a7..3985dd4f5fa 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -51,6 +51,7 @@ import {
   TaskInstancesBatchBody,
   TriggerDAGRunPostBody,
   VariableBody,
+  XComCreateBody,
 } from "../requests/types.gen";
 import * as Common from "./common";
 
@@ -3164,6 +3165,52 @@ export const usePoolServicePostPool = <
     mutationFn: ({ requestBody }) => PoolService.postPool({ requestBody }) as 
unknown as Promise<TData>,
     ...options,
   });
+/**
+ * Create Xcom Entry
+ * Create an XCom entry.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.taskId
+ * @param data.dagRunId
+ * @param data.requestBody
+ * @returns XComResponseNative Successful Response
+ * @throws ApiError
+ */
+export const useXcomServiceCreateXcomEntry = <
+  TData = Common.XcomServiceCreateXcomEntryMutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        dagId: string;
+        dagRunId: string;
+        requestBody: XComCreateBody;
+        taskId: string;
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      dagId: string;
+      dagRunId: string;
+      requestBody: XComCreateBody;
+      taskId: string;
+    },
+    TContext
+  >({
+    mutationFn: ({ dagId, dagRunId, requestBody, taskId }) =>
+      XcomService.createXcomEntry({ dagId, dagRunId, requestBody, taskId }) as 
unknown as Promise<TData>,
+    ...options,
+  });
 /**
  * Post Variable
  * Create a variable.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index bbb7a033c31..080561f6186 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -6022,6 +6022,28 @@ export const $XComCollectionResponse = {
   description: "XCom Collection serializer for responses.",
 } as const;
 
+export const $XComCreateBody = {
+  properties: {
+    key: {
+      type: "string",
+      title: "Key",
+    },
+    value: {
+      title: "Value",
+    },
+    map_index: {
+      type: "integer",
+      title: "Map Index",
+      default: -1,
+    },
+  },
+  additionalProperties: false,
+  type: "object",
+  required: ["key", "value"],
+  title: "XComCreateBody",
+  description: "Payload serializer for creating an XCom entry.",
+} as const;
+
 export const $XComResponse = {
   properties: {
     key: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 02ce4c2a2c4..30136de4a56 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -177,6 +177,8 @@ import type {
   GetXcomEntryResponse,
   GetXcomEntriesData,
   GetXcomEntriesResponse,
+  CreateXcomEntryData,
+  CreateXcomEntryResponse,
   GetTasksData,
   GetTasksResponse,
   GetTaskData,
@@ -3015,6 +3017,38 @@ export class XcomService {
       },
     });
   }
+
+  /**
+   * Create Xcom Entry
+   * Create an XCom entry.
+   * @param data The data for the request.
+   * @param data.dagId
+   * @param data.taskId
+   * @param data.dagRunId
+   * @param data.requestBody
+   * @returns XComResponseNative Successful Response
+   * @throws ApiError
+   */
+  public static createXcomEntry(data: CreateXcomEntryData): 
CancelablePromise<CreateXcomEntryResponse> {
+    return __request(OpenAPI, {
+      method: "POST",
+      url: 
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries",
+      path: {
+        dag_id: data.dagId,
+        task_id: data.taskId,
+        dag_run_id: data.dagRunId,
+      },
+      body: data.requestBody,
+      mediaType: "application/json",
+      errors: {
+        400: "Bad Request",
+        401: "Unauthorized",
+        403: "Forbidden",
+        404: "Not Found",
+        422: "Validation Error",
+      },
+    });
+  }
 }
 
 export class TaskService {
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 4d7a5ccc2d5..31a392466c8 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -1478,6 +1478,15 @@ export type XComCollectionResponse = {
   total_entries: number;
 };
 
+/**
+ * Payload serializer for creating an XCom entry.
+ */
+export type XComCreateBody = {
+  key: string;
+  value: unknown;
+  map_index?: number;
+};
+
 /**
  * Serializer for a xcom item.
  */
@@ -2307,6 +2316,15 @@ export type GetXcomEntriesData = {
 
 export type GetXcomEntriesResponse = XComCollectionResponse;
 
+export type CreateXcomEntryData = {
+  dagId: string;
+  dagRunId: string;
+  requestBody: XComCreateBody;
+  taskId: string;
+};
+
+export type CreateXcomEntryResponse = XComResponseNative;
+
 export type GetTasksData = {
   dagId: string;
   orderBy?: string;
@@ -4718,6 +4736,35 @@ export type $OpenApiTs = {
         422: HTTPValidationError;
       };
     };
+    post: {
+      req: CreateXcomEntryData;
+      res: {
+        /**
+         * Successful Response
+         */
+        201: XComResponseNative;
+        /**
+         * Bad Request
+         */
+        400: HTTPExceptionResponse;
+        /**
+         * Unauthorized
+         */
+        401: HTTPExceptionResponse;
+        /**
+         * Forbidden
+         */
+        403: HTTPExceptionResponse;
+        /**
+         * Not Found
+         */
+        404: HTTPExceptionResponse;
+        /**
+         * Validation Error
+         */
+        422: HTTPValidationError;
+      };
+    };
   };
   "/public/dags/{dag_id}/tasks": {
     get: {
diff --git a/tests/api_fastapi/core_api/routes/public/test_xcom.py 
b/tests/api_fastapi/core_api/routes/public/test_xcom.py
index 103ce049426..7c6ce2c71e6 100644
--- a/tests/api_fastapi/core_api/routes/public/test_xcom.py
+++ b/tests/api_fastapi/core_api/routes/public/test_xcom.py
@@ -20,6 +20,7 @@ from unittest import mock
 
 import pytest
 
+from airflow.api_fastapi.core_api.datamodels.xcom import XComCreateBody
 from airflow.models import XCom
 from airflow.models.dag import DagModel
 from airflow.models.dagrun import DagRun
@@ -64,18 +65,18 @@ def _create_xcom(key, value, backend, session=None) -> None:
 
 
 @provide_session
-def _create_dag_run(session=None) -> None:
-    dagrun = DagRun(
-        dag_id=TEST_DAG_ID,
+def _create_dag_run(dag_maker, session=None):
+    with dag_maker(TEST_DAG_ID, schedule=None, start_date=logical_date_parsed):
+        EmptyOperator(task_id=TEST_TASK_ID)
+    dag_maker.create_dagrun(
         run_id=run_id,
-        logical_date=logical_date_parsed,
-        start_date=logical_date_parsed,
         run_type=DagRunType.MANUAL,
+        logical_date=logical_date_parsed,
     )
-    session.add(dagrun)
-    ti = TaskInstance(EmptyOperator(task_id=TEST_TASK_ID), run_id=run_id)
-    ti.dag_id = TEST_DAG_ID
-    session.add(ti)
+
+    dag_maker.sync_dagbag_to_db()
+    session.merge(dag_maker.dag_model)
+    session.commit()
 
 
 class CustomXCom(BaseXCom):
@@ -95,9 +96,9 @@ class TestXComEndpoint:
         clear_db_xcom()
 
     @pytest.fixture(autouse=True)
-    def setup(self) -> None:
+    def setup(self, dag_maker) -> None:
         self.clear_db()
-        _create_dag_run()
+        _create_dag_run(dag_maker)
 
     def teardown_method(self) -> None:
         self.clear_db()
@@ -214,7 +215,7 @@ class TestGetXComEntry(TestXComEndpoint):
 
 class TestGetXComEntries(TestXComEndpoint):
     @pytest.fixture(autouse=True)
-    def setup(self) -> None:
+    def setup(self, dag_maker) -> None:
         self.clear_db()
 
     def test_should_respond_200(self, test_client):
@@ -496,3 +497,85 @@ class TestPaginationGetXComEntries(TestXComEndpoint):
         assert response_data["total_entries"] == 10
         conn_ids = [conn["key"] for conn in response_data["xcom_entries"] if 
conn]
         assert conn_ids == expected_xcom_ids
+
+
+class TestCreateXComEntry(TestXComEndpoint):
+    @pytest.mark.parametrize(
+        "dag_id, task_id, dag_run_id, request_body, expected_status, 
expected_detail",
+        [
+            # Test case: Valid input, should succeed with 201 CREATED
+            pytest.param(
+                TEST_DAG_ID,
+                TEST_TASK_ID,
+                run_id,
+                XComCreateBody(key=TEST_XCOM_KEY, value=TEST_XCOM_VALUE),
+                201,
+                None,
+                id="valid-xcom-entry",
+            ),
+            # Test case: DAG not found
+            pytest.param(
+                "invalid-dag-id",
+                TEST_TASK_ID,
+                run_id,
+                XComCreateBody(key=TEST_XCOM_KEY, value=TEST_XCOM_VALUE),
+                404,
+                "Dag with ID: `invalid-dag-id` was not found",
+                id="dag-not-found",
+            ),
+            # Test case: Task not found in DAG
+            pytest.param(
+                TEST_DAG_ID,
+                "invalid-task-id",
+                run_id,
+                XComCreateBody(key=TEST_XCOM_KEY, value=TEST_XCOM_VALUE),
+                404,
+                f"Task with ID: `invalid-task-id` not found in DAG: 
`{TEST_DAG_ID}`",
+                id="task-not-found",
+            ),
+            # Test case: DAG Run not found
+            pytest.param(
+                TEST_DAG_ID,
+                TEST_TASK_ID,
+                "invalid-dag-run-id",
+                XComCreateBody(key=TEST_XCOM_KEY, value=TEST_XCOM_VALUE),
+                404,
+                f"DAG Run with ID: `invalid-dag-run-id` not found for DAG: 
`{TEST_DAG_ID}`",
+                id="dag-run-not-found",
+            ),
+            # Test case: XCom entry already exists
+            pytest.param(
+                TEST_DAG_ID,
+                TEST_TASK_ID,
+                run_id,
+                XComCreateBody(key=TEST_XCOM_KEY, value=TEST_XCOM_VALUE),
+                409,
+                f"The XCom with key: `{TEST_XCOM_KEY}` with mentioned task 
instance already exists.",
+                id="xcom-already-exists",
+            ),
+        ],
+    )
+    def test_create_xcom_entry(
+        self, dag_id, task_id, dag_run_id, request_body, expected_status, 
expected_detail, test_client
+    ):
+        # Pre-create an XCom entry to test conflict case
+        if expected_status == 409:
+            self._create_xcom(TEST_XCOM_KEY, TEST_XCOM_VALUE)
+
+        response = test_client.post(
+            
f"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries",
+            json=request_body.dict(),
+        )
+
+        assert response.status_code == expected_status
+        if expected_detail:
+            assert response.json()["detail"] == expected_detail
+        elif expected_status == 201:
+            # Validate the created XCom response
+            current_data = response.json()
+            assert current_data["key"] == request_body.key
+            assert current_data["value"] == 
XCom.serialize_value(request_body.value)
+            assert current_data["dag_id"] == dag_id
+            assert current_data["task_id"] == task_id
+            assert current_data["run_id"] == dag_run_id
+            assert current_data["map_index"] == request_body.map_index

Reply via email to