This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e4b344b2f AIP-84 Migrate Modify Dag Run endpoint to FastAPI (#42973)
3e4b344b2f is described below

commit 3e4b344b2fb051dc49cba3cbfa6e623cde271657
Author: Kalyan R <[email protected]>
AuthorDate: Tue Oct 29 21:13:42 2024 +0530

    AIP-84 Migrate Modify Dag Run endpoint to FastAPI (#42973)
    
    * add modify_dag_run
    
    * add tests
    
    * Update airflow/api_fastapi/views/public/dag_run.py
    
    * fix
    
    * Update airflow/api_fastapi/routes/public/dag_run.py
    
    Co-authored-by: Pierre Jeambrun <[email protected]>
    
    * Update airflow/api_fastapi/serializers/dag_run.py
    
    Co-authored-by: Pierre Jeambrun <[email protected]>
    
    * use dagbag
    
    * replace patch with put
    
    * refactor
    
    * use put in tests
    
    * modify to patch
    
    * add update_mask
    
    * refactor update to patch
    
    ---------
    
    Co-authored-by: Pierre Jeambrun <[email protected]>
---
 .../api_connexion/endpoints/dag_run_endpoint.py    |  1 +
 .../api_fastapi/core_api/openapi/v1-generated.yaml | 89 ++++++++++++++++++++++
 .../api_fastapi/core_api/routes/public/dag_run.py  | 57 +++++++++++++-
 .../api_fastapi/core_api/serializers/dag_run.py    | 15 ++++
 airflow/ui/openapi-gen/queries/common.ts           |  3 +
 airflow/ui/openapi-gen/queries/queries.ts          | 52 +++++++++++++
 airflow/ui/openapi-gen/requests/schemas.gen.ts     | 19 +++++
 airflow/ui/openapi-gen/requests/services.gen.ts    | 38 +++++++++
 airflow/ui/openapi-gen/requests/types.gen.ts       | 50 ++++++++++++
 .../core_api/routes/public/test_dag_run.py         | 50 ++++++++++++
 10 files changed, 371 insertions(+), 3 deletions(-)

diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py 
b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index e0d8357561..8ebb2b44e2 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -373,6 +373,7 @@ def post_dag_run(*, dag_id: str, session: Session = 
NEW_SESSION) -> APIResponse:
     raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID: 
'{run_id}' already exists")
 
 
+@mark_fastapi_migration_done
 @security.requires_access_dag("PUT", DagAccessEntity.RUN)
 @provide_session
 @action_logging
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 3d1b1611ad..ae5fc9e117 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1156,6 +1156,78 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+    patch:
+      tags:
+      - DagRun
+      summary: Patch Dag Run State
+      description: Modify a DAG Run.
+      operationId: patch_dag_run_state
+      parameters:
+      - name: dag_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Id
+      - name: dag_run_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Run Id
+      - name: update_mask
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: array
+            items:
+              type: string
+          - type: 'null'
+          title: Update Mask
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/DAGRunPatchBody'
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/DAGRunResponse'
+        '400':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Bad Request
+        '401':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Unauthorized
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '404':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Not Found
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
   /public/monitor/health:
     get:
       tags:
@@ -2079,6 +2151,23 @@ components:
       - file_token
       title: DAGResponse
       description: DAG serializer for responses.
+    DAGRunPatchBody:
+      properties:
+        state:
+          $ref: '#/components/schemas/DAGRunPatchStates'
+      type: object
+      required:
+      - state
+      title: DAGRunPatchBody
+      description: DAG Run Serializer for PATCH requests.
+    DAGRunPatchStates:
+      type: string
+      enum:
+      - queued
+      - success
+      - failed
+      title: DAGRunPatchStates
+      description: Enum for DAG Run states when updating a DAG Run.
     DAGRunResponse:
       properties:
         run_id:
diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py 
b/airflow/api_fastapi/core_api/routes/public/dag_run.py
index 035d1b7fd7..02780d6088 100644
--- a/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -17,16 +17,25 @@
 
 from __future__ import annotations
 
-from fastapi import Depends, HTTPException
+from fastapi import Depends, HTTPException, Query, Request
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 from typing_extensions import Annotated
 
+from airflow.api.common.mark_tasks import (
+    set_dag_run_state_to_failed,
+    set_dag_run_state_to_queued,
+    set_dag_run_state_to_success,
+)
 from airflow.api_fastapi.common.db.common import get_session
 from airflow.api_fastapi.common.router import AirflowRouter
 from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.serializers.dag_run import DAGRunResponse
-from airflow.models import DagRun
+from airflow.api_fastapi.core_api.serializers.dag_run import (
+    DAGRunPatchBody,
+    DAGRunPatchStates,
+    DAGRunResponse,
+)
+from airflow.models import DAG, DagRun
 
 dag_run_router = AirflowRouter(tags=["DagRun"], 
prefix="/dags/{dag_id}/dagRuns")
 
@@ -57,3 +66,45 @@ async def delete_dag_run(dag_id: str, dag_run_id: str, 
session: Annotated[Sessio
         )
 
     session.delete(dag_run)
+
+
+@dag_run_router.patch("/{dag_run_id}", 
responses=create_openapi_http_exception_doc([400, 401, 403, 404]))
+async def patch_dag_run_state(
+    dag_id: str,
+    dag_run_id: str,
+    patch_body: DAGRunPatchBody,
+    session: Annotated[Session, Depends(get_session)],
+    request: Request,
+    update_mask: list[str] | None = Query(None),
+) -> DAGRunResponse:
+    """Modify a DAG Run."""
+    dag_run = session.scalar(select(DagRun).filter_by(dag_id=dag_id, 
run_id=dag_run_id))
+    if dag_run is None:
+        raise HTTPException(
+            404, f"The DagRun with dag_id: `{dag_id}` and run_id: 
`{dag_run_id}` was not found"
+        )
+
+    dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
+
+    if not dag:
+        raise HTTPException(404, f"Dag with id {dag_id} was not found")
+
+    if update_mask:
+        if update_mask != ["state"]:
+            raise HTTPException(400, "Only `state` field can be updated 
through the REST API")
+    else:
+        update_mask = ["state"]
+
+    for attr_name in update_mask:
+        if attr_name == "state":
+            state = getattr(patch_body, attr_name)
+            if state == DAGRunPatchStates.SUCCESS:
+                set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, 
commit=True)
+            elif state == DAGRunPatchStates.QUEUED:
+                set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, 
commit=True)
+            else:
+                set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, 
commit=True)
+
+    dag_run = session.get(DagRun, dag_run.id)
+
+    return DAGRunResponse.model_validate(dag_run, from_attributes=True)
diff --git a/airflow/api_fastapi/core_api/serializers/dag_run.py 
b/airflow/api_fastapi/core_api/serializers/dag_run.py
index 4622fac645..1557690561 100644
--- a/airflow/api_fastapi/core_api/serializers/dag_run.py
+++ b/airflow/api_fastapi/core_api/serializers/dag_run.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from datetime import datetime
+from enum import Enum
 
 from pydantic import BaseModel, Field
 
@@ -25,6 +26,20 @@ from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 
+class DAGRunPatchStates(str, Enum):
+    """Enum for DAG Run states when updating a DAG Run."""
+
+    QUEUED = DagRunState.QUEUED
+    SUCCESS = DagRunState.SUCCESS
+    FAILED = DagRunState.FAILED
+
+
+class DAGRunPatchBody(BaseModel):
+    """DAG Run Serializer for PATCH requests."""
+
+    state: DAGRunPatchStates
+
+
 class DAGRunResponse(BaseModel):
     """DAG Run serializer for responses."""
 
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index 959b476718..b8c83b1525 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -441,6 +441,9 @@ export type DagServicePatchDagMutationResult = Awaited<
 export type VariableServicePatchVariableMutationResult = Awaited<
   ReturnType<typeof VariableService.patchVariable>
 >;
+export type DagRunServicePatchDagRunStateMutationResult = Awaited<
+  ReturnType<typeof DagRunService.patchDagRunState>
+>;
 export type PoolServicePatchPoolMutationResult = Awaited<
   ReturnType<typeof PoolService.patchPool>
 >;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index a3aed2e793..a0d6a65853 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -23,6 +23,7 @@ import {
 } from "../requests/services.gen";
 import {
   DAGPatchBody,
+  DAGRunPatchBody,
   DagRunState,
   PoolPatchBody,
   PoolPostBody,
@@ -948,6 +949,57 @@ export const useVariableServicePatchVariable = <
       }) as unknown as Promise<TData>,
     ...options,
   });
+/**
+ * Patch Dag Run State
+ * Modify a DAG Run.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.requestBody
+ * @param data.updateMask
+ * @returns DAGRunResponse Successful Response
+ * @throws ApiError
+ */
+export const useDagRunServicePatchDagRunState = <
+  TData = Common.DagRunServicePatchDagRunStateMutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        dagId: string;
+        dagRunId: string;
+        requestBody: DAGRunPatchBody;
+        updateMask?: string[];
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      dagId: string;
+      dagRunId: string;
+      requestBody: DAGRunPatchBody;
+      updateMask?: string[];
+    },
+    TContext
+  >({
+    mutationFn: ({ dagId, dagRunId, requestBody, updateMask }) =>
+      DagRunService.patchDagRunState({
+        dagId,
+        dagRunId,
+        requestBody,
+        updateMask,
+      }) as unknown as Promise<TData>,
+    ...options,
+  });
 /**
  * Patch Pool
  * Update a Pool.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index b1a5b267e1..3982407c5f 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -887,6 +887,25 @@ export const $DAGResponse = {
   description: "DAG serializer for responses.",
 } as const;
 
+export const $DAGRunPatchBody = {
+  properties: {
+    state: {
+      $ref: "#/components/schemas/DAGRunPatchStates",
+    },
+  },
+  type: "object",
+  required: ["state"],
+  title: "DAGRunPatchBody",
+  description: "DAG Run Serializer for PATCH requests.",
+} as const;
+
+export const $DAGRunPatchStates = {
+  type: "string",
+  enum: ["queued", "success", "failed"],
+  title: "DAGRunPatchStates",
+  description: "Enum for DAG Run states when updating a DAG Run.",
+} as const;
+
 export const $DAGRunResponse = {
   properties: {
     run_id: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 4db1e052a2..56207631d1 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -43,6 +43,8 @@ import type {
   GetDagRunResponse,
   DeleteDagRunData,
   DeleteDagRunResponse,
+  PatchDagRunStateData,
+  PatchDagRunStateResponse,
   GetHealthResponse,
   DeletePoolData,
   DeletePoolResponse,
@@ -672,6 +674,42 @@ export class DagRunService {
       },
     });
   }
+
+  /**
+   * Patch Dag Run State
+   * Modify a DAG Run.
+   * @param data The data for the request.
+   * @param data.dagId
+   * @param data.dagRunId
+   * @param data.requestBody
+   * @param data.updateMask
+   * @returns DAGRunResponse Successful Response
+   * @throws ApiError
+   */
+  public static patchDagRunState(
+    data: PatchDagRunStateData,
+  ): CancelablePromise<PatchDagRunStateResponse> {
+    return __request(OpenAPI, {
+      method: "PATCH",
+      url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}",
+      path: {
+        dag_id: data.dagId,
+        dag_run_id: data.dagRunId,
+      },
+      query: {
+        update_mask: data.updateMask,
+      },
+      body: data.requestBody,
+      mediaType: "application/json",
+      errors: {
+        400: "Bad Request",
+        401: "Unauthorized",
+        403: "Forbidden",
+        404: "Not Found",
+        422: "Validation Error",
+      },
+    });
+  }
 }
 
 export class MonitorService {
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 288916b392..a3b1d8e6be 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -154,6 +154,18 @@ export type DAGResponse = {
   readonly file_token: string;
 };
 
+/**
+ * DAG Run Serializer for PATCH requests.
+ */
+export type DAGRunPatchBody = {
+  state: DAGRunPatchStates;
+};
+
+/**
+ * Enum for DAG Run states when updating a DAG Run.
+ */
+export type DAGRunPatchStates = "queued" | "success" | "failed";
+
 /**
  * DAG Run serializer for responses.
  */
@@ -680,6 +692,15 @@ export type DeleteDagRunData = {
 
 export type DeleteDagRunResponse = void;
 
+export type PatchDagRunStateData = {
+  dagId: string;
+  dagRunId: string;
+  requestBody: DAGRunPatchBody;
+  updateMask?: Array<string> | null;
+};
+
+export type PatchDagRunStateResponse = DAGRunResponse;
+
 export type GetHealthResponse = HealthInfoSchema;
 
 export type DeletePoolData = {
@@ -1236,6 +1257,35 @@ export type $OpenApiTs = {
         422: HTTPValidationError;
       };
     };
+    patch: {
+      req: PatchDagRunStateData;
+      res: {
+        /**
+         * Successful Response
+         */
+        200: DAGRunResponse;
+        /**
+         * Bad Request
+         */
+        400: HTTPExceptionResponse;
+        /**
+         * Unauthorized
+         */
+        401: HTTPExceptionResponse;
+        /**
+         * Forbidden
+         */
+        403: HTTPExceptionResponse;
+        /**
+         * Not Found
+         */
+        404: HTTPExceptionResponse;
+        /**
+         * Validation Error
+         */
+        422: HTTPValidationError;
+      };
+    };
   };
   "/public/monitor/health": {
     get: {
diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py 
b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
index 554bc73eba..dfd48af2fa 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -138,6 +138,56 @@ class TestGetDagRun:
         assert body["detail"] == "The DagRun with dag_id: `test_dag1` and 
run_id: `invalid` was not found"
 
 
+class TestPatchDagRun:
+    @pytest.mark.parametrize(
+        "dag_id, run_id, state, response_state",
+        [
+            (DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED),
+            (DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS),
+            (DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED),
+        ],
+    )
+    def test_patch_dag_run(self, test_client, dag_id, run_id, state, 
response_state):
+        response = 
test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state": 
state})
+        assert response.status_code == 200
+        body = response.json()
+        assert body["dag_id"] == dag_id
+        assert body["run_id"] == run_id
+        assert body["state"] == response_state
+
+    @pytest.mark.parametrize(
+        "query_params,patch_body, expected_status_code",
+        [
+            ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 200),
+            ({}, {"state": DagRunState.SUCCESS}, 200),
+            ({"update_mask": ["random"]}, {"state": DagRunState.SUCCESS}, 400),
+        ],
+    )
+    def test_patch_dag_run_with_update_mask(
+        self, test_client, query_params, patch_body, expected_status_code
+    ):
+        response = test_client.patch(
+            f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", 
params=query_params, json=patch_body
+        )
+        assert response.status_code == expected_status_code
+
+    def test_patch_dag_run_not_found(self, test_client):
+        response = test_client.patch(
+            f"/public/dags/{DAG1_ID}/dagRuns/invalid", json={"state": 
DagRunState.SUCCESS}
+        )
+        assert response.status_code == 404
+        body = response.json()
+        assert body["detail"] == "The DagRun with dag_id: `test_dag1` and 
run_id: `invalid` was not found"
+
+    def test_patch_dag_run_bad_request(self, test_client):
+        response = test_client.patch(
+            f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state": 
"running"}
+        )
+        assert response.status_code == 422
+        body = response.json()
+        assert body["detail"][0]["msg"] == "Input should be 'queued', 
'success' or 'failed'"
+
+
 class TestDeleteDagRun:
     def test_delete_dag_run(self, test_client):
         response = 
test_client.delete(f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}")

Reply via email to