This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3e4b344b2f AIP-84 Migrate Modify Dag Run endpoint to FastAPI (#42973)
3e4b344b2f is described below
commit 3e4b344b2fb051dc49cba3cbfa6e623cde271657
Author: Kalyan R <[email protected]>
AuthorDate: Tue Oct 29 21:13:42 2024 +0530
AIP-84 Migrate Modify Dag Run endpoint to FastAPI (#42973)
* add modify_dag_run
* add tests
* Update airflow/api_fastapi/views/public/dag_run.py
* fix
* Update airflow/api_fastapi/routes/public/dag_run.py
Co-authored-by: Pierre Jeambrun <[email protected]>
* Update airflow/api_fastapi/serializers/dag_run.py
Co-authored-by: Pierre Jeambrun <[email protected]>
* use dagbag
* replace patch with put
* refactor
* use put in tests
* modify to patch
* add update_mask
* refactor update to patch
---------
Co-authored-by: Pierre Jeambrun <[email protected]>
---
.../api_connexion/endpoints/dag_run_endpoint.py | 1 +
.../api_fastapi/core_api/openapi/v1-generated.yaml | 89 ++++++++++++++++++++++
.../api_fastapi/core_api/routes/public/dag_run.py | 57 +++++++++++++-
.../api_fastapi/core_api/serializers/dag_run.py | 15 ++++
airflow/ui/openapi-gen/queries/common.ts | 3 +
airflow/ui/openapi-gen/queries/queries.ts | 52 +++++++++++++
airflow/ui/openapi-gen/requests/schemas.gen.ts | 19 +++++
airflow/ui/openapi-gen/requests/services.gen.ts | 38 +++++++++
airflow/ui/openapi-gen/requests/types.gen.ts | 50 ++++++++++++
.../core_api/routes/public/test_dag_run.py | 50 ++++++++++++
10 files changed, 371 insertions(+), 3 deletions(-)
diff --git a/airflow/api_connexion/endpoints/dag_run_endpoint.py
b/airflow/api_connexion/endpoints/dag_run_endpoint.py
index e0d8357561..8ebb2b44e2 100644
--- a/airflow/api_connexion/endpoints/dag_run_endpoint.py
+++ b/airflow/api_connexion/endpoints/dag_run_endpoint.py
@@ -373,6 +373,7 @@ def post_dag_run(*, dag_id: str, session: Session =
NEW_SESSION) -> APIResponse:
raise AlreadyExists(detail=f"DAGRun with DAG ID: '{dag_id}' and DAGRun ID:
'{run_id}' already exists")
+@mark_fastapi_migration_done
@security.requires_access_dag("PUT", DagAccessEntity.RUN)
@provide_session
@action_logging
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 3d1b1611ad..ae5fc9e117 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1156,6 +1156,78 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+ patch:
+ tags:
+ - DagRun
+ summary: Patch Dag Run State
+ description: Modify a DAG Run.
+ operationId: patch_dag_run_state
+ parameters:
+ - name: dag_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Id
+ - name: dag_run_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Run Id
+ - name: update_mask
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: array
+ items:
+ type: string
+ - type: 'null'
+ title: Update Mask
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DAGRunPatchBody'
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/DAGRunResponse'
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
/public/monitor/health:
get:
tags:
@@ -2079,6 +2151,23 @@ components:
- file_token
title: DAGResponse
description: DAG serializer for responses.
+ DAGRunPatchBody:
+ properties:
+ state:
+ $ref: '#/components/schemas/DAGRunPatchStates'
+ type: object
+ required:
+ - state
+ title: DAGRunPatchBody
+ description: DAG Run Serializer for PATCH requests.
+ DAGRunPatchStates:
+ type: string
+ enum:
+ - queued
+ - success
+ - failed
+ title: DAGRunPatchStates
+ description: Enum for DAG Run states when updating a DAG Run.
DAGRunResponse:
properties:
run_id:
diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py
b/airflow/api_fastapi/core_api/routes/public/dag_run.py
index 035d1b7fd7..02780d6088 100644
--- a/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -17,16 +17,25 @@
from __future__ import annotations
-from fastapi import Depends, HTTPException
+from fastapi import Depends, HTTPException, Query, Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from typing_extensions import Annotated
+from airflow.api.common.mark_tasks import (
+ set_dag_run_state_to_failed,
+ set_dag_run_state_to_queued,
+ set_dag_run_state_to_success,
+)
from airflow.api_fastapi.common.db.common import get_session
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.serializers.dag_run import DAGRunResponse
-from airflow.models import DagRun
+from airflow.api_fastapi.core_api.serializers.dag_run import (
+ DAGRunPatchBody,
+ DAGRunPatchStates,
+ DAGRunResponse,
+)
+from airflow.models import DAG, DagRun
dag_run_router = AirflowRouter(tags=["DagRun"],
prefix="/dags/{dag_id}/dagRuns")
@@ -57,3 +66,45 @@ async def delete_dag_run(dag_id: str, dag_run_id: str,
session: Annotated[Sessio
)
session.delete(dag_run)
+
+
+@dag_run_router.patch("/{dag_run_id}",
responses=create_openapi_http_exception_doc([400, 401, 403, 404]))
+async def patch_dag_run_state(
+ dag_id: str,
+ dag_run_id: str,
+ patch_body: DAGRunPatchBody,
+ session: Annotated[Session, Depends(get_session)],
+ request: Request,
+ update_mask: list[str] | None = Query(None),
+) -> DAGRunResponse:
+ """Modify a DAG Run."""
+ dag_run = session.scalar(select(DagRun).filter_by(dag_id=dag_id,
run_id=dag_run_id))
+ if dag_run is None:
+ raise HTTPException(
+ 404, f"The DagRun with dag_id: `{dag_id}` and run_id:
`{dag_run_id}` was not found"
+ )
+
+ dag: DAG = request.app.state.dag_bag.get_dag(dag_id)
+
+ if not dag:
+ raise HTTPException(404, f"Dag with id {dag_id} was not found")
+
+ if update_mask:
+ if update_mask != ["state"]:
+ raise HTTPException(400, "Only `state` field can be updated
through the REST API")
+ else:
+ update_mask = ["state"]
+
+ for attr_name in update_mask:
+ if attr_name == "state":
+ state = getattr(patch_body, attr_name)
+ if state == DAGRunPatchStates.SUCCESS:
+ set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id,
commit=True)
+ elif state == DAGRunPatchStates.QUEUED:
+ set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id,
commit=True)
+ else:
+ set_dag_run_state_to_failed(dag=dag, run_id=dag_run.run_id,
commit=True)
+
+ dag_run = session.get(DagRun, dag_run.id)
+
+ return DAGRunResponse.model_validate(dag_run, from_attributes=True)
diff --git a/airflow/api_fastapi/core_api/serializers/dag_run.py
b/airflow/api_fastapi/core_api/serializers/dag_run.py
index 4622fac645..1557690561 100644
--- a/airflow/api_fastapi/core_api/serializers/dag_run.py
+++ b/airflow/api_fastapi/core_api/serializers/dag_run.py
@@ -18,6 +18,7 @@
from __future__ import annotations
from datetime import datetime
+from enum import Enum
from pydantic import BaseModel, Field
@@ -25,6 +26,20 @@ from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunTriggeredByType, DagRunType
+class DAGRunPatchStates(str, Enum):
+ """Enum for DAG Run states when updating a DAG Run."""
+
+ QUEUED = DagRunState.QUEUED
+ SUCCESS = DagRunState.SUCCESS
+ FAILED = DagRunState.FAILED
+
+
+class DAGRunPatchBody(BaseModel):
+ """DAG Run Serializer for PATCH requests."""
+
+ state: DAGRunPatchStates
+
+
class DAGRunResponse(BaseModel):
"""DAG Run serializer for responses."""
diff --git a/airflow/ui/openapi-gen/queries/common.ts
b/airflow/ui/openapi-gen/queries/common.ts
index 959b476718..b8c83b1525 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -441,6 +441,9 @@ export type DagServicePatchDagMutationResult = Awaited<
export type VariableServicePatchVariableMutationResult = Awaited<
ReturnType<typeof VariableService.patchVariable>
>;
+export type DagRunServicePatchDagRunStateMutationResult = Awaited<
+ ReturnType<typeof DagRunService.patchDagRunState>
+>;
export type PoolServicePatchPoolMutationResult = Awaited<
ReturnType<typeof PoolService.patchPool>
>;
diff --git a/airflow/ui/openapi-gen/queries/queries.ts
b/airflow/ui/openapi-gen/queries/queries.ts
index a3aed2e793..a0d6a65853 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -23,6 +23,7 @@ import {
} from "../requests/services.gen";
import {
DAGPatchBody,
+ DAGRunPatchBody,
DagRunState,
PoolPatchBody,
PoolPostBody,
@@ -948,6 +949,57 @@ export const useVariableServicePatchVariable = <
}) as unknown as Promise<TData>,
...options,
});
+/**
+ * Patch Dag Run State
+ * Modify a DAG Run.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.requestBody
+ * @param data.updateMask
+ * @returns DAGRunResponse Successful Response
+ * @throws ApiError
+ */
+export const useDagRunServicePatchDagRunState = <
+ TData = Common.DagRunServicePatchDagRunStateMutationResult,
+ TError = unknown,
+ TContext = unknown,
+>(
+ options?: Omit<
+ UseMutationOptions<
+ TData,
+ TError,
+ {
+ dagId: string;
+ dagRunId: string;
+ requestBody: DAGRunPatchBody;
+ updateMask?: string[];
+ },
+ TContext
+ >,
+ "mutationFn"
+ >,
+) =>
+ useMutation<
+ TData,
+ TError,
+ {
+ dagId: string;
+ dagRunId: string;
+ requestBody: DAGRunPatchBody;
+ updateMask?: string[];
+ },
+ TContext
+ >({
+ mutationFn: ({ dagId, dagRunId, requestBody, updateMask }) =>
+ DagRunService.patchDagRunState({
+ dagId,
+ dagRunId,
+ requestBody,
+ updateMask,
+ }) as unknown as Promise<TData>,
+ ...options,
+ });
/**
* Patch Pool
* Update a Pool.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index b1a5b267e1..3982407c5f 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -887,6 +887,25 @@ export const $DAGResponse = {
description: "DAG serializer for responses.",
} as const;
+export const $DAGRunPatchBody = {
+ properties: {
+ state: {
+ $ref: "#/components/schemas/DAGRunPatchStates",
+ },
+ },
+ type: "object",
+ required: ["state"],
+ title: "DAGRunPatchBody",
+ description: "DAG Run Serializer for PATCH requests.",
+} as const;
+
+export const $DAGRunPatchStates = {
+ type: "string",
+ enum: ["queued", "success", "failed"],
+ title: "DAGRunPatchStates",
+ description: "Enum for DAG Run states when updating a DAG Run.",
+} as const;
+
export const $DAGRunResponse = {
properties: {
run_id: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 4db1e052a2..56207631d1 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -43,6 +43,8 @@ import type {
GetDagRunResponse,
DeleteDagRunData,
DeleteDagRunResponse,
+ PatchDagRunStateData,
+ PatchDagRunStateResponse,
GetHealthResponse,
DeletePoolData,
DeletePoolResponse,
@@ -672,6 +674,42 @@ export class DagRunService {
},
});
}
+
+ /**
+ * Patch Dag Run State
+ * Modify a DAG Run.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.requestBody
+ * @param data.updateMask
+ * @returns DAGRunResponse Successful Response
+ * @throws ApiError
+ */
+ public static patchDagRunState(
+ data: PatchDagRunStateData,
+ ): CancelablePromise<PatchDagRunStateResponse> {
+ return __request(OpenAPI, {
+ method: "PATCH",
+ url: "/public/dags/{dag_id}/dagRuns/{dag_run_id}",
+ path: {
+ dag_id: data.dagId,
+ dag_run_id: data.dagRunId,
+ },
+ query: {
+ update_mask: data.updateMask,
+ },
+ body: data.requestBody,
+ mediaType: "application/json",
+ errors: {
+ 400: "Bad Request",
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 404: "Not Found",
+ 422: "Validation Error",
+ },
+ });
+ }
}
export class MonitorService {
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 288916b392..a3b1d8e6be 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -154,6 +154,18 @@ export type DAGResponse = {
readonly file_token: string;
};
+/**
+ * DAG Run Serializer for PATCH requests.
+ */
+export type DAGRunPatchBody = {
+ state: DAGRunPatchStates;
+};
+
+/**
+ * Enum for DAG Run states when updating a DAG Run.
+ */
+export type DAGRunPatchStates = "queued" | "success" | "failed";
+
/**
* DAG Run serializer for responses.
*/
@@ -680,6 +692,15 @@ export type DeleteDagRunData = {
export type DeleteDagRunResponse = void;
+export type PatchDagRunStateData = {
+ dagId: string;
+ dagRunId: string;
+ requestBody: DAGRunPatchBody;
+ updateMask?: Array<string> | null;
+};
+
+export type PatchDagRunStateResponse = DAGRunResponse;
+
export type GetHealthResponse = HealthInfoSchema;
export type DeletePoolData = {
@@ -1236,6 +1257,35 @@ export type $OpenApiTs = {
422: HTTPValidationError;
};
};
+ patch: {
+ req: PatchDagRunStateData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: DAGRunResponse;
+ /**
+ * Bad Request
+ */
+ 400: HTTPExceptionResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
};
"/public/monitor/health": {
get: {
diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
index 554bc73eba..dfd48af2fa 100644
--- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py
+++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py
@@ -138,6 +138,56 @@ class TestGetDagRun:
assert body["detail"] == "The DagRun with dag_id: `test_dag1` and
run_id: `invalid` was not found"
+class TestPatchDagRun:
+ @pytest.mark.parametrize(
+ "dag_id, run_id, state, response_state",
+ [
+ (DAG1_ID, DAG1_RUN1_ID, DagRunState.FAILED, DagRunState.FAILED),
+ (DAG1_ID, DAG1_RUN2_ID, DagRunState.SUCCESS, DagRunState.SUCCESS),
+ (DAG2_ID, DAG2_RUN1_ID, DagRunState.QUEUED, DagRunState.QUEUED),
+ ],
+ )
+ def test_patch_dag_run(self, test_client, dag_id, run_id, state,
response_state):
+ response =
test_client.patch(f"/public/dags/{dag_id}/dagRuns/{run_id}", json={"state":
state})
+ assert response.status_code == 200
+ body = response.json()
+ assert body["dag_id"] == dag_id
+ assert body["run_id"] == run_id
+ assert body["state"] == response_state
+
+ @pytest.mark.parametrize(
+ "query_params,patch_body, expected_status_code",
+ [
+ ({"update_mask": ["state"]}, {"state": DagRunState.SUCCESS}, 200),
+ ({}, {"state": DagRunState.SUCCESS}, 200),
+ ({"update_mask": ["random"]}, {"state": DagRunState.SUCCESS}, 400),
+ ],
+ )
+ def test_patch_dag_run_with_update_mask(
+ self, test_client, query_params, patch_body, expected_status_code
+ ):
+ response = test_client.patch(
+ f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}",
params=query_params, json=patch_body
+ )
+ assert response.status_code == expected_status_code
+
+ def test_patch_dag_run_not_found(self, test_client):
+ response = test_client.patch(
+ f"/public/dags/{DAG1_ID}/dagRuns/invalid", json={"state":
DagRunState.SUCCESS}
+ )
+ assert response.status_code == 404
+ body = response.json()
+ assert body["detail"] == "The DagRun with dag_id: `test_dag1` and
run_id: `invalid` was not found"
+
+ def test_patch_dag_run_bad_request(self, test_client):
+ response = test_client.patch(
+ f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}", json={"state":
"running"}
+ )
+ assert response.status_code == 422
+ body = response.json()
+ assert body["detail"][0]["msg"] == "Input should be 'queued',
'success' or 'failed'"
+
+
class TestDeleteDagRun:
def test_delete_dag_run(self, test_client):
response =
test_client.delete(f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}")