This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 825df6db600 AIP-84 Add patch task_instance dry_run endpoint (#46018)
825df6db600 is described below

commit 825df6db600fd3ba81f708e60c12be8e2b515a2f
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Tue Jan 28 11:22:37 2025 +0100

    AIP-84 Add patch task_instance dry_run endpoint (#46018)
---
 .../core_api/datamodels/task_instances.py          |   3 +-
 .../api_fastapi/core_api/openapi/v1-generated.yaml | 194 ++++++++-
 .../core_api/routes/public/task_instances.py       | 145 +++++--
 airflow/ui/openapi-gen/queries/common.ts           |   6 +
 airflow/ui/openapi-gen/queries/queries.ts          | 122 +++++-
 airflow/ui/openapi-gen/requests/schemas.gen.ts     |   8 +-
 airflow/ui/openapi-gen/requests/services.gen.ts    |  90 ++++-
 airflow/ui/openapi-gen/requests/types.gen.ts       |  95 ++++-
 .../core_api/routes/public/test_task_instances.py  | 443 +++++++++++++++++----
 9 files changed, 967 insertions(+), 139 deletions(-)

diff --git a/airflow/api_fastapi/core_api/datamodels/task_instances.py 
b/airflow/api_fastapi/core_api/datamodels/task_instances.py
index 4754e67f2d3..7cecb96ca42 100644
--- a/airflow/api_fastapi/core_api/datamodels/task_instances.py
+++ b/airflow/api_fastapi/core_api/datamodels/task_instances.py
@@ -198,8 +198,7 @@ class ClearTaskInstancesBody(BaseModel):
 class PatchTaskInstanceBody(BaseModel):
     """Request body for Clear Task Instances endpoint."""
 
-    dry_run: bool = True
-    new_state: str | None = None
+    new_state: TaskInstanceState | None = None
     note: Annotated[str, StringConstraints(max_length=1000)] | None = None
     include_upstream: bool = False
     include_downstream: bool = False
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml 
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index b40538e0244..62df27b568e 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -4518,7 +4518,7 @@ paths:
       tags:
       - Task Instance
       summary: Patch Task Instance
-      description: Update the state of a task instance.
+      description: Update a task instance.
       operationId: patch_task_instance
       parameters:
       - name: dag_id
@@ -5125,7 +5125,7 @@ paths:
       tags:
       - Task Instance
       summary: Patch Task Instance
-      description: Update the state of a task instance.
+      description: Update a task instance.
       operationId: patch_task_instance
       parameters:
       - name: dag_id
@@ -5675,6 +5675,189 @@ paths:
             application/json:
               schema:
                 $ref: '#/components/schemas/HTTPValidationError'
+  
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run:
+    patch:
+      tags:
+      - Task Instance
+      summary: Patch Task Instance Dry Run
+      description: Update a task instance dry_run mode.
+      operationId: patch_task_instance_dry_run
+      parameters:
+      - name: dag_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Id
+      - name: dag_run_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Run Id
+      - name: task_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Task Id
+      - name: map_index
+        in: path
+        required: true
+        schema:
+          type: integer
+          title: Map Index
+      - name: update_mask
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: array
+            items:
+              type: string
+          - type: 'null'
+          title: Update Mask
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/PatchTaskInstanceBody'
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/TaskInstanceCollectionResponse'
+        '401':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Unauthorized
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '400':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Bad Request
+        '404':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Not Found
+        '409':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Conflict
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
+  /public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run:
+    patch:
+      tags:
+      - Task Instance
+      summary: Patch Task Instance Dry Run
+      description: Update a task instance dry_run mode.
+      operationId: patch_task_instance_dry_run
+      parameters:
+      - name: dag_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Id
+      - name: dag_run_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Dag Run Id
+      - name: task_id
+        in: path
+        required: true
+        schema:
+          type: string
+          title: Task Id
+      - name: map_index
+        in: query
+        required: false
+        schema:
+          type: integer
+          default: -1
+          title: Map Index
+      - name: update_mask
+        in: query
+        required: false
+        schema:
+          anyOf:
+          - type: array
+            items:
+              type: string
+          - type: 'null'
+          title: Update Mask
+      requestBody:
+        required: true
+        content:
+          application/json:
+            schema:
+              $ref: '#/components/schemas/PatchTaskInstanceBody'
+      responses:
+        '200':
+          description: Successful Response
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/TaskInstanceCollectionResponse'
+        '401':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Unauthorized
+        '403':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Forbidden
+        '400':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Bad Request
+        '404':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Not Found
+        '409':
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPExceptionResponse'
+          description: Conflict
+        '422':
+          description: Validation Error
+          content:
+            application/json:
+              schema:
+                $ref: '#/components/schemas/HTTPValidationError'
   /public/dags/{dag_id}/tasks:
     get:
       tags:
@@ -8931,15 +9114,10 @@ components:
       description: Node serializer for responses.
     PatchTaskInstanceBody:
       properties:
-        dry_run:
-          type: boolean
-          title: Dry Run
-          default: true
         new_state:
           anyOf:
-          - type: string
+          - $ref: '#/components/schemas/TaskInstanceState'
           - type: 'null'
-          title: New State
         note:
           anyOf:
           - type: string
diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py 
b/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 1ecc18a6513..c97e190fb2c 100644
--- a/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -62,6 +62,7 @@ from airflow.api_fastapi.core_api.openapi.exceptions import 
create_openapi_http_
 from airflow.exceptions import TaskNotFound
 from airflow.jobs.scheduler_job_runner import DR
 from airflow.models import Base, DagRun
+from airflow.models.dag import DAG
 from airflow.models.taskinstance import TaskInstance as TI, 
clear_task_instances
 from airflow.models.taskinstancehistory import TaskInstanceHistory as TIH
 from airflow.ti_deps.dep_context import DepContext
@@ -661,19 +662,7 @@ def post_clear_task_instances(
     )
 
 
-@task_instances_router.patch(
-    task_instances_prefix + "/{task_id}",
-    responses=create_openapi_http_exception_doc(
-        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
-    ),
-)
-@task_instances_router.patch(
-    task_instances_prefix + "/{task_id}/{map_index}",
-    responses=create_openapi_http_exception_doc(
-        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
-    ),
-)
-def patch_task_instance(
+def _patch_ti_validate_request(
     dag_id: str,
     dag_run_id: str,
     task_id: str,
@@ -682,8 +671,7 @@ def patch_task_instance(
     session: SessionDep,
     map_index: int = -1,
     update_mask: list[str] | None = Query(None),
-) -> TaskInstanceResponse:
-    """Update the state of a task instance."""
+) -> tuple[DAG, TI, dict]:
     dag = request.app.state.dag_bag.get_dag(dag_id)
     if not dag:
         raise HTTPException(status.HTTP_404_NOT_FOUND, f"DAG {dag_id} not 
found")
@@ -717,34 +705,123 @@ def patch_task_instance(
     fields_to_update = body.model_fields_set
     if update_mask:
         fields_to_update = fields_to_update.intersection(update_mask)
-        data = body.model_dump(include=fields_to_update, by_alias=True)
     else:
         try:
             PatchTaskInstanceBody.model_validate(body)
         except ValidationError as e:
             raise RequestValidationError(errors=e.errors())
-        data = body.model_dump(by_alias=True)
+
+    return dag, ti, body.model_dump(include=fields_to_update, by_alias=True)
+
+
+@task_instances_router.patch(
+    task_instances_prefix + "/{task_id}/dry_run",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+    ),
+)
+@task_instances_router.patch(
+    task_instances_prefix + "/{task_id}/{map_index}/dry_run",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+    ),
+)
+def patch_task_instance_dry_run(
+    dag_id: str,
+    dag_run_id: str,
+    task_id: str,
+    request: Request,
+    body: PatchTaskInstanceBody,
+    session: SessionDep,
+    map_index: int = -1,
+    update_mask: list[str] | None = Query(None),
+) -> TaskInstanceCollectionResponse:
+    """Update a task instance dry_run mode."""
+    dag, ti, data = _patch_ti_validate_request(
+        dag_id, dag_run_id, task_id, request, body, session, map_index, 
update_mask
+    )
+
+    tis: list[TI] = []
+
+    if data.get("new_state"):
+        tis = dag.set_task_instance_state(
+            task_id=task_id,
+            run_id=dag_run_id,
+            map_indexes=[map_index],
+            state=data["new_state"],
+            upstream=body.include_upstream,
+            downstream=body.include_downstream,
+            future=body.include_future,
+            past=body.include_past,
+            commit=False,
+            session=session,
+        )
+
+        if not tis:
+            raise HTTPException(
+                status.HTTP_409_CONFLICT, f"Task id {task_id} is already in 
{data['new_state']} state"
+            )
+    elif "note" in data:
+        tis = [ti]
+
+    return TaskInstanceCollectionResponse(
+        task_instances=[
+            TaskInstanceResponse.model_validate(
+                ti,
+                from_attributes=True,
+            )
+            for ti in tis
+        ],
+        total_entries=len(tis),
+    )
+
+
+@task_instances_router.patch(
+    task_instances_prefix + "/{task_id}",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+    ),
+)
+@task_instances_router.patch(
+    task_instances_prefix + "/{task_id}/{map_index}",
+    responses=create_openapi_http_exception_doc(
+        [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, 
status.HTTP_409_CONFLICT],
+    ),
+)
+def patch_task_instance(
+    dag_id: str,
+    dag_run_id: str,
+    task_id: str,
+    request: Request,
+    body: PatchTaskInstanceBody,
+    session: SessionDep,
+    map_index: int = -1,
+    update_mask: list[str] | None = Query(None),
+) -> TaskInstanceResponse:
+    """Update a task instance."""
+    dag, ti, data = _patch_ti_validate_request(
+        dag_id, dag_run_id, task_id, request, body, session, map_index, 
update_mask
+    )
 
     for key, _ in data.items():
         if key == "new_state":
-            if not body.dry_run:
-                tis: list[TI] = dag.set_task_instance_state(
-                    task_id=task_id,
-                    run_id=dag_run_id,
-                    map_indexes=[map_index],
-                    state=body.new_state,
-                    upstream=body.include_upstream,
-                    downstream=body.include_downstream,
-                    future=body.include_future,
-                    past=body.include_past,
-                    commit=True,
-                    session=session,
+            tis: list[TI] = dag.set_task_instance_state(
+                task_id=task_id,
+                run_id=dag_run_id,
+                map_indexes=[map_index],
+                state=data["new_state"],
+                upstream=body.include_upstream,
+                downstream=body.include_downstream,
+                future=body.include_future,
+                past=body.include_past,
+                commit=True,
+                session=session,
+            )
+            if not tis:
+                raise HTTPException(
+                    status.HTTP_409_CONFLICT, f"Task id {task_id} is already 
in {data['new_state']} state"
                 )
-                if not tis:
-                    raise HTTPException(
-                        status.HTTP_409_CONFLICT, f"Task id {task_id} is 
already in {data['new_state']} state"
-                    )
-                ti = tis[0] if isinstance(tis, list) else tis
+            ti = tis[0] if isinstance(tis, list) else tis
         elif key == "note":
             if update_mask or body.note is not None:
                 # @TODO: replace None passed for user_id with actual user id 
when
diff --git a/airflow/ui/openapi-gen/queries/common.ts 
b/airflow/ui/openapi-gen/queries/common.ts
index 73236c9b2d6..0e137620a4e 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -1690,6 +1690,12 @@ export type 
TaskInstanceServicePatchTaskInstanceMutationResult = Awaited<
 export type TaskInstanceServicePatchTaskInstance1MutationResult = Awaited<
   ReturnType<typeof TaskInstanceService.patchTaskInstance1>
 >;
+export type TaskInstanceServicePatchTaskInstanceDryRunMutationResult = Awaited<
+  ReturnType<typeof TaskInstanceService.patchTaskInstanceDryRun>
+>;
+export type TaskInstanceServicePatchTaskInstanceDryRun1MutationResult = 
Awaited<
+  ReturnType<typeof TaskInstanceService.patchTaskInstanceDryRun1>
+>;
 export type PoolServicePatchPoolMutationResult = Awaited<ReturnType<typeof 
PoolService.patchPool>>;
 export type PoolServiceBulkPoolsMutationResult = Awaited<ReturnType<typeof 
PoolService.bulkPools>>;
 export type VariableServicePatchVariableMutationResult = Awaited<
diff --git a/airflow/ui/openapi-gen/queries/queries.ts 
b/airflow/ui/openapi-gen/queries/queries.ts
index 0fe9e425b70..c4d201767e1 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -3596,7 +3596,7 @@ export const useDagServicePatchDag = <
   });
 /**
  * Patch Task Instance
- * Update the state of a task instance.
+ * Update a task instance.
  * @param data The data for the request.
  * @param data.dagId
  * @param data.dagRunId
@@ -3655,7 +3655,7 @@ export const useTaskInstanceServicePatchTaskInstance = <
   });
 /**
  * Patch Task Instance
- * Update the state of a task instance.
+ * Update a task instance.
  * @param data The data for the request.
  * @param data.dagId
  * @param data.dagRunId
@@ -3712,6 +3712,124 @@ export const useTaskInstanceServicePatchTaskInstance1 = 
<
       }) as unknown as Promise<TData>,
     ...options,
   });
+/**
+ * Patch Task Instance Dry Run
+ * Update a task instance dry_run mode.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.mapIndex
+ * @param data.requestBody
+ * @param data.updateMask
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServicePatchTaskInstanceDryRun = <
+  TData = Common.TaskInstanceServicePatchTaskInstanceDryRunMutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        dagId: string;
+        dagRunId: string;
+        mapIndex: number;
+        requestBody: PatchTaskInstanceBody;
+        taskId: string;
+        updateMask?: string[];
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      dagId: string;
+      dagRunId: string;
+      mapIndex: number;
+      requestBody: PatchTaskInstanceBody;
+      taskId: string;
+      updateMask?: string[];
+    },
+    TContext
+  >({
+    mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId, updateMask 
}) =>
+      TaskInstanceService.patchTaskInstanceDryRun({
+        dagId,
+        dagRunId,
+        mapIndex,
+        requestBody,
+        taskId,
+        updateMask,
+      }) as unknown as Promise<TData>,
+    ...options,
+  });
+/**
+ * Patch Task Instance Dry Run
+ * Update a task instance dry_run mode.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.requestBody
+ * @param data.mapIndex
+ * @param data.updateMask
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServicePatchTaskInstanceDryRun1 = <
+  TData = Common.TaskInstanceServicePatchTaskInstanceDryRun1MutationResult,
+  TError = unknown,
+  TContext = unknown,
+>(
+  options?: Omit<
+    UseMutationOptions<
+      TData,
+      TError,
+      {
+        dagId: string;
+        dagRunId: string;
+        mapIndex?: number;
+        requestBody: PatchTaskInstanceBody;
+        taskId: string;
+        updateMask?: string[];
+      },
+      TContext
+    >,
+    "mutationFn"
+  >,
+) =>
+  useMutation<
+    TData,
+    TError,
+    {
+      dagId: string;
+      dagRunId: string;
+      mapIndex?: number;
+      requestBody: PatchTaskInstanceBody;
+      taskId: string;
+      updateMask?: string[];
+    },
+    TContext
+  >({
+    mutationFn: ({ dagId, dagRunId, mapIndex, requestBody, taskId, updateMask 
}) =>
+      TaskInstanceService.patchTaskInstanceDryRun1({
+        dagId,
+        dagRunId,
+        mapIndex,
+        requestBody,
+        taskId,
+        updateMask,
+      }) as unknown as Promise<TData>,
+    ...options,
+  });
 /**
  * Patch Pool
  * Update a Pool.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts 
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 3def7c5825c..170b977beb0 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -3913,21 +3913,15 @@ export const $NodeResponse = {
 
 export const $PatchTaskInstanceBody = {
   properties: {
-    dry_run: {
-      type: "boolean",
-      title: "Dry Run",
-      default: true,
-    },
     new_state: {
       anyOf: [
         {
-          type: "string",
+          $ref: "#/components/schemas/TaskInstanceState",
         },
         {
           type: "null",
         },
       ],
-      title: "New State",
     },
     note: {
       anyOf: [
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts 
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 91edabfc3dc..bf6e528da11 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -145,6 +145,10 @@ import type {
   GetMappedTaskInstanceTryDetailsResponse,
   PostClearTaskInstancesData,
   PostClearTaskInstancesResponse,
+  PatchTaskInstanceDryRunData,
+  PatchTaskInstanceDryRunResponse,
+  PatchTaskInstanceDryRun1Data,
+  PatchTaskInstanceDryRun1Response,
   GetLogData,
   GetLogResponse,
   GetImportErrorData,
@@ -1981,7 +1985,7 @@ export class TaskInstanceService {
 
   /**
    * Patch Task Instance
-   * Update the state of a task instance.
+   * Update a task instance.
    * @param data The data for the request.
    * @param data.dagId
    * @param data.dagRunId
@@ -2249,7 +2253,7 @@ export class TaskInstanceService {
 
   /**
    * Patch Task Instance
-   * Update the state of a task instance.
+   * Update a task instance.
    * @param data The data for the request.
    * @param data.dagId
    * @param data.dagRunId
@@ -2486,6 +2490,88 @@ export class TaskInstanceService {
     });
   }
 
+  /**
+   * Patch Task Instance Dry Run
+   * Update a task instance dry_run mode.
+   * @param data The data for the request.
+   * @param data.dagId
+   * @param data.dagRunId
+   * @param data.taskId
+   * @param data.mapIndex
+   * @param data.requestBody
+   * @param data.updateMask
+   * @returns TaskInstanceCollectionResponse Successful Response
+   * @throws ApiError
+   */
+  public static patchTaskInstanceDryRun(
+    data: PatchTaskInstanceDryRunData,
+  ): CancelablePromise<PatchTaskInstanceDryRunResponse> {
+    return __request(OpenAPI, {
+      method: "PATCH",
+      url: 
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run",
+      path: {
+        dag_id: data.dagId,
+        dag_run_id: data.dagRunId,
+        task_id: data.taskId,
+        map_index: data.mapIndex,
+      },
+      query: {
+        update_mask: data.updateMask,
+      },
+      body: data.requestBody,
+      mediaType: "application/json",
+      errors: {
+        400: "Bad Request",
+        401: "Unauthorized",
+        403: "Forbidden",
+        404: "Not Found",
+        409: "Conflict",
+        422: "Validation Error",
+      },
+    });
+  }
+
+  /**
+   * Patch Task Instance Dry Run
+   * Update a task instance dry_run mode.
+   * @param data The data for the request.
+   * @param data.dagId
+   * @param data.dagRunId
+   * @param data.taskId
+   * @param data.requestBody
+   * @param data.mapIndex
+   * @param data.updateMask
+   * @returns TaskInstanceCollectionResponse Successful Response
+   * @throws ApiError
+   */
+  public static patchTaskInstanceDryRun1(
+    data: PatchTaskInstanceDryRun1Data,
+  ): CancelablePromise<PatchTaskInstanceDryRun1Response> {
+    return __request(OpenAPI, {
+      method: "PATCH",
+      url: 
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run",
+      path: {
+        dag_id: data.dagId,
+        dag_run_id: data.dagRunId,
+        task_id: data.taskId,
+      },
+      query: {
+        map_index: data.mapIndex,
+        update_mask: data.updateMask,
+      },
+      body: data.requestBody,
+      mediaType: "application/json",
+      errors: {
+        400: "Bad Request",
+        401: "Unauthorized",
+        403: "Forbidden",
+        404: "Not Found",
+        409: "Conflict",
+        422: "Validation Error",
+      },
+    });
+  }
+
   /**
    * Get Log
    * Get logs for a specific task instance.
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts 
b/airflow/ui/openapi-gen/requests/types.gen.ts
index cf1cacdf115..6ff1a083ec6 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -1024,8 +1024,7 @@ export type type =
  * Request body for Clear Task Instances endpoint.
  */
 export type PatchTaskInstanceBody = {
-  dry_run?: boolean;
-  new_state?: string | null;
+  new_state?: TaskInstanceState | null;
   note?: string | null;
   include_upstream?: boolean;
   include_downstream?: boolean;
@@ -2159,6 +2158,28 @@ export type PostClearTaskInstancesData = {
 
 export type PostClearTaskInstancesResponse = TaskInstanceCollectionResponse;
 
+export type PatchTaskInstanceDryRunData = {
+  dagId: string;
+  dagRunId: string;
+  mapIndex: number;
+  requestBody: PatchTaskInstanceBody;
+  taskId: string;
+  updateMask?: Array<string> | null;
+};
+
+export type PatchTaskInstanceDryRunResponse = TaskInstanceCollectionResponse;
+
+export type PatchTaskInstanceDryRun1Data = {
+  dagId: string;
+  dagRunId: string;
+  mapIndex?: number;
+  requestBody: PatchTaskInstanceBody;
+  taskId: string;
+  updateMask?: Array<string> | null;
+};
+
+export type PatchTaskInstanceDryRun1Response = TaskInstanceCollectionResponse;
+
 export type GetLogData = {
   accept?: "application/json" | "text/plain" | "*/*";
   dagId: string;
@@ -4261,6 +4282,76 @@ export type $OpenApiTs = {
       };
     };
   };
+  
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}/dry_run":
 {
+    patch: {
+      req: PatchTaskInstanceDryRunData;
+      res: {
+        /**
+         * Successful Response
+         */
+        200: TaskInstanceCollectionResponse;
+        /**
+         * Bad Request
+         */
+        400: HTTPExceptionResponse;
+        /**
+         * Unauthorized
+         */
+        401: HTTPExceptionResponse;
+        /**
+         * Forbidden
+         */
+        403: HTTPExceptionResponse;
+        /**
+         * Not Found
+         */
+        404: HTTPExceptionResponse;
+        /**
+         * Conflict
+         */
+        409: HTTPExceptionResponse;
+        /**
+         * Validation Error
+         */
+        422: HTTPValidationError;
+      };
+    };
+  };
+  
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/dry_run": {
+    patch: {
+      req: PatchTaskInstanceDryRun1Data;
+      res: {
+        /**
+         * Successful Response
+         */
+        200: TaskInstanceCollectionResponse;
+        /**
+         * Bad Request
+         */
+        400: HTTPExceptionResponse;
+        /**
+         * Unauthorized
+         */
+        401: HTTPExceptionResponse;
+        /**
+         * Forbidden
+         */
+        403: HTTPExceptionResponse;
+        /**
+         * Not Found
+         */
+        404: HTTPExceptionResponse;
+        /**
+         * Conflict
+         */
+        409: HTTPExceptionResponse;
+        /**
+         * Validation Error
+         */
+        422: HTTPValidationError;
+      };
+    };
+  };
   
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/logs/{try_number}":
 {
     get: {
       req: GetLogData;
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py 
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index bc2cef7c03e..ebf718340d3 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -2693,7 +2693,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             self.ENDPOINT_URL,
             json={
-                "dry_run": False,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2743,68 +2742,12 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
             task_id=self.TASK_ID,
         )
 
-    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
-    def test_should_not_call_mocked_api_for_dry_run(self, 
mock_set_task_instance_state, test_client, session):
-        self.create_task_instances(session)
-
-        mock_set_task_instance_state.return_value = session.scalars(
-            select(TaskInstance).where(
-                TaskInstance.dag_id == self.DAG_ID,
-                TaskInstance.task_id == self.TASK_ID,
-                TaskInstance.run_id == self.RUN_ID,
-                TaskInstance.map_index == -1,
-            )
-        ).one_or_none()
-
-        response = test_client.patch(
-            self.ENDPOINT_URL,
-            json={
-                "dry_run": True,
-                "new_state": self.NEW_STATE,
-            },
-        )
-        assert response.status_code == 200
-        assert response.json() == {
-            "dag_id": self.DAG_ID,
-            "dag_run_id": self.RUN_ID,
-            "logical_date": "2020-01-01T00:00:00Z",
-            "task_id": self.TASK_ID,
-            "duration": 10000.0,
-            "end_date": "2020-01-03T00:00:00Z",
-            "executor": None,
-            "executor_config": "{}",
-            "hostname": "",
-            "id": mock.ANY,
-            "map_index": -1,
-            "max_tries": 0,
-            "note": "placeholder-note",
-            "operator": "PythonOperator",
-            "pid": 100,
-            "pool": "default_pool",
-            "pool_slots": 1,
-            "priority_weight": 9,
-            "queue": "default_queue",
-            "queued_when": None,
-            "start_date": "2020-01-02T00:00:00Z",
-            "state": "running",
-            "task_display_name": self.TASK_ID,
-            "try_number": 0,
-            "unixname": getuser(),
-            "rendered_fields": {},
-            "rendered_map_index": None,
-            "trigger": None,
-            "triggerer_job": None,
-        }
-
-        mock_set_task_instance_state.assert_not_called()
-
     def test_should_update_task_instance_state(self, test_client, session):
         self.create_task_instances(session)
 
         test_client.patch(
             self.ENDPOINT_URL,
             json={
-                "dry_run": False,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2813,20 +2756,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         assert response2.status_code == 200
         assert response2.json()["state"] == self.NEW_STATE
 
-    def test_should_update_task_instance_state_default_dry_run_to_true(self, 
test_client, session):
-        self.create_task_instances(session)
-
-        test_client.patch(
-            self.ENDPOINT_URL,
-            json={
-                "new_state": self.NEW_STATE,
-            },
-        )
-
-        response2 = test_client.get(self.ENDPOINT_URL)
-        assert response2.status_code == 200
-        assert response2.json()["state"] == "running"  # no change in state
-
     def test_should_update_mapped_task_instance_state(self, test_client, 
session):
         map_index = 1
         tis = self.create_task_instances(session)
@@ -2838,7 +2767,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             f"{self.ENDPOINT_URL}/{map_index}",
             json={
-                "dry_run": False,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2858,7 +2786,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
                 ),
                 404,
                 {
-                    "dry_run": True,
                     "new_state": "failed",
                 },
             ]
@@ -2877,7 +2804,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             self.ENDPOINT_URL,
             json={
-                "dryrun": True,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2887,7 +2813,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             
"/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context",
             json={
-                "dry_run": False,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2898,7 +2823,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task",
             json={
-                "dry_run": False,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2911,7 +2835,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             self.ENDPOINT_URL,
             json={
-                "dry_run": True,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2921,7 +2844,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             self.ENDPOINT_URL,
             json={
-                "dry_run": True,
                 "new_state": self.NEW_STATE,
             },
         )
@@ -2932,14 +2854,12 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         [
             (
                 {
-                    "dry_run": True,
                     "new_state": "failede",
                 },
                 f"'failede' is not one of ['{State.SUCCESS}', 
'{State.FAILED}', '{State.SKIPPED}']",
             ),
             (
                 {
-                    "dry_run": True,
                     "new_state": "queued",
                 },
                 f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}', 
'{State.SKIPPED}']",
@@ -3048,7 +2968,6 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
             self.ENDPOINT_URL,
             params={"update_mask": "new_state"},
             json={
-                "dry_run": False,
                 "new_state": new_state,
             },
         )
@@ -3222,7 +3141,367 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         response = test_client.patch(
             self.ENDPOINT_URL,
             json={
-                "dry_run": False,
+                "new_state": "success",
+            },
+        )
+        assert response.status_code == 409
+        assert "Task id print_the_context is already in success state" in 
response.text
+
+
+class TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
+    ENDPOINT_URL = (
+        
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
+    )
+    NEW_STATE = "failed"
+    DAG_ID = "example_python_operator"
+    TASK_ID = "print_the_context"
+    RUN_ID = "TEST_DAG_RUN_ID"
+
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_should_call_mocked_api(self, mock_set_ti_state, test_client, 
session):
+        self.create_task_instances(session)
+
+        mock_set_ti_state.return_value = [
+            session.scalars(
+                select(TaskInstance).where(
+                    TaskInstance.dag_id == self.DAG_ID,
+                    TaskInstance.task_id == self.TASK_ID,
+                    TaskInstance.run_id == self.RUN_ID,
+                    TaskInstance.map_index == -1,
+                )
+            ).one_or_none()
+        ]
+
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json={
+                "new_state": self.NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+        assert response.json() == {
+            "task_instances": [
+                {
+                    "dag_id": self.DAG_ID,
+                    "dag_run_id": self.RUN_ID,
+                    "logical_date": "2020-01-01T00:00:00Z",
+                    "task_id": self.TASK_ID,
+                    "duration": 10000.0,
+                    "end_date": "2020-01-03T00:00:00Z",
+                    "executor": None,
+                    "executor_config": "{}",
+                    "hostname": "",
+                    "id": mock.ANY,
+                    "map_index": -1,
+                    "max_tries": 0,
+                    "note": "placeholder-note",
+                    "operator": "PythonOperator",
+                    "pid": 100,
+                    "pool": "default_pool",
+                    "pool_slots": 1,
+                    "priority_weight": 9,
+                    "queue": "default_queue",
+                    "queued_when": None,
+                    "start_date": "2020-01-02T00:00:00Z",
+                    "state": "running",
+                    "task_display_name": self.TASK_ID,
+                    "try_number": 0,
+                    "unixname": getuser(),
+                    "rendered_fields": {},
+                    "rendered_map_index": None,
+                    "trigger": None,
+                    "triggerer_job": None,
+                }
+            ],
+            "total_entries": 1,
+        }
+
+        mock_set_ti_state.assert_called_once_with(
+            commit=False,
+            downstream=False,
+            upstream=False,
+            future=False,
+            map_indexes=[-1],
+            past=False,
+            run_id=self.RUN_ID,
+            session=mock.ANY,
+            state=self.NEW_STATE,
+            task_id=self.TASK_ID,
+        )
+
+    @pytest.mark.parametrize(
+        "payload",
+        [
+            {
+                "new_state": "success",
+            },
+            {
+                "note": "something",
+            },
+            {
+                "new_state": "success",
+                "note": "something",
+            },
+        ],
+    )
+    def test_should_not_update(self, test_client, session, payload):
+        self.create_task_instances(session)
+
+        task_before = test_client.get(self.ENDPOINT_URL).json()
+
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json=payload,
+        )
+
+        assert response.status_code == 200
+        assert [ti["task_id"] for ti in response.json()["task_instances"]] == 
["print_the_context"]
+
+        task_after = test_client.get(self.ENDPOINT_URL).json()
+
+        assert task_before == task_after
+
+    def test_should_not_update_mapped_task_instance(self, test_client, 
session):
+        map_index = 1
+        tis = self.create_task_instances(session)
+        ti = TaskInstance(task=tis[0].task, run_id=tis[0].run_id, 
map_index=map_index)
+        ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
+        session.add(ti)
+        session.commit()
+
+        task_before = 
test_client.get(f"{self.ENDPOINT_URL}/{map_index}").json()
+
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/{map_index}/dry_run",
+            json={
+                "new_state": self.NEW_STATE,
+            },
+        )
+
+        assert response.status_code == 200
+        assert [ti["task_id"] for ti in response.json()["task_instances"]] == 
["print_the_context"]
+
+        task_after = test_client.get(f"{self.ENDPOINT_URL}/{map_index}").json()
+
+        assert task_before == task_after
+
+    @pytest.mark.parametrize(
+        "error, code, payload",
+        [
+            [
+                (
+                    "Task Instance not found for 
dag_id=example_python_operator"
+                    ", run_id=TEST_DAG_RUN_ID, task_id=print_the_context"
+                ),
+                404,
+                {
+                    "new_state": "failed",
+                },
+            ]
+        ],
+    )
+    def test_should_handle_errors(self, error, code, payload, test_client, 
session):
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json=payload,
+        )
+        assert response.status_code == code
+        assert response.json()["detail"] == error
+
+    def test_should_200_for_unknown_fields(self, test_client, session):
+        self.create_task_instances(session)
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json={
+                "new_state": self.NEW_STATE,
+            },
+        )
+        assert response.status_code == 200
+
+    def test_should_raise_404_for_non_existent_dag(self, test_client):
+        response = test_client.patch(
+            
"/public/dags/non-existent-dag/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/dry_run",
+            json={
+                "new_state": self.NEW_STATE,
+            },
+        )
+        assert response.status_code == 404
+        assert response.json() == {"detail": "DAG non-existent-dag not found"}
+
+    def test_should_raise_404_for_non_existent_task_in_dag(self, test_client):
+        response = test_client.patch(
+            
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/non_existent_task/dry_run",
+            json={
+                "new_state": self.NEW_STATE,
+            },
+        )
+        assert response.status_code == 404
+        assert response.json() == {
+            "detail": "Task 'non_existent_task' not found in DAG 
'example_python_operator'"
+        }
+
+    def test_should_raise_404_not_found_dag(self, test_client):
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json={
+                "new_state": self.NEW_STATE,
+            },
+        )
+        assert response.status_code == 404
+
+    def test_should_raise_404_not_found_task(self, test_client):
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json={
+                "new_state": self.NEW_STATE,
+            },
+        )
+        assert response.status_code == 404
+
+    @pytest.mark.parametrize(
+        "payload, expected",
+        [
+            (
+                {
+                    "new_state": "failede",
+                },
+                f"'failede' is not one of ['{State.SUCCESS}', 
'{State.FAILED}', '{State.SKIPPED}']",
+            ),
+            (
+                {
+                    "new_state": "queued",
+                },
+                f"'queued' is not one of ['{State.SUCCESS}', '{State.FAILED}', 
'{State.SKIPPED}']",
+            ),
+        ],
+    )
+    def test_should_raise_422_for_invalid_task_instance_state(self, payload, 
expected, test_client, session):
+        self.create_task_instances(session)
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json=payload,
+        )
+        assert response.status_code == 422
+        assert response.json() == {
+            "detail": [
+                {
+                    "type": "value_error",
+                    "loc": ["body", "new_state"],
+                    "msg": f"Value error, {expected}",
+                    "input": payload["new_state"],
+                    "ctx": {"error": {}},
+                }
+            ]
+        }
+
+    @pytest.mark.parametrize(
+        "new_state,expected_status_code,expected_json,set_ti_state_call_count",
+        [
+            (
+                "failed",
+                200,
+                {
+                    "task_instances": [
+                        {
+                            "dag_id": "example_python_operator",
+                            "dag_run_id": "TEST_DAG_RUN_ID",
+                            "logical_date": "2020-01-01T00:00:00Z",
+                            "task_id": "print_the_context",
+                            "duration": 10000.0,
+                            "end_date": "2020-01-03T00:00:00Z",
+                            "executor": None,
+                            "executor_config": "{}",
+                            "hostname": "",
+                            "id": mock.ANY,
+                            "map_index": -1,
+                            "max_tries": 0,
+                            "note": "placeholder-note",
+                            "operator": "PythonOperator",
+                            "pid": 100,
+                            "pool": "default_pool",
+                            "pool_slots": 1,
+                            "priority_weight": 9,
+                            "queue": "default_queue",
+                            "queued_when": None,
+                            "start_date": "2020-01-02T00:00:00Z",
+                            "state": "running",
+                            "task_display_name": "print_the_context",
+                            "try_number": 0,
+                            "unixname": getuser(),
+                            "rendered_fields": {},
+                            "rendered_map_index": None,
+                            "trigger": None,
+                            "triggerer_job": None,
+                        }
+                    ],
+                    "total_entries": 1,
+                },
+                1,
+            ),
+            (
+                None,
+                422,
+                {
+                    "detail": [
+                        {
+                            "type": "value_error",
+                            "loc": ["body", "new_state"],
+                            "msg": "Value error, 'new_state' should not be 
empty",
+                            "input": None,
+                            "ctx": {"error": {}},
+                        }
+                    ]
+                },
+                0,
+            ),
+        ],
+    )
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_update_mask_should_call_mocked_api(
+        self,
+        mock_set_ti_state,
+        test_client,
+        session,
+        new_state,
+        expected_status_code,
+        expected_json,
+        set_ti_state_call_count,
+    ):
+        self.create_task_instances(session)
+
+        mock_set_ti_state.return_value = [
+            session.scalars(
+                select(TaskInstance).where(
+                    TaskInstance.dag_id == self.DAG_ID,
+                    TaskInstance.task_id == self.TASK_ID,
+                    TaskInstance.run_id == self.RUN_ID,
+                    TaskInstance.map_index == -1,
+                )
+            ).one_or_none()
+        ]
+
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            params={"update_mask": "new_state"},
+            json={
+                "new_state": new_state,
+            },
+        )
+        assert response.status_code == expected_status_code
+        assert response.json() == expected_json
+        assert mock_set_ti_state.call_count == set_ti_state_call_count
+
+    @mock.patch("airflow.models.dag.DAG.set_task_instance_state")
+    def test_should_raise_409_for_updating_same_task_instance_state(
+        self, mock_set_ti_state, test_client, session
+    ):
+        self.create_task_instances(session)
+
+        mock_set_ti_state.return_value = None
+
+        response = test_client.patch(
+            f"{self.ENDPOINT_URL}/dry_run",
+            json={
                 "new_state": "success",
             },
         )


Reply via email to