This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new adb4d2adc2 AIP-84 Get Mapped Task Instance (#43548)
adb4d2adc2 is described below
commit adb4d2adc2aeff90b2f770a9a81c9c5e2b72aff4
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Fri Nov 1 04:02:42 2024 +0800
AIP-84 Get Mapped Task Instance (#43548)
---
.../endpoints/task_instance_endpoint.py | 1 +
.../api_fastapi/core_api/openapi/v1-generated.yaml | 63 +++++++++++++++++++++
.../core_api/routes/public/task_instances.py | 29 +++++++++-
airflow/ui/openapi-gen/queries/common.ts | 26 +++++++++
airflow/ui/openapi-gen/queries/prefetch.ts | 40 +++++++++++++
airflow/ui/openapi-gen/queries/queries.ts | 44 +++++++++++++++
airflow/ui/openapi-gen/queries/suspense.ts | 44 +++++++++++++++
airflow/ui/openapi-gen/requests/services.gen.ts | 34 +++++++++++
airflow/ui/openapi-gen/requests/types.gen.ts | 36 ++++++++++++
.../core_api/routes/public/test_task_instances.py | 66 ++++++++++++++++++++++
10 files changed, 382 insertions(+), 1 deletion(-)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index 2ba236b065..b862ed1469 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -104,6 +104,7 @@ def get_task_instance(
return task_instance_schema.dump(task_instance)
+@mark_fastapi_migration_done
@security.requires_access_dag("GET", DagAccessEntity.TASK_INSTANCE)
@provide_session
def get_mapped_task_instance(
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 4295211546..b7e7f62693 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -1491,6 +1491,69 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}:
+ get:
+ tags:
+ - Task Instance
+ summary: Get Mapped Task Instance
+ description: Get task instance.
+ operationId: get_mapped_task_instance
+ parameters:
+ - name: dag_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Id
+ - name: dag_run_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Run Id
+ - name: task_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Task Id
+ - name: map_index
+ in: path
+ required: true
+ schema:
+ type: integer
+ title: Map Index
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstanceResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
/public/variables/{variable_key}:
delete:
tags:
diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow/api_fastapi/core_api/routes/public/task_instances.py
index 7e0c6fa894..c9458e843a 100644
--- a/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -44,7 +44,6 @@ async def get_task_instance(
.join(TI.dag_run)
.options(joinedload(TI.rendered_task_instance_fields))
)
-
task_instance = session.scalar(query)
if task_instance is None:
@@ -56,3 +55,31 @@ async def get_task_instance(
raise HTTPException(404, "Task instance is mapped, add the map_index
value to the URL")
return TaskInstanceResponse.model_validate(task_instance,
from_attributes=True)
+
+
+@task_instances_router.get(
+ "/{task_id}/{map_index}",
responses=create_openapi_http_exception_doc([401, 403, 404])
+)
+async def get_mapped_task_instance(
+ dag_id: str,
+ dag_run_id: str,
+ task_id: str,
+ map_index: int,
+ session: Annotated[Session, Depends(get_session)],
+) -> TaskInstanceResponse:
+ """Get task instance."""
+ query = (
+ select(TI)
+ .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index == map_index)
+ .join(TI.dag_run)
+ .options(joinedload(TI.rendered_task_instance_fields))
+ )
+ task_instance = session.scalar(query)
+
+ if task_instance is None:
+ raise HTTPException(
+ 404,
+ f"The Mapped Task Instance with dag_id: `{dag_id}`, run_id:
`{dag_run_id}`, task_id: `{task_id}`, and map_index: `{map_index}` was not
found",
+ )
+
+ return TaskInstanceResponse.model_validate(task_instance,
from_attributes=True)
diff --git a/airflow/ui/openapi-gen/queries/common.ts
b/airflow/ui/openapi-gen/queries/common.ts
index 875e2b87f3..07edb67a99 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -423,6 +423,32 @@ export const UseTaskInstanceServiceGetTaskInstanceKeyFn = (
useTaskInstanceServiceGetTaskInstanceKey,
...(queryKey ?? [{ dagId, dagRunId, taskId }]),
];
+export type TaskInstanceServiceGetMappedTaskInstanceDefaultResponse = Awaited<
+ ReturnType<typeof TaskInstanceService.getMappedTaskInstance>
+>;
+export type TaskInstanceServiceGetMappedTaskInstanceQueryResult<
+ TData = TaskInstanceServiceGetMappedTaskInstanceDefaultResponse,
+ TError = unknown,
+> = UseQueryResult<TData, TError>;
+export const useTaskInstanceServiceGetMappedTaskInstanceKey =
+ "TaskInstanceServiceGetMappedTaskInstance";
+export const UseTaskInstanceServiceGetMappedTaskInstanceKeyFn = (
+ {
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ taskId: string;
+ },
+ queryKey?: Array<unknown>,
+) => [
+ useTaskInstanceServiceGetMappedTaskInstanceKey,
+ ...(queryKey ?? [{ dagId, dagRunId, mapIndex, taskId }]),
+];
export type VariableServiceGetVariableDefaultResponse = Awaited<
ReturnType<typeof VariableService.getVariable>
>;
diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts
b/airflow/ui/openapi-gen/queries/prefetch.ts
index 795c8770b8..db61369e19 100644
--- a/airflow/ui/openapi-gen/queries/prefetch.ts
+++ b/airflow/ui/openapi-gen/queries/prefetch.ts
@@ -530,6 +530,46 @@ export const prefetchUseTaskInstanceServiceGetTaskInstance
= (
queryFn: () =>
TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }),
});
+/**
+ * Get Mapped Task Instance
+ * Get task instance.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.mapIndex
+ * @returns TaskInstanceResponse Successful Response
+ * @throws ApiError
+ */
+export const prefetchUseTaskInstanceServiceGetMappedTaskInstance = (
+ queryClient: QueryClient,
+ {
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ taskId: string;
+ },
+) =>
+ queryClient.prefetchQuery({
+ queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstanceKeyFn({
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }),
+ queryFn: () =>
+ TaskInstanceService.getMappedTaskInstance({
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }),
+ });
/**
* Get Variable
* Get a variable entry.
diff --git a/airflow/ui/openapi-gen/queries/queries.ts
b/airflow/ui/openapi-gen/queries/queries.ts
index afb3fddaef..7820656799 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -680,6 +680,50 @@ export const useTaskInstanceServiceGetTaskInstance = <
TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }) as
TData,
...options,
});
+/**
+ * Get Mapped Task Instance
+ * Get task instance.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.mapIndex
+ * @returns TaskInstanceResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServiceGetMappedTaskInstance = <
+ TData = Common.TaskInstanceServiceGetMappedTaskInstanceDefaultResponse,
+ TError = unknown,
+ TQueryKey extends Array<unknown> = unknown[],
+>(
+ {
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ taskId: string;
+ },
+ queryKey?: TQueryKey,
+ options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
+) =>
+ useQuery<TData, TError>({
+ queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstanceKeyFn(
+ { dagId, dagRunId, mapIndex, taskId },
+ queryKey,
+ ),
+ queryFn: () =>
+ TaskInstanceService.getMappedTaskInstance({
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }) as TData,
+ ...options,
+ });
/**
* Get Variable
* Get a variable entry.
diff --git a/airflow/ui/openapi-gen/queries/suspense.ts
b/airflow/ui/openapi-gen/queries/suspense.ts
index ab8dfbabcc..2cb0841d71 100644
--- a/airflow/ui/openapi-gen/queries/suspense.ts
+++ b/airflow/ui/openapi-gen/queries/suspense.ts
@@ -668,6 +668,50 @@ export const useTaskInstanceServiceGetTaskInstanceSuspense
= <
TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }) as
TData,
...options,
});
+/**
+ * Get Mapped Task Instance
+ * Get task instance.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.mapIndex
+ * @returns TaskInstanceResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServiceGetMappedTaskInstanceSuspense = <
+ TData = Common.TaskInstanceServiceGetMappedTaskInstanceDefaultResponse,
+ TError = unknown,
+ TQueryKey extends Array<unknown> = unknown[],
+>(
+ {
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ taskId: string;
+ },
+ queryKey?: TQueryKey,
+ options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
+) =>
+ useSuspenseQuery<TData, TError>({
+ queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstanceKeyFn(
+ { dagId, dagRunId, mapIndex, taskId },
+ queryKey,
+ ),
+ queryFn: () =>
+ TaskInstanceService.getMappedTaskInstance({
+ dagId,
+ dagRunId,
+ mapIndex,
+ taskId,
+ }) as TData,
+ ...options,
+ });
/**
* Get Variable
* Get a variable entry.
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow/ui/openapi-gen/requests/services.gen.ts
index fd38b2ec31..486e04b056 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -56,6 +56,8 @@ import type {
GetProvidersResponse,
GetTaskInstanceData,
GetTaskInstanceResponse,
+ GetMappedTaskInstanceData,
+ GetMappedTaskInstanceResponse,
DeleteVariableData,
DeleteVariableResponse,
GetVariableData,
@@ -874,6 +876,38 @@ export class TaskInstanceService {
},
});
}
+
+ /**
+ * Get Mapped Task Instance
+ * Get task instance.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.mapIndex
+ * @returns TaskInstanceResponse Successful Response
+ * @throws ApiError
+ */
+ public static getMappedTaskInstance(
+ data: GetMappedTaskInstanceData,
+ ): CancelablePromise<GetMappedTaskInstanceResponse> {
+ return __request(OpenAPI, {
+ method: "GET",
+ url:
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}",
+ path: {
+ dag_id: data.dagId,
+ dag_run_id: data.dagRunId,
+ task_id: data.taskId,
+ map_index: data.mapIndex,
+ },
+ errors: {
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 404: "Not Found",
+ 422: "Validation Error",
+ },
+ });
+ }
}
export class VariableService {
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index 064163f417..0580694ba7 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -827,6 +827,15 @@ export type GetTaskInstanceData = {
export type GetTaskInstanceResponse = TaskInstanceResponse;
+export type GetMappedTaskInstanceData = {
+ dagId: string;
+ dagRunId: string;
+ mapIndex: number;
+ taskId: string;
+};
+
+export type GetMappedTaskInstanceResponse = TaskInstanceResponse;
+
export type DeleteVariableData = {
variableKey: string;
};
@@ -1528,6 +1537,33 @@ export type $OpenApiTs = {
};
};
};
+
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}":
{
+ get: {
+ req: GetMappedTaskInstanceData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: TaskInstanceResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
+ };
"/public/variables/{variable_key}": {
delete: {
req: DeleteVariableData;
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index 85b4639d6c..fa9cc0b161 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -394,3 +394,69 @@ class TestGetTaskInstance(TestTaskInstanceEndpoint):
)
assert response.status_code == 404
assert response.json() == {"detail": "Task instance is mapped, add the
map_index value to the URL"}
+
+
+class TestGetMappedTaskInstance(TestTaskInstanceEndpoint):
+ def test_should_respond_200_mapped_task_instance_with_rtif(self,
test_client, session):
+ """Verify we don't duplicate rows through join to RTIF"""
+ tis = self.create_task_instances(session)
+ old_ti = tis[0]
+ for idx in (1, 2):
+ ti = TaskInstance(task=old_ti.task, run_id=old_ti.run_id,
map_index=idx)
+ ti.rendered_task_instance_fields = RTIF(ti, render_templates=False)
+ for attr in ["duration", "end_date", "pid", "start_date", "state",
"queue", "note"]:
+ setattr(ti, attr, getattr(old_ti, attr))
+ session.add(ti)
+ session.commit()
+
+ # in each loop, we should get the right mapped TI back
+ for map_index in (1, 2):
+ response = test_client.get(
+
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
+ f"/print_the_context/{map_index}",
+ )
+ assert response.status_code == 200
+
+ assert response.json() == {
+ "dag_id": "example_python_operator",
+ "duration": 10000.0,
+ "end_date": "2020-01-03T00:00:00Z",
+ "logical_date": "2020-01-01T00:00:00Z",
+ "executor": None,
+ "executor_config": "{}",
+ "hostname": "",
+ "map_index": map_index,
+ "max_tries": 0,
+ "note": "placeholder-note",
+ "operator": "PythonOperator",
+ "pid": 100,
+ "pool": "default_pool",
+ "pool_slots": 1,
+ "priority_weight": 9,
+ "queue": "default_queue",
+ "queued_when": None,
+ "start_date": "2020-01-02T00:00:00Z",
+ "state": "running",
+ "task_id": "print_the_context",
+ "task_display_name": "print_the_context",
+ "try_number": 0,
+ "unixname": getuser(),
+ "dag_run_id": "TEST_DAG_RUN_ID",
+ "rendered_fields": {"op_args": [], "op_kwargs": {},
"templates_dict": None},
+ "rendered_map_index": None,
+ "trigger": None,
+ "triggerer_job": None,
+ }
+
+ def test_should_respond_404_wrong_map_index(self, test_client, session):
+ self.create_task_instances(session)
+
+ response = test_client.get(
+
"/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances"
+ "/print_the_context/10",
+ )
+ assert response.status_code == 404
+
+ assert response.json() == {
+ "detail": "The Mapped Task Instance with dag_id:
`example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id:
`print_the_context`, and map_index: `10` was not found"
+ }