This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f57db717a31 AIP-84 Add ability to update dag run note in PATCH dag_run 
endpoint  (#43508)
f57db717a31 is described below

commit f57db717a31b2735699504a6f087bf94ef82fc66
Author: Kalyan R <[email protected]>
AuthorDate: Fri Nov 8 18:38:57 2024 +0530

    AIP-84 Add ability to update dag run note in PATCH dag_run endpoint  
(#43508)
    
    * include dag_run_note in update_mask
    
    * add dag run note
    
    * state can be none
    
    * add test
    
    * Fix tests
    
    * handle edge cases
    
    * add tests
    
    * remove joinedload
    
    * fix update_mask checks
    
    * fix tests
    
    * fix
    
    * remove async
    
    * undo async
    
    * fix
    
    ---------
    
    Co-authored-by: pierrejeambrun <[email protected]>
---
 .../api_fastapi/core_api/openapi/v1-generated.yaml | 16 ++--
 .../api_fastapi/core_api/routes/public/dag_run.py  | 34 +++++----
 .../api_fastapi/core_api/serializers/dag_run.py    |  3 +-
 airflow/ui/openapi-gen/queries/common.ts           |  4 +-
 airflow/ui/openapi-gen/queries/queries.ts          |  8 +-
 airflow/ui/openapi-gen/requests/schemas.gen.ts     | 22 +++++-
 airflow/ui/openapi-gen/requests/services.gen.ts    | 12 +--
 airflow/ui/openapi-gen/requests/types.gen.ts       |  9 ++-
 .../core_api/routes/public/test_dag_run.py         | 88 ++++++++++++++++++----
 9 files changed, 142 insertions(+), 54 deletions(-)

diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 0e9221444af..9b52f3bc003 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1318,9 +1318,9 @@ paths:
     patch:
       tags:
       - DagRun
-      summary: Patch Dag Run State
+      summary: Patch Dag Run
       description: Modify a DAG Run.
-      operationId: patch_dag_run_state
+      operationId: patch_dag_run
       parameters:
       - name: dag_id
         in: path
@@ -3694,10 +3694,16 @@ components:
     DAGRunPatchBody:
       properties:
         state:
-          $ref: '#/components/schemas/DAGRunPatchStates'
+          anyOf:
+          - $ref: '#/components/schemas/DAGRunPatchStates'
+          - type: 'null'
+        note:
+          anyOf:
+          - type: string
+            maxLength: 1000
+          - type: 'null'
+          title: Note
       type: object
-      required:
-      - state
       title: DAGRunPatchBody
       description: DAG Run Serializer for PATCH requests.
     DAGRunPatchStates:
diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py 
b/airflow/api_fastapi/core_api/routes/public/dag_run.py
index b05ed2ba113..7778d7778fa 100644
--- a/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -99,7 +99,7 @@ def delete_dag_run(dag_id: str, dag_run_id: str, session: 
Annotated[Session, Dep
         ]
     ),
 )
-def patch_dag_run_state(
+def patch_dag_run(
     dag_id: str,
     dag_run_id: str,
     patch_body: DAGRunPatchBody,
@@ -121,23 +121,29 @@ def patch_dag_run_state(
         raise HTTPException(status.HTTP_404_NOT_FOUND, f"Dag with id {dag_id} 
was not found")
 
     if update_mask:
-        if update_mask != ["state"]:
-            raise HTTPException(
-                status.HTTP_400_BAD_REQUEST, "Only `state` field can be 
updated through the REST API"
-            )
+        data = patch_body.model_dump(include=set(update_mask))
     else:
-        update_mask = ["state"]
+        data = patch_body.model_dump()
 
-    for attr_name in update_mask:
+    for attr_name, attr_value in data.items():
         if attr_name == "state":
-            state = getattr(patch_body, attr_name)
-            if state == DAGRunPatchStates.SUCCESS:
-                set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, 
commit=True)
-            elif state == DAGRunPatchStates.QUEUED:
-                set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, 
commit=True)
+            attr_value = getattr(patch_body, "state")
+            if attr_value == DAGRunPatchStates.SUCCESS:
+                set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+            elif attr_value == DAGRunPatchStates.QUEUED:
+                set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+            elif attr_value == DAGRunPatchStates.FAILED:
+                set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
+        elif attr_name == "note":
+            # Once Authentication is implemented in this FastAPI app,
+            # user id will be added when updating dag run note
+            # Refer to https://github.com/apache/airflow/issues/43534
+            dag_run = session.get(DagRun, dag_run.id)
+            if dag_run.dag_run_note is None:
+                dag_run.note = (attr_value, None)
             else:
-                set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id, 
commit=True)
+                dag_run.dag_run_note.content = attr_value
 
-    session.refresh(dag_run)
+    dag_run = session.get(DagRun, dag_run.id)
 
     return DAGRunResponse.model_validate(dag_run, from_attributes=True)
diff --git a/airflow/api_fastapi/core_api/serializers/dag_run.py 
b/airflow/api_fastapi/core_api/serializers/dag_run.py
index 15576905611..759c4399fbd 100644
--- a/airflow/api_fastapi/core_api/serializers/dag_run.py
+++ b/airflow/api_fastapi/core_api/serializers/dag_run.py
@@ -37,7 +37,8 @@ class DAGRunPatchStates(str, Enum):
 class DAGRunPatchBody(BaseModel):
     """DAG Run Serializer for PATCH requests."""
 
-    state: DAGRunPatchStates
+    state: DAGRunPatchStates | None = None
+    note: str | None = Field(None, max_length=1000)
 
 
 class DAGRunResponse(BaseModel):
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index 953bb16291f..cfbb945b4e5 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -935,8 +935,8 @@ export type DagServicePatchDagMutationResult = Awaited<
 export type ConnectionServicePatchConnectionMutationResult = Awaited<
   ReturnType<typeof ConnectionService.patchConnection>
 >;
-export type DagRunServicePatchDagRunStateMutationResult = Awaited<
-  ReturnType<typeof DagRunService.patchDagRunState>
+export type DagRunServicePatchDagRunMutationResult = Awaited<
+  ReturnType<typeof DagRunService.patchDagRun>
 >;
 export type PoolServicePatchPoolMutationResult = Awaited<
   ReturnType<typeof PoolService.patchPool>
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index c6f4bd09dd6..3a8d508a8c4 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -1912,7 +1912,7 @@ export const useConnectionServicePatchConnection = <
     ...options,
   });
 /**
- * Patch Dag Run State
+ * Patch Dag Run
  * Modify a DAG Run.
  * @param data The data for the request.
  * @param data.dagId
@@ -1922,8 +1922,8 @@ export const useConnectionServicePatchConnection = <
  * @returns DAGRunResponse Successful Response
  * @throws ApiError
  */
-export const useDagRunServicePatchDagRunState = <
-  TData = Common.DagRunServicePatchDagRunStateMutationResult,
+export const useDagRunServicePatchDagRun = <
+  TData = Common.DagRunServicePatchDagRunMutationResult,
   TError = unknown,
   TContext = unknown,
 >(
@@ -1954,7 +1954,7 @@ export const useDagRunServicePatchDagRunState = <
     TContext
   >({
     mutationFn: ({ dagId, dagRunId, requestBody, updateMask }) =>
-      DagRunService.patchDagRunState({
+      DagRunService.patchDagRun({
         dagId,
         dagRunId,
         requestBody,
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 44bd279a162..b8c43b7ac20 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -981,11 +981,29 @@ export const $DAGResponse = {
 export const $DAGRunPatchBody = {
   properties: {
     state: {
-      $ref: "#/components/schemas/DAGRunPatchStates",
+      anyOf: [
+        {
+          $ref: "#/components/schemas/DAGRunPatchStates",
+        },
+        {
+          type: "null",
+        },
+      ],
+    },
+    note: {
+      anyOf: [
+        {
+          type: "string",
+          maxLength: 1000,
+        },
+        {
+          type: "null",
+        },
+      ],
+      title: "Note",
     },
   },
   type: "object",
-  required: ["state"],
   title: "DAGRunPatchBody",
   description: "DAG Run Serializer for PATCH requests.",
 } as const;
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 45be069d310..8a6cd3e4f70 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -49,8 +49,8 @@ import type {
   GetDagRunResponse,
   DeleteDagRunData,
   DeleteDagRunResponse,
-  PatchDagRunStateData,
-  PatchDagRunStateResponse,
+  PatchDagRunData,
+  PatchDagRunResponse,
   GetDagSourceData,
   GetDagSourceResponse,
   GetEventLogData,
@@ -794,7 +794,7 @@ export class DagRunService {
   }
 
   /**
-   * Patch Dag Run State
+   * Patch Dag Run
    * Modify a DAG Run.
    * @param data The data for the request.
    * @param data.dagId
@@ -804,9 +804,9 @@ export class DagRunService {
    * @returns DAGRunResponse Successful Response
    * @throws ApiError
    */
-  public static patchDagRunState(
-    data: PatchDagRunStateData,
-  ): CancelablePromise<PatchDagRunStateResponse> {
+  public static patchDagRun(
+    data: PatchDagRunData,
+  ): CancelablePromise<PatchDagRunResponse> {
     return __request(OpenAPI, {
       method: "PATCH",
       url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}",
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 8dc0a3188ca..08f174e9b39 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -184,7 +184,8 @@ export type DAGResponse = {
  * DAG Run Serializer for PATCH requests.
  */
 export type DAGRunPatchBody = {
-  state: DAGRunPatchStates;
+  state?: DAGRunPatchStates | null;
+  note?: string | null;
 };
 
 /**
@@ -932,14 +933,14 @@ export type DeleteDagRunData = {
 
 export type DeleteDagRunResponse = void;
 
-export type PatchDagRunStateData = {
+export type PatchDagRunData = {
   dagId: string;
   dagRunId: string;
   requestBody: DAGRunPatchBody;
   updateMask?: Array<string> | null;
 };
 
-export type PatchDagRunStateResponse = DAGRunResponse;
+export type PatchDagRunResponse = DAGRunResponse;
 
 export type GetDagSourceData = {
   accept?: string;
@@ -1775,7 +1776,7 @@ export type $OpenApiTs = {
       };
     };
     patch: {
-      req: PatchDagRunStateData;
+      req: PatchDagRunData;
       res: {
         /**
          * Successful Response
diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py 
b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
index 6c48cece798..64c3512e88b 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -50,7 +50,7 @@ DAG2_RUN1_TRIGGERED_BY = DagRunTriggeredByType.CLI
 DAG2_RUN2_TRIGGERED_BY = DagRunTriggeredByType.REST_API
 START_DATE = datetime(2024, 6, 15, 0, 0, tzinfo=timezone.utc)
 EXECUTION_DATE = datetime(2024, 6, 16, 0, 0, tzinfo=timezone.utc)
-DAG1_NOTE = "test_note"
+DAG1_RUN1_NOTE = "test_note"
 
 
 @pytest.fixture(autouse=True)
@@ -66,13 +66,13 @@ def setup(dag_maker, session=None):
         start_date=START_DATE,
     ):
         EmptyOperator(task_id="task_1")
-    dag1 = dag_maker.create_dagrun(
+    dag_run1 = dag_maker.create_dagrun(
         run_id=DAG1_RUN1_ID,
         state=DAG1_RUN1_STATE,
         run_type=DAG1_RUN1_RUN_TYPE,
         triggered_by=DAG1_RUN1_TRIGGERED_BY,
     )
-    dag1.note = (DAG1_NOTE, 1)
+    dag_run1.note = (DAG1_RUN1_NOTE, 1)
 
     dag_maker.create_dagrun(
         run_id=DAG1_RUN2_ID,
@@ -114,7 +114,14 @@ class TestGetDagRun:
     @pytest.mark.parametrize(
         "dag_id, run_id, state, run_type, triggered_by, dag_run_note",
         [
-            (DAG1_ID, DAG1_RUN1_ID, DAG1_RUN1_STATE, DAG1_RUN1_RUN_TYPE, 
DAG1_RUN1_TRIGGERED_BY, DAG1_NOTE),
+            (
+                DAG1_ID,
+                DAG1_RUN1_ID,
+                DAG1_RUN1_STATE,
+                DAG1_RUN1_RUN_TYPE,
+                DAG1_RUN1_TRIGGERED_BY,
+                DAG1_RUN1_NOTE,
+            ),
             (DAG1_ID, DAG1_RUN2_ID, DAG1_RUN2_STATE, DAG1_RUN2_RUN_TYPE, 
DAG1_RUN2_TRIGGERED_BY, None),
             (DAG2_ID, DAG2_RUN1_ID, DAG2_RUN1_STATE, DAG2_RUN1_RUN_TYPE, 
DAG2_RUN1_TRIGGERED_BY, None),
             (DAG2_ID, DAG2_RUN2_ID, DAG2_RUN2_STATE, DAG2_RUN2_RUN_TYPE, 
DAG2_RUN2_TRIGGERED_BY, None),
@@ -140,36 +147,85 @@ class TestGetDagRun:
 
 class TestPatchDagRun:
     @pytest.mark.parametrize(
-        "dag_id, run_id, state, response_state",
+        "dag_id, run_id, patch_body, response_body",
         [
-            (DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED),
-            (DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS),
-            (DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED),
+            (
+                DAG1_ID,
+                DAG1_RUN1_ID,
+                {"state": DagRunState.FAILED, "note": "new_note2"},
+                {"state": DagRunState.FAILED, "note": "new_note2"},
+            ),
+            (
+                DAG1_ID,
+                DAG1_RUN2_ID,
+                {"state": DagRunState.SUCCESS},
+                {"state": DagRunState.SUCCESS, "note": None},
+            ),
+            (
+                DAG2_ID,
+                DAG2_RUN1_ID,
+                {"state": DagRunState.QUEUED},
+                {"state": DagRunState.QUEUED, "note": None},
+            ),
+            (
+                DAG1_ID,
+                DAG1_RUN1_ID,
+                {"note": "updated note"},
+                {"state": DagRunState.SUCCESS, "note": "updated note"},
+            ),
+            (
+                DAG1_ID,
+                DAG1_RUN2_ID,
+                {"note": "new note", "state": DagRunState.FAILED},
+                {"state": DagRunState.FAILED, "note": "new note"},
+            ),
+            (DAG1_ID, DAG1_RUN2_ID, {"note": None}, {"state": 
DagRunState.FAILED, "note": None}),
         ],
     )
-    def test_patch_dag_run(self, test_client, dag_id, run_id, state, 
response_state):
-        response = 
test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state": 
state})
+    def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, 
response_body):
+        response = 
test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json=patch_body)
         assert response.status_code == 200
         body = response.json()
         assert body["dag_id"] == dag_id
         assert body["run_id"] == run_id
-        assert body["state"] == response_state
+        assert body.get("state") == response_body.get("state")
+        assert body.get("note") == response_body.get("note")
 
     @pytest.mark.parametrize(
-        "query_params, patch_body, expected_status_code",
+        "query_params, patch_body, response_body, expected_status_code",
         [
-            ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 200),
-            ({}, {"state": DagRunState.SUCCESS}, 200),
-            ({"update_mask": ["random"]}, {"state": DagRunState.SUCCESS}, 400),
+            ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 
{"state": "success"}, 200),
+            (
+                {"update_mask": ["note"]},
+                {"state": DagRunState.FAILED, "note": "new_note1"},
+                {"note": "new_note1", "state": "success"},
+                200,
+            ),
+            (
+                {},
+                {"state": DagRunState.FAILED, "note": "new_note2"},
+                {"note": "new_note2", "state": "failed"},
+                200,
+            ),
+            ({"update_mask": ["note"]}, {}, {"state": "success", "note": 
None}, 200),
+            (
+                {"update_mask": ["random"]},
+                {"state": DagRunState.FAILED},
+                {"state": "success", "note": "test_note"},
+                200,
+            ),
         ],
     )
     def test_patch_dag_run_with_update_mask(
-        self, test_client, query_params, patch_body, expected_status_code
+        self, test_client, query_params, patch_body, response_body, 
expected_status_code
     ):
         response = test_client.patch(
             f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", 
params=query_params, json=patch_body
         )
+        response_json = response.json()
         assert response.status_code == expected_status_code
+        for key, value in response_body.items():
+            assert response_json.get(key) == value
 
     def test_patch_dag_run_not_found(self, test_client):
         response = test_client.patch(

Reply via email to