This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1c61d13b45 Migrate patch dag to FastAPI API (#42469)
1c61d13b45 is described below

commit 1c61d13b454178b156bdce497d2949ff2ba3f7da
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Thu Sep 26 18:16:12 2024 +0800

    Migrate patch dag to FastAPI API (#42469)
---
 airflow/api_connexion/endpoints/dag_endpoint.py |  1 +
 airflow/api_fastapi/openapi/v1-generated.yaml   | 59 +++++++++++++++++++++++--
 airflow/api_fastapi/serializers/dags.py         | 10 ++++-
 airflow/api_fastapi/views/public/dags.py        | 34 ++++++++++++--
 airflow/ui/openapi-gen/queries/common.ts        |  3 ++
 airflow/ui/openapi-gen/queries/queries.ts       | 55 ++++++++++++++++++++++-
 airflow/ui/openapi-gen/requests/schemas.gen.ts  | 19 ++++++--
 airflow/ui/openapi-gen/requests/services.gen.ts | 32 ++++++++++++++
 airflow/ui/openapi-gen/requests/types.gen.ts    | 34 +++++++++++++-
 airflow/ui/src/pages/DagsList.tsx               |  4 +-
 tests/api_fastapi/views/public/test_dags.py     | 20 +++++++++
 11 files changed, 254 insertions(+), 17 deletions(-)

diff --git a/airflow/api_connexion/endpoints/dag_endpoint.py 
b/airflow/api_connexion/endpoints/dag_endpoint.py
index 08d36f9978..6fca5ae7c9 100644
--- a/airflow/api_connexion/endpoints/dag_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_endpoint.py
@@ -141,6 +141,7 @@ def get_dags(
         raise BadRequest("DAGCollectionSchema error", detail=str(e))
 
 
+@mark_fastapi_migration_done
 @security.requires_access_dag("PUT")
 @action_logging
 @provide_session
diff --git a/airflow/api_fastapi/openapi/v1-generated.yaml 
b/airflow/api_fastapi/openapi/v1-generated.yaml
index b0037b372b..6d77056d05 100644
--- a/airflow/api_fastapi/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/openapi/v1-generated.yaml
@@ -123,13 +123,56 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+  /public/dags/{dag_id}:
+    patch:
+      tags:
+      - DAG
+      summary: Patch Dag
+      description: Update the specific DAG.
+      operationId: patch_dag_public_dags__dag_id__patch
+      parameters:
+      - name: dag_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Id
+      - name: update_mask
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: array
+            items:
+              type: string
+          - type: 'null'
+          title: Update Mask
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/DAGPatchBody'
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/DAGResponse'
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
 components:
   schemas:
     DAGCollectionResponse:
       properties:
         dags:
           items:
-            $ref: '#/components/schemas/DAGModelResponse'
+            $ref: '#/components/schemas/DAGResponse'
           type: array
           title: Dags
         total_entries:
@@ -141,7 +184,17 @@ components:
       - total_entries
       title: DAGCollectionResponse
       description: DAG Collection serializer for responses.
-    DAGModelResponse:
+    DAGPatchBody:
+      properties:
+        is_paused:
+          type: boolean
+          title: Is Paused
+      type: object
+      required:
+      - is_paused
+      title: DAGPatchBody
+      description: Dag Serializer for updatable body.
+    DAGResponse:
       properties:
         dag_id:
           type: string
@@ -292,7 +345,7 @@ components:
       - next_dagrun_create_after
       - owners
       - file_token
-      title: DAGModelResponse
+      title: DAGResponse
       description: DAG serializer for responses.
     DagTagPydantic:
       properties:
diff --git a/airflow/api_fastapi/serializers/dags.py 
b/airflow/api_fastapi/serializers/dags.py
index 264f549e29..59b47bdef9 100644
--- a/airflow/api_fastapi/serializers/dags.py
+++ b/airflow/api_fastapi/serializers/dags.py
@@ -31,7 +31,7 @@ from airflow.configuration import conf
 from airflow.serialization.pydantic.dag import DagTagPydantic
 
 
-class DAGModelResponse(BaseModel):
+class DAGResponse(BaseModel):
     """DAG serializer for responses."""
 
     dag_id: str
@@ -82,8 +82,14 @@ class DAGModelResponse(BaseModel):
         return serializer.dumps(self.fileloc)
 
 
+class DAGPatchBody(BaseModel):
+    """Dag Serializer for updatable body."""
+
+    is_paused: bool
+
+
 class DAGCollectionResponse(BaseModel):
     """DAG Collection serializer for responses."""
 
-    dags: list[DAGModelResponse]
+    dags: list[DAGResponse]
     total_entries: int
diff --git a/airflow/api_fastapi/views/public/dags.py 
b/airflow/api_fastapi/views/public/dags.py
index a1957d3073..433e5ef862 100644
--- a/airflow/api_fastapi/views/public/dags.py
+++ b/airflow/api_fastapi/views/public/dags.py
@@ -17,7 +17,7 @@
 
 from __future__ import annotations
 
-from fastapi import APIRouter, Depends, HTTPException
+from fastapi import APIRouter, Depends, HTTPException, Query
 from sqlalchemy import select
 from sqlalchemy.orm import Session
 from typing_extensions import Annotated
@@ -34,7 +34,7 @@ from airflow.api_fastapi.parameters import (
     QueryTagsFilter,
     SortParam,
 )
-from airflow.api_fastapi.serializers.dags import DAGCollectionResponse, 
DAGModelResponse
+from airflow.api_fastapi.serializers.dags import DAGCollectionResponse, 
DAGPatchBody, DAGResponse
 from airflow.models import DagModel
 from airflow.utils.db import get_query_count
 
@@ -43,7 +43,6 @@ dags_router = APIRouter(tags=["DAG"])
 
 @dags_router.get("/dags")
 async def get_dags(
-    *,
     limit: QueryLimit,
     offset: QueryOffset,
     tags: QueryTagsFilter,
@@ -74,8 +73,35 @@ async def get_dags(
 
     try:
         return DAGCollectionResponse(
-            dags=[DAGModelResponse.model_validate(dag, from_attributes=True) 
for dag in dags],
+            dags=[DAGResponse.model_validate(dag, from_attributes=True) for 
dag in dags],
             total_entries=total_entries,
         )
     except ValueError as e:
         raise HTTPException(400, f"DAGCollectionSchema error: {str(e)}")
+
+
+@dags_router.patch("/dags/{dag_id}")
+async def patch_dag(
+    dag_id: str,
+    patch_body: DAGPatchBody,
+    session: Annotated[Session, Depends(get_session)],
+    update_mask: list[str] | None = Query(None),
+) -> DAGResponse:
+    """Update the specific DAG."""
+    dag = session.get(DagModel, dag_id)
+
+    if dag is None:
+        raise HTTPException(404, f"Dag with id: {dag_id} was not found")
+
+    if update_mask:
+        if update_mask != ["is_paused"]:
+            raise HTTPException(400, "Only `is_paused` field can be updated 
through the REST API")
+
+    else:
+        update_mask = ["is_paused"]
+
+    for attr_name in update_mask:
+        attr_value = getattr(patch_body, attr_name)
+        setattr(dag, attr_name, attr_value)
+
+    return DAGResponse.model_validate(dag, from_attributes=True)
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index 143ec83c55..2818b48a33 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -72,3 +72,6 @@ export const UseDagServiceGetDagsPublicDagsGetKeyFn = (
     },
   ]),
 ];
+export type DagServicePatchDagPublicDagsDagIdPatchMutationResult = Awaited<
+  ReturnType<typeof DagService.patchDagPublicDagsDagIdPatch>
+>;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index 9dce528f2a..2a0c6b6821 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -1,7 +1,13 @@
 // generated with @7nohe/[email protected]
-import { useQuery, UseQueryOptions } from "@tanstack/react-query";
+import {
+  useMutation,
+  UseMutationOptions,
+  useQuery,
+  UseQueryOptions,
+} from "@tanstack/react-query";
 
 import { DagService, DatasetService } from "../requests/services.gen";
+import { DAGPatchBody } from "../requests/types.gen";
 import * as Common from "./common";
 
 /**
@@ -110,3 +116,50 @@ export const useDagServiceGetDagsPublicDagsGet = <
       }) as TData,
     ...options,
   });
+/**
+ * Patch Dag
+ * Update the specific DAG.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.requestBody
+ * @param data.updateMask
+ * @returns DAGResponse Successful Response
+ * @throws ApiError
+ */
+export const useDagServicePatchDagPublicDagsDagIdPatch = <
+  TData = Common.DagServicePatchDagPublicDagsDagIdPatchMutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        dagId: string;
+        requestBody: DAGPatchBody;
+        updateMask?: string[];
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      dagId: string;
+      requestBody: DAGPatchBody;
+      updateMask?: string[];
+    },
+    TContext
+  >({
+    mutationFn: ({ dagId, requestBody, updateMask }) =>
+      DagService.patchDagPublicDagsDagIdPatch({
+        dagId,
+        requestBody,
+        updateMask,
+      }) as unknown as Promise<TData>,
+    ...options,
+  });
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 64cddab30b..83d3670507 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -4,7 +4,7 @@ export const $DAGCollectionResponse = {
   properties: {
     dags: {
       items: {
-        $ref: "#/components/schemas/DAGModelResponse",
+        $ref: "#/components/schemas/DAGResponse",
       },
       type: "array",
       title: "Dags",
@@ -20,7 +20,20 @@ export const $DAGCollectionResponse = {
   description: "DAG Collection serializer for responses.",
 } as const;
 
-export const $DAGModelResponse = {
+export const $DAGPatchBody = {
+  properties: {
+    is_paused: {
+      type: "boolean",
+      title: "Is Paused",
+    },
+  },
+  type: "object",
+  required: ["is_paused"],
+  title: "DAGPatchBody",
+  description: "Dag Serializer for updatable body.",
+} as const;
+
+export const $DAGResponse = {
   properties: {
     dag_id: {
       type: "string",
@@ -271,7 +284,7 @@ export const $DAGModelResponse = {
     "owners",
     "file_token",
   ],
-  title: "DAGModelResponse",
+  title: "DAGResponse",
   description: "DAG serializer for responses.",
 } as const;
 
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index e0786e9137..a4c36d5990 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -7,6 +7,8 @@ import type {
   NextRunDatasetsUiNextRunDatasetsDagIdGetResponse,
   GetDagsPublicDagsGetData,
   GetDagsPublicDagsGetResponse,
+  PatchDagPublicDagsDagIdPatchData,
+  PatchDagPublicDagsDagIdPatchResponse,
 } from "./types.gen";
 
 export class DatasetService {
@@ -72,4 +74,34 @@ export class DagService {
       },
     });
   }
+
+  /**
+   * Patch Dag
+   * Update the specific DAG.
+   * @param data The data for the request.
+   * @param data.dagId
+   * @param data.requestBody
+   * @param data.updateMask
+   * @returns DAGResponse Successful Response
+   * @throws ApiError
+   */
+  public static patchDagPublicDagsDagIdPatch(
+    data: PatchDagPublicDagsDagIdPatchData,
+  ): CancelablePromise<PatchDagPublicDagsDagIdPatchResponse> {
+    return __request(OpenAPI, {
+      method: "PATCH",
+      url: "/public/dags/{dag_id}",
+      path: {
+        dag_id: data.dagId,
+      },
+      query: {
+        update_mask: data.updateMask,
+      },
+      body: data.requestBody,
+      mediaType: "application/json",
+      errors: {
+        422: "Validation Error",
+      },
+    });
+  }
 }
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 917dca6626..2f6bc263d4 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -4,14 +4,21 @@
  * DAG Collection serializer for responses.
  */
 export type DAGCollectionResponse = {
-  dags: Array<DAGModelResponse>;
+  dags: Array<DAGResponse>;
   total_entries: number;
 };
 
+/**
+ * Dag Serializer for updatable body.
+ */
+export type DAGPatchBody = {
+  is_paused: boolean;
+};
+
 /**
  * DAG serializer for responses.
  */
-export type DAGModelResponse = {
+export type DAGResponse = {
   dag_id: string;
   dag_display_name: string;
   is_paused: boolean;
@@ -83,6 +90,14 @@ export type GetDagsPublicDagsGetData = {
 
 export type GetDagsPublicDagsGetResponse = DAGCollectionResponse;
 
+export type PatchDagPublicDagsDagIdPatchData = {
+  dagId: string;
+  requestBody: DAGPatchBody;
+  updateMask?: Array<string> | null;
+};
+
+export type PatchDagPublicDagsDagIdPatchResponse = DAGResponse;
+
 export type $OpenApiTs = {
   "/ui/next_run_datasets/{dag_id}": {
     get: {
@@ -116,4 +131,19 @@ export type $OpenApiTs = {
       };
     };
   };
+  "/public/dags/{dag_id}": {
+    patch: {
+      req: PatchDagPublicDagsDagIdPatchData;
+      res: {
+        /**
+         * Successful Response
+         */
+        200: DAGResponse;
+        /**
+         * Validation Error
+         */
+        422: HTTPValidationError;
+      };
+    };
+  };
 };
diff --git a/airflow/ui/src/pages/DagsList.tsx 
b/airflow/ui/src/pages/DagsList.tsx
index e93f281c50..fe764f117e 100644
--- a/airflow/ui/src/pages/DagsList.tsx
+++ b/airflow/ui/src/pages/DagsList.tsx
@@ -31,7 +31,7 @@ import { type ChangeEventHandler, useCallback } from "react";
 import { useSearchParams } from "react-router-dom";
 
 import { useDagServiceGetDagsPublicDagsGet } from "openapi/queries";
-import type { DAGModelResponse } from "openapi/requests/types.gen";
+import type { DAGResponse } from "openapi/requests/types.gen";
 
 import { DataTable } from "../components/DataTable";
 import { useTableURLState } from "../components/DataTable/useTableUrlState";
@@ -39,7 +39,7 @@ import { QuickFilterButton } from 
"../components/QuickFilterButton";
 import { SearchBar } from "../components/SearchBar";
 import { pluralize } from "../utils/pluralize";
 
-const columns: Array<ColumnDef<DAGModelResponse>> = [
+const columns: Array<ColumnDef<DAGResponse>> = [
   {
     accessorKey: "dag_id",
     cell: ({ row }) => row.original.dag_display_name,
diff --git a/tests/api_fastapi/views/public/test_dags.py 
b/tests/api_fastapi/views/public/test_dags.py
index dfba5437a8..b508a14483 100644
--- a/tests/api_fastapi/views/public/test_dags.py
+++ b/tests/api_fastapi/views/public/test_dags.py
@@ -115,3 +115,23 @@ def test_get_dags(test_client, query_params, 
expected_total_entries, expected_id
 
     assert body["total_entries"] == expected_total_entries
     assert [dag["dag_id"] for dag in body["dags"]] == expected_ids
+
+
[email protected](
+    "query_params, dag_id, body, expected_status_code, expected_is_paused",
+    [
+        ({}, "fake_dag_id", {"is_paused": True}, 404, None),
+        ({"update_mask": ["field_1", "is_paused"]}, DAG1_ID, {"is_paused": 
True}, 400, None),
+        ({}, DAG1_ID, {"is_paused": True}, 200, True),
+        ({}, DAG1_ID, {"is_paused": False}, 200, False),
+        ({"update_mask": ["is_paused"]}, DAG1_ID, {"is_paused": True}, 200, 
True),
+        ({"update_mask": ["is_paused"]}, DAG1_ID, {"is_paused": False}, 200, 
False),
+    ],
+)
+def test_patch_dag(test_client, query_params, dag_id, body, 
expected_status_code, expected_is_paused):
+    response = test_client.patch(f"/public/dags/{dag_id}", json=body, 
params=query_params)
+
+    assert response.status_code == expected_status_code
+    if expected_status_code == 200:
+        body = response.json()
+        assert body["is_paused"] == expected_is_paused

Reply via email to