This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 825df6db600 AIP-84 Add patch task_instance dry_run endpoint (#46018)
825df6db600 is described below
commit 825df6db600fd3ba81f708e60c12be8e2b515a2f
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Tue Jan 28 11:22:37 2025 +0100
AIP-84 Add patch task_instance dry_run endpoint (#46018)
---
.../core_api/datamodels/task_instances.py | 3 +-
.../api_fastapi/core_api/openapi/v1-generated.yaml | 194 ++++++++-
.../core_api/routes/public/task_instances.py | 145 +++++--
airflow/ui/openapi-gen/queries/common.ts | 6 +
airflow/ui/openapi-gen/queries/queries.ts | 122 +++++-
airflow/ui/openapi-gen/requests/schemas.gen.ts | 8 +-
airflow/ui/openapi-gen/requests/services.gen.ts | 90 ++++-
airflow/ui/openapi-gen/requests/types.gen.ts | 95 ++++-
.../core_api/routes/public/test_task_instances.py | 443 +++++++++++++++++----
9 files changed, 967 insertions(+), 139 deletions(-)
diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py
b/airflow/api_fastapi/core_api/datamodels/task_instances.py
index 4754e67f2d3..7cecb96ca42 100644
--- a/airflow/api_fastapi/core_api/datamodels/task_instances.py
+++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py
@@ -198,8 +198,7 @@ class ClearTaskInstancesBody(BaseModel):
class PatchTaskInstanceBody(BaseModel):
"""Request body for Clear Task Instances endpoint."""
- dry_run: bool = True
- new_state: str | None = None
+ new_state: TaskInstanceState | None = None
note: Annotated[str, StringConstraints(max_length=1000)] | None = None
include_upstream: bool = False
include_downstream: bool = False
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index b40538e0244..62df27b568e 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -4518,7 +4518,7 @@ paths:
tags:
- Task Instance
summary: Patch Task Instance
- description: Update the state of a task instance.
+ description: Update a task instance.
operationId: patch_task_instance
parameters:
- name: dag_id
@@ -5125,7 +5125,7 @@ paths:
tags:
- Task Instance
summary: Patch Task Instance
- description: Update the state of a task instance.
+ description: Update a task instance.
operationId: patch_task_instance
parameters:
- name: dag_id
@@ -5675,6 +5675,189 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run:
+ patch:
+ tags:
+ - Task Instance
+ summary: Patch Task Instance Dry Run
+ description: Update a task instance dry_run mode.
+ operationId: patch_task_instance_dry_run
+ parameters:
+ - name: dag_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Id
+ - name: dag_run_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Run Id
+ - name: task_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Task Id
+ - name: map_index
+ in: path
+ required: true
+ schema:
+ type: integer
+ title: Map Index
+ - name: update_mask
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: array
+ items:
+ type: string
+ - type: 'null'
+ title: Update Mask
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/PatchTaskInstanceBody'
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstanceCollectionResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '409':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Conflict
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
+ /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run:
+ patch:
+ tags:
+ - Task Instance
+ summary: Patch Task Instance Dry Run
+ description: Update a task instance dry_run mode.
+ operationId: patch_task_instance_dry_run
+ parameters:
+ - name: dag_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Id
+ - name: dag_run_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Run Id
+ - name: task_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Task Id
+ - name: map_index
+ in: query
+ required: false
+ schema:
+ type: integer
+ default: -1
+ title: Map Index
+ - name: update_mask
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: array
+ items:
+ type: string
+ - type: 'null'
+ title: Update Mask
+ requestBody:
+ required: true
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/PatchTaskInstanceBody'
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstanceCollectionResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '400':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Bad Request
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '409':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Conflict
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/tasks:
get:
tags:
@@ -8931,15 +9114,10 @@ components:
description: Node serializer for responses.
PatchTaskInstanceBody:
properties:
- dry_run:
- type: boolean
- title: Dry Run
- default: true
new_state:
anyOf:
- - type: string
+ - $ref: '#/components/schemas/TaskInstanceState'
- type: 'null'
- title: New State
note:
anyOf:
- type: string
diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 1ecc18a6513..c97e190fb2c 100644
--- a/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -62,6 +62,7 @@ from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_
from airflow.exceptions import TaskNotFound
from airflow.jobs.scheduler_job_runner import DR
from airflow.models import Base, DagRun
+from airflow.models.dag import DAG
from airflow.models.taskinstance import TaskInstance as TI,
clear_task_instances
from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH
from airflow.ti_deps.dep_context import DepContext
@@ -661,19 +662,7 @@ def post_clear_task_instances(
)
-@task_instances_router.patch(
- task_instances_prefix + "/{task_id}",
- responses=create_openapi_http_exception_doc(
- [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
- ),
-)
-@task_instances_router.patch(
- task_instances_prefix + "/{task_id}/{map_index}",
- responses=create_openapi_http_exception_doc(
- [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
- ),
-)
-def patch_task_instance(
+def _patch_ti_validate_request(
dag_id: str,
dag_run_id: str,
task_id: str,
@@ -682,8 +671,7 @@ def patch_task_instance(
session: SessionDep,
map_index: int = -1,
update_mask: list[str] | None = Query(None),
-) -> TaskInstanceResponse:
- """Update the state of a task instance."""
+) -> tuple[DAG, TI, dict]:
dag = request.app.state.dag_bag.get_dag(dag_id)
if not dag:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not
found")
@@ -717,34 +705,123 @@ def patch_task_instance(
fields_to_update = body.model_fields_set
if update_mask:
fields_to_update = fields_to_update.intersection(update_mask)
- data = body.model_dump(include=fields_to_update, by_alias=True)
else:
try:
PatchTaskInstanceBody.model_validate(body)
except ValidationError as e:
raise RequestValidationError(errors=e.errors())
- data = body.model_dump(by_alias=True)
+
+ return dag, ti, body.model_dump(include=fields_to_update, by_alias=True)
+
+
+@task_instances_router.patch(
+ task_instances_prefix + "/{task_id}/dry_run",
+ responses=create_openapi_http_exception_doc(
+ [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+ ),
+)
+@task_instances_router.patch(
+ task_instances_prefix + "/{task_id}/{map_index}/dry_run",
+ responses=create_openapi_http_exception_doc(
+ [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+ ),
+)
+def patch_task_instance_dry_run(
+ dag_id: str,
+ dag_run_id: str,
+ task_id: str,
+ request: Request,
+ body: PatchTaskInstanceBody,
+ session: SessionDep,
+ map_index: int = -1,
+ update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+ """Update a task instance dry_run mode."""
+ dag, ti, data = _patch_ti_validate_request(
+ dag_id, dag_run_id, task_id, request, body, session, map_index,
update_mask
+ )
+
+ tis: list[TI] = []
+
+ if data.get("new_state"):
+ tis = dag.set_task_instance_state(
+ task_id=task_id,
+ run_id=dag_run_id,
+ map_indexes=[map_index],
+ state=data["new_state"],
+ upstream=body.include_upstream,
+ downstream=body.include_downstream,
+ future=body.include_future,
+ past=body.include_past,
+ commit=False,
+ session=session,
+ )
+
+ if not tis:
+ raise HTTPException(
+ status.HTTP_409_CONFLICT, f"Task id {task_id} is already in
{data['new_state']} state"
+ )
+ elif "note" in data:
+ tis = [ti]
+
+ return TaskInstanceCollectionResponse(
+ task_instances=[
+ TaskInstanceResponse.model_validate(
+ ti,
+ from_attributes=True,
+ )
+ for ti in tis
+ ],
+ total_entries=len(tis),
+ )
+
+
+@task_instances_router.patch(
+ task_instances_prefix + "/{task_id}",
+ responses=create_openapi_http_exception_doc(
+ [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+ ),
+)
+@task_instances_router.patch(
+ task_instances_prefix + "/{task_id}/{map_index}",
+ responses=create_openapi_http_exception_doc(
+ [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST,
status.HTTP_409_CONFLICT],
+ ),
+)
+def patch_task_instance(
+ dag_id: str,
+ dag_run_id: str,
+ task_id: str,
+ request: Request,
+ body: PatchTaskInstanceBody,
+ session: SessionDep,
+ map_index: int = -1,
+ update_mask: list[str] | None = Query(None),
+) -> TaskInstanceResponse:
+ """Update a task instance."""
+ dag, ti, data = _patch_ti_validate_request(
+ dag_id, dag_run_id, task_id, request, body, session, map_index,
update_mask
+ )
for key, _ in data.items():
if key == "new_state":
- if not body.dry_run:
- tis: list[TI] = dag.set_task_instance_state(
- task_id=task_id,
- run_id=dag_run_id,
- map_indexes=[map_index],
- state=body.new_state,
- upstream=body.include_upstream,
- downstream=body.include_downstream,
- future=body.include_future,
- past=body.include_past,
- commit=True,
- session=session,
+ tis: list[TI] = dag.set_task_instance_state(
+ task_id=task_id,
+ run_id=dag_run_id,
+ map_indexes=[map_index],
+ state=data["new_state"],
+ upstream=body.include_upstream,
+ downstream=body.include_downstream,
+ future=body.include_future,
+ past=body.include_past,
+ commit=True,
+ session=session,
+ )
+ if not tis:
+ raise HTTPException(
+ status.HTTP_409_CONFLICT, f"Task id {task_id} is already
in {data['new_state']} state"
)
- if not tis:
- raise HTTPException(
- status.HTTP_409_CONFLICT, f"Task id {task_id} is
already in {data['new_state']} state"
- )
- ti = tis[0] if isinstance(tis, list) else tis
+ ti = tis[0] if isinstance(tis, list) else tis
elif key == "note":
if update_mask or body.note is not None:
# @TODO: replace None passed for user_id with actual user id
when
diff --git a/airflow/ui/openapi-gen/queries/common.ts
b/airflow/ui/openapi-gen/queries/common.ts
index 73236c9b2d6..0e137620a4e 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1690,6 +1690,12 @@ export type
TaskInstanceServicePatchTaskInstanceMutationResult = Awaited<
export type TaskInstanceServicePatchTaskInstance1MutationResult = Awaited<
ReturnType<typeof TaskInstanceService.patchTaskInstance1>
>;
+export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited<
+ ReturnType<typeof TaskInstanceService.patchTaskInstanceDryRun>
+>;
+export type TaskInstanceServicePatchTaskInstanceDryRun1MutationResult =
Awaited<
+ ReturnType<typeof TaskInstanceService.patchTaskInstanceDryRun1>
+>;
export type PoolServicePatchPoolMutationResult = Awaited<ReturnType<typeof
PoolService.patchPool>>;
export type PoolServiceBulkPoolsMutationResult = Awaited<ReturnType<typeof
PoolService.bulkPools>>;
export type VariableServicePatchVariableMutationResult = Awaited<
diff --git a/airflow/ui/openapi-gen/queries/queries.ts
b/airflow/ui/openapi-gen/queries/queries.ts
index 0fe9e425b70..c4d201767e1 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -3596,7 +3596,7 @@ export const useDagServicePatchDag = <
});
/**
* Patch Task Instance
- * Update the state of a task instance.
+ * Update a task instance.
* @param data The data for the request.
* @param data.dagId
* @param data.dagRunId
@@ -3655,7 +3655,7 @@ export const useTaskInstanceServicePatchTaskInstance = <
});
/**
* Patch Task Instance
- * Update the state of a task instance.
+ * Update a task instance.
* @param data The data for the request.
* @param data.dagId
* @param data.dagRunId
@@ -3712,6 +3712,124 @@ export const useTaskInstanceServicePatchTaskInstance1 =
<
}) as unknown as Promise<TData>,
...options,
});
+/**
+ * Patch Task Instance Dry Run
+ * Update a task instance dry_run mode.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.mapIndex
+ * @param data.requestBody
+ * @param data.updateMask
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServicePatchTaskInstanceDryRun = <
+ TData = Common.TaskInstanceServicePatchTaskInstanceDryRunMutationResult,
+ TError = unknown,
+ TContext = unknown,
+>(
+ options?: Omit<
+ UseMutationOptions<
+ TData,
+ TError,
+ {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ requestBody: PatchTaskInstanceBody;
+ taskId: string;
+ updateMask?: string[];
+ },
+ TContext
+ >,
+ "mutationFn"
+ >,
+) =>
+ useMutation<
+ TData,
+ TError,
+ {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ requestBody: PatchTaskInstanceBody;
+ taskId: string;
+ updateMask?: string[];
+ },
+ TContext
+ >({
+ mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId, updateMask
}) =>
+ TaskInstanceService.patchTaskInstanceDryRun({
+ dagId,
+ dagRunId,
+ mapIndex,
+ requestBody,
+ taskId,
+ updateMask,
+ }) as unknown as Promise<TData>,
+ ...options,
+ });
+/**
+ * Patch Task Instance Dry Run
+ * Update a task instance dry_run mode.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.requestBody
+ * @param data.mapIndex
+ * @param data.updateMask
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServicePatchTaskInstanceDryRun1 = <
+ TData = Common.TaskInstanceServicePatchTaskInstanceDryRun1MutationResult,
+ TError = unknown,
+ TContext = unknown,
+>(
+ options?: Omit<
+ UseMutationOptions<
+ TData,
+ TError,
+ {
+ dagId: string;
+ dagRunId: string;
+ mapIndex?: number;
+ requestBody: PatchTaskInstanceBody;
+ taskId: string;
+ updateMask?: string[];
+ },
+ TContext
+ >,
+ "mutationFn"
+ >,
+) =>
+ useMutation<
+ TData,
+ TError,
+ {
+ dagId: string;
+ dagRunId: string;
+ mapIndex?: number;
+ requestBody: PatchTaskInstanceBody;
+ taskId: string;
+ updateMask?: string[];
+ },
+ TContext
+ >({
+ mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId, updateMask
}) =>
+ TaskInstanceService.patchTaskInstanceDryRun1({
+ dagId,
+ dagRunId,
+ mapIndex,
+ requestBody,
+ taskId,
+ updateMask,
+ }) as unknown as Promise<TData>,
+ ...options,
+ });
/**
* Patch Pool
* Update a Pool.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 3def7c5825c..170b977beb0 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -3913,21 +3913,15 @@ export const $NodeResponse = {
export const $PatchTaskInstanceBody = {
properties: {
- dry_run: {
- type: "boolean",
- title: "Dry Run",
- default: true,
- },
new_state: {
anyOf: [
{
- type: "string",
+ $ref: "#/components/schemas/TaskInstanceState",
},
{
type: "null",
},
],
- title: "New State",
},
note: {
anyOf: [
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 91edabfc3dc..bf6e528da11 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -145,6 +145,10 @@ import type {
GetMappedTaskInstanceTryDetailsResponse,
PostClearTaskInstancesData,
PostClearTaskInstancesResponse,
+ PatchTaskInstanceDryRunData,
+ PatchTaskInstanceDryRunResponse,
+ PatchTaskInstanceDryRun1Data,
+ PatchTaskInstanceDryRun1Response,
GetLogData,
GetLogResponse,
GetImportErrorData,
@@ -1981,7 +1985,7 @@ export class TaskInstanceService {
/**
* Patch Task Instance
- * Update the state of a task instance.
+ * Update a task instance.
* @param data The data for the request.
* @param data.dagId
* @param data.dagRunId
@@ -2249,7 +2253,7 @@ export class TaskInstanceService {
/**
* Patch Task Instance
- * Update the state of a task instance.
+ * Update a task instance.
* @param data The data for the request.
* @param data.dagId
* @param data.dagRunId
@@ -2486,6 +2490,88 @@ export class TaskInstanceService {
});
}
+ /**
+ * Patch Task Instance Dry Run
+ * Update a task instance dry_run mode.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.mapIndex
+ * @param data.requestBody
+ * @param data.updateMask
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+ public static patchTaskInstanceDryRun(
+ data: PatchTaskInstanceDryRunData,
+ ): CancelablePromise<PatchTaskInstanceDryRunResponse> {
+ return __request(OpenAPI, {
+ method: "PATCH",
+ url:
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run",
+ path: {
+ dag_id: data.dagId,
+ dag_run_id: data.dagRunId,
+ task_id: data.taskId,
+ map_index: data.mapIndex,
+ },
+ query: {
+ update_mask: data.updateMask,
+ },
+ body: data.requestBody,
+ mediaType: "application/json",
+ errors: {
+ 400: "Bad Request",
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 404: "Not Found",
+ 409: "Conflict",
+ 422: "Validation Error",
+ },
+ });
+ }
+
+ /**
+ * Patch Task Instance Dry Run
+ * Update a task instance dry_run mode.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.requestBody
+ * @param data.mapIndex
+ * @param data.updateMask
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+ public static patchTaskInstanceDryRun1(
+ data: PatchTaskInstanceDryRun1Data,
+ ): CancelablePromise<PatchTaskInstanceDryRun1Response> {
+ return __request(OpenAPI, {
+ method: "PATCH",
+ url:
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run",
+ path: {
+ dag_id: data.dagId,
+ dag_run_id: data.dagRunId,
+ task_id: data.taskId,
+ },
+ query: {
+ map_index: data.mapIndex,
+ update_mask: data.updateMask,
+ },
+ body: data.requestBody,
+ mediaType: "application/json",
+ errors: {
+ 400: "Bad Request",
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 404: "Not Found",
+ 409: "Conflict",
+ 422: "Validation Error",
+ },
+ });
+ }
+
/**
* Get Log
* Get logs for a specific task instance.
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index cf1cacdf115..6ff1a083ec6 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -1024,8 +1024,7 @@ export type type =
* Request body for Clear Task Instances endpoint.
*/
export type PatchTaskInstanceBody = {
- dry_run?: boolean;
- new_state?: string | null;
+ new_state?: TaskInstanceState | null;
note?: string | null;
include_upstream?: boolean;
include_downstream?: boolean;
@@ -2159,6 +2158,28 @@ export type PostClearTaskInstancesData = {
export type PostClearTaskInstancesResponse = TaskInstanceCollectionResponse;
+export type PatchTaskInstanceDryRunData = {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ requestBody: PatchTaskInstanceBody;
+ taskId: string;
+ updateMask?: Array<string> | null;
+};
+
+export type PatchTaskInstanceDryRunResponse = TaskInstanceCollectionResponse;
+
+export type PatchTaskInstanceDryRun1Data = {
+ dagId: string;
+ dagRunId: string;
+ mapIndex?: number;
+ requestBody: PatchTaskInstanceBody;
+ taskId: string;
+ updateMask?: Array<string> | null;
+};
+
+export type PatchTaskInstanceDryRun1Response = TaskInstanceCollectionResponse;
+
export type GetLogData = {
accept?: "application/json" | "text/plain" | "*/*";
dagId: string;
@@ -4261,6 +4282,76 @@ export type $OpenApiTs = {
};
};
};
+
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run":
{
+ patch: {
+ req: PatchTaskInstanceDryRunData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: TaskInstanceCollectionResponse;
+ /**
+ * Bad Request
+ */
+ 400: HTTPExceptionResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Conflict
+ */
+ 409: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
+ };
+
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run": {
+ patch: {
+ req: PatchTaskInstanceDryRun1Data;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: TaskInstanceCollectionResponse;
+ /**
+ * Bad Request
+ */
+ 400: HTTPExceptionResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Conflict
+ */
+ 409: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
+ };
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{try_number}":
{
get: {
req: GetLogData;
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index bc2cef7c03e..ebf718340d3 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -2693,7 +2693,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
self.ENDPOINT_URL,
json={
- "dry_run": False,
"new_state": self.NEW_STATE,
},
)
@@ -2743,68 +2742,12 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
task_id=self.TASK_ID,
)
- @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
- def test_should_not_call_mocked_api_for_dry_run(self,
mock_set_task_instance_state, test_client, session):
- self.create_task_instances(session)
-
- mock_set_task_instance_state.return_value = session.scalars(
- select(TaskInstance).where(
- TaskInstance.dag_id == self.DAG_ID,
- TaskInstance.task_id == self.TASK_ID,
- TaskInstance.run_id == self.RUN_ID,
- TaskInstance.map_index == -1,
- )
- ).one_or_none()
-
- response = test_client.patch(
- self.ENDPOINT_URL,
- json={
- "dry_run": True,
- "new_state": self.NEW_STATE,
- },
- )
- assert response.status_code == 200
- assert response.json() == {
- "dag_id": self.DAG_ID,
- "dag_run_id": self.RUN_ID,
- "logical_date": "2020-01-01T00:00:00Z",
- "task_id": self.TASK_ID,
- "duration": 10000.0,
- "end_date": "2020-01-03T00:00:00Z",
- "executor": None,
- "executor_config": "{}",
- "hostname": "",
- "id": mock.ANY,
- "map_index": -1,
- "max_tries": 0,
- "note": "placeholder-note",
- "operator": "PythonOperator",
- "pid": 100,
- "pool": "default_pool",
- "pool_slots": 1,
- "priority_weight": 9,
- "queue": "default_queue",
- "queued_when": None,
- "start_date": "2020-01-02T00:00:00Z",
- "state": "running",
- "task_display_name": self.TASK_ID,
- "try_number": 0,
- "unixname": getuser(),
- "rendered_fields": {},
- "rendered_map_index": None,
- "trigger": None,
- "triggerer_job": None,
- }
-
- mock_set_task_instance_state.assert_not_called()
-
def test_should_update_task_instance_state(self, test_client, session):
self.create_task_instances(session)
test_client.patch(
self.ENDPOINT_URL,
json={
- "dry_run": False,
"new_state": self.NEW_STATE,
},
)
@@ -2813,20 +2756,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
assert response2.status_code == 200
assert response2.json()["state"] == self.NEW_STATE
- def test_should_update_task_instance_state_default_dry_run_to_true(self,
test_client, session):
- self.create_task_instances(session)
-
- test_client.patch(
- self.ENDPOINT_URL,
- json={
- "new_state": self.NEW_STATE,
- },
- )
-
- response2 = test_client.get(self.ENDPOINT_URL)
- assert response2.status_code == 200
- assert response2.json()["state"] == "running" # no change in state
-
def test_should_update_mapped_task_instance_state(self, test_client,
session):
map_index = 1
tis = self.create_task_instances(session)
@@ -2838,7 +2767,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
f"{self.ENDPOINT_URL}/{map_index}",
json={
- "dry_run": False,
"new_state": self.NEW_STATE,
},
)
@@ -2858,7 +2786,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
),
404,
{
- "dry_run": True,
"new_state": "failed",
},
]
@@ -2877,7 +2804,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
self.ENDPOINT_URL,
json={
- "dryrun": True,
"new_state": self.NEW_STATE,
},
)
@@ -2887,7 +2813,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
"/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
json={
- "dry_run": False,
"new_state": self.NEW_STATE,
},
)
@@ -2898,7 +2823,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task",
json={
- "dry_run": False,
"new_state": self.NEW_STATE,
},
)
@@ -2911,7 +2835,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
self.ENDPOINT_URL,
json={
- "dry_run": True,
"new_state": self.NEW_STATE,
},
)
@@ -2921,7 +2844,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
self.ENDPOINT_URL,
json={
- "dry_run": True,
"new_state": self.NEW_STATE,
},
)
@@ -2932,14 +2854,12 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
[
(
{
- "dry_run": True,
"new_state": "failede",
},
f"'failede' is not one of ['{State.SUCCESS}',
'{State.FAILED}', '{State.SKIPPED}']",
),
(
{
- "dry_run": True,
"new_state": "queued",
},
f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}',
'{State.SKIPPED}']",
@@ -3048,7 +2968,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
self.ENDPOINT_URL,
params={"update_mask": "new_state"},
json={
- "dry_run": False,
"new_state": new_state,
},
)
@@ -3222,7 +3141,367 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
response = test_client.patch(
self.ENDPOINT_URL,
json={
- "dry_run": False,
+ "new_state": "success",
+ },
+ )
+ assert response.status_code == 409
+ assert "Task id print_the_context is already in success state" in
response.text
+
+
+class TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
+ ENDPOINT_URL = (
+
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+ )
+ NEW_STATE = "failed"
+ DAG_ID = "example_python_operator"
+ TASK_ID = "print_the_context"
+ RUN_ID = "TEST_DAG_RUN_ID"
+
+ @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+ def test_should_call_mocked_api(self, mock_set_ti_state, test_client,
session):
+ self.create_task_instances(session)
+
+ mock_set_ti_state.return_value = [
+ session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.task_id == self.TASK_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ TaskInstance.map_index == -1,
+ )
+ ).one_or_none()
+ ]
+
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json={
+ "new_state": self.NEW_STATE,
+ },
+ )
+ assert response.status_code == 200
+ assert response.json() == {
+ "task_instances": [
+ {
+ "dag_id": self.DAG_ID,
+ "dag_run_id": self.RUN_ID,
+ "logical_date": "2020-01-01T00:00:00Z",
+ "task_id": self.TASK_ID,
+ "duration": 10000.0,
+ "end_date": "2020-01-03T00:00:00Z",
+ "executor": None,
+ "executor_config": "{}",
+ "hostname": "",
+ "id": mock.ANY,
+ "map_index": -1,
+ "max_tries": 0,
+ "note": "placeholder-note",
+ "operator": "PythonOperator",
+ "pid": 100,
+ "pool": "default_pool",
+ "pool_slots": 1,
+ "priority_weight": 9,
+ "queue": "default_queue",
+ "queued_when": None,
+ "start_date": "2020-01-02T00:00:00Z",
+ "state": "running",
+ "task_display_name": self.TASK_ID,
+ "try_number": 0,
+ "unixname": getuser(),
+ "rendered_fields": {},
+ "rendered_map_index": None,
+ "trigger": None,
+ "triggerer_job": None,
+ }
+ ],
+ "total_entries": 1,
+ }
+
+ mock_set_ti_state.assert_called_once_with(
+ commit=False,
+ downstream=False,
+ upstream=False,
+ future=False,
+ map_indexes=[-1],
+ past=False,
+ run_id=self.RUN_ID,
+ session=mock.ANY,
+ state=self.NEW_STATE,
+ task_id=self.TASK_ID,
+ )
+
+ @pytest.mark.parametrize(
+ "payload",
+ [
+ {
+ "new_state": "success",
+ },
+ {
+ "note": "something",
+ },
+ {
+ "new_state": "success",
+ "note": "something",
+ },
+ ],
+ )
+ def test_should_not_update(self, test_client, session, payload):
+ self.create_task_instances(session)
+
+ task_before = test_client.get(self.ENDPOINT_URL).json()
+
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json=payload,
+ )
+
+ assert response.status_code == 200
+ assert [ti["task_id"] for ti in response.json()["task_instances"]] ==
["print_the_context"]
+
+ task_after = test_client.get(self.ENDPOINT_URL).json()
+
+ assert task_before == task_after
+
+ def test_should_not_update_mapped_task_instance(self, test_client,
session):
+ map_index = 1
+ tis = self.create_task_instances(session)
+ ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id,
map_index=map_index)
+ ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
+ session.add(ti)
+ session.commit()
+
+ task_before =
test_client.get(f"{self.ENDPOINT_URL}/{map_index}").json()
+
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/{map_index}/dry_run",
+ json={
+ "new_state": self.NEW_STATE,
+ },
+ )
+
+ assert response.status_code == 200
+ assert [ti["task_id"] for ti in response.json()["task_instances"]] ==
["print_the_context"]
+
+ task_after = test_client.get(f"{self.ENDPOINT_URL}/{map_index}").json()
+
+ assert task_before == task_after
+
+ @pytest.mark.parametrize(
+ "error, code, payload",
+ [
+ [
+ (
+ "Task Instance not found for
dag_id=example_python_operator"
+ ", run_id=TEST_DAG_RUN_ID, task_id=print_the_context"
+ ),
+ 404,
+ {
+ "new_state": "failed",
+ },
+ ]
+ ],
+ )
+ def test_should_handle_errors(self, error, code, payload, test_client,
session):
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json=payload,
+ )
+ assert response.status_code == code
+ assert response.json()["detail"] == error
+
+ def test_should_200_for_unknown_fields(self, test_client, session):
+ self.create_task_instances(session)
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json={
+ "new_state": self.NEW_STATE,
+ },
+ )
+ assert response.status_code == 200
+
+ def test_should_raise_404_for_non_existent_dag(self, test_client):
+ response = test_client.patch(
+
"/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/dry_run",
+ json={
+ "new_state": self.NEW_STATE,
+ },
+ )
+ assert response.status_code == 404
+ assert response.json() == {"detail": "DAG non-existent-dag not found"}
+
+ def test_should_raise_404_for_non_existent_task_in_dag(self, test_client):
+ response = test_client.patch(
+
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task/dry_run",
+ json={
+ "new_state": self.NEW_STATE,
+ },
+ )
+ assert response.status_code == 404
+ assert response.json() == {
+ "detail": "Task 'non_existent_task' not found in DAG
'example_python_operator'"
+ }
+
+ def test_should_raise_404_not_found_dag(self, test_client):
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json={
+ "new_state": self.NEW_STATE,
+ },
+ )
+ assert response.status_code == 404
+
+ def test_should_raise_404_not_found_task(self, test_client):
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json={
+ "new_state": self.NEW_STATE,
+ },
+ )
+ assert response.status_code == 404
+
+ @pytest.mark.parametrize(
+ "payload, expected",
+ [
+ (
+ {
+ "new_state": "failede",
+ },
+ f"'failede' is not one of ['{State.SUCCESS}',
'{State.FAILED}', '{State.SKIPPED}']",
+ ),
+ (
+ {
+ "new_state": "queued",
+ },
+ f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}',
'{State.SKIPPED}']",
+ ),
+ ],
+ )
+ def test_should_raise_422_for_invalid_task_instance_state(self, payload,
expected, test_client, session):
+ self.create_task_instances(session)
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json=payload,
+ )
+ assert response.status_code == 422
+ assert response.json() == {
+ "detail": [
+ {
+ "type": "value_error",
+ "loc": ["body", "new_state"],
+ "msg": f"Value error, {expected}",
+ "input": payload["new_state"],
+ "ctx": {"error": {}},
+ }
+ ]
+ }
+
+ @pytest.mark.parametrize(
+ "new_state,expected_status_code,expected_json,set_ti_state_call_count",
+ [
+ (
+ "failed",
+ 200,
+ {
+ "task_instances": [
+ {
+ "dag_id": "example_python_operator",
+ "dag_run_id": "TEST_DAG_RUN_ID",
+ "logical_date": "2020-01-01T00:00:00Z",
+ "task_id": "print_the_context",
+ "duration": 10000.0,
+ "end_date": "2020-01-03T00:00:00Z",
+ "executor": None,
+ "executor_config": "{}",
+ "hostname": "",
+ "id": mock.ANY,
+ "map_index": -1,
+ "max_tries": 0,
+ "note": "placeholder-note",
+ "operator": "PythonOperator",
+ "pid": 100,
+ "pool": "default_pool",
+ "pool_slots": 1,
+ "priority_weight": 9,
+ "queue": "default_queue",
+ "queued_when": None,
+ "start_date": "2020-01-02T00:00:00Z",
+ "state": "running",
+ "task_display_name": "print_the_context",
+ "try_number": 0,
+ "unixname": getuser(),
+ "rendered_fields": {},
+ "rendered_map_index": None,
+ "trigger": None,
+ "triggerer_job": None,
+ }
+ ],
+ "total_entries": 1,
+ },
+ 1,
+ ),
+ (
+ None,
+ 422,
+ {
+ "detail": [
+ {
+ "type": "value_error",
+ "loc": ["body", "new_state"],
+ "msg": "Value error, 'new_state' should not be
empty",
+ "input": None,
+ "ctx": {"error": {}},
+ }
+ ]
+ },
+ 0,
+ ),
+ ],
+ )
+ @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+ def test_update_mask_should_call_mocked_api(
+ self,
+ mock_set_ti_state,
+ test_client,
+ session,
+ new_state,
+ expected_status_code,
+ expected_json,
+ set_ti_state_call_count,
+ ):
+ self.create_task_instances(session)
+
+ mock_set_ti_state.return_value = [
+ session.scalars(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.DAG_ID,
+ TaskInstance.task_id == self.TASK_ID,
+ TaskInstance.run_id == self.RUN_ID,
+ TaskInstance.map_index == -1,
+ )
+ ).one_or_none()
+ ]
+
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ params={"update_mask": "new_state"},
+ json={
+ "new_state": new_state,
+ },
+ )
+ assert response.status_code == expected_status_code
+ assert response.json() == expected_json
+ assert mock_set_ti_state.call_count == set_ti_state_call_count
+
+ @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+ def test_should_raise_409_for_updating_same_task_instance_state(
+ self, mock_set_ti_state, test_client, session
+ ):
+ self.create_task_instances(session)
+
+ mock_set_ti_state.return_value = None
+
+ response = test_client.patch(
+ f"{self.ENDPOINT_URL}/dry_run",
+ json={
"new_state": "success",
},
)