This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 31b10011f6 AIP-84 List Mapped Task Instances (#43642)
31b10011f6 is described below
commit 31b10011f64520ae3f393816f56636276d2a1d03
Author: Pierre Jeambrun <[email protected]>
AuthorDate: Thu Nov 7 22:53:00 2024 +0800
AIP-84 List Mapped Task Instances (#43642)
* AIP-84 List Mapped Task Instances
* Update following code review
---
.../endpoints/task_instance_endpoint.py | 1 +
airflow/api_fastapi/common/parameters.py | 151 ++++++++-
.../api_fastapi/core_api/openapi/v1-generated.yaml | 214 +++++++++++++
.../core_api/routes/public/task_instances.py | 110 ++++++-
airflow/ui/openapi-gen/queries/common.ts | 81 +++++
airflow/ui/openapi-gen/queries/prefetch.ts | 120 ++++++++
airflow/ui/openapi-gen/queries/queries.ts | 129 ++++++++
airflow/ui/openapi-gen/queries/suspense.ts | 129 ++++++++
airflow/ui/openapi-gen/requests/schemas.gen.ts | 20 ++
airflow/ui/openapi-gen/requests/services.gen.ts | 68 ++++
airflow/ui/openapi-gen/requests/types.gen.ts | 60 ++++
.../core_api/routes/public/test_task_instances.py | 342 +++++++++++++++++++--
12 files changed, 1393 insertions(+), 32 deletions(-)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index b862ed1469..1064991e9b 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -130,6 +130,7 @@ def get_mapped_task_instance(
return task_instance_schema.dump(task_instance)
+@mark_fastapi_migration_done
@format_parameters(
{
"execution_date_gte": format_datetime,
diff --git a/airflow/api_fastapi/common/parameters.py
b/airflow/api_fastapi/common/parameters.py
index a318aced83..18630e473d 100644
--- a/airflow/api_fastapi/common/parameters.py
+++ b/airflow/api_fastapi/common/parameters.py
@@ -19,22 +19,24 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from datetime import datetime
-from typing import TYPE_CHECKING, Any, Callable, Generic, List, TypeVar
+from typing import TYPE_CHECKING, Any, Callable, Generic, List, Optional,
TypeVar
from fastapi import Depends, HTTPException, Query
from pendulum.parsing.exceptions import ParserError
-from pydantic import AfterValidator
+from pydantic import AfterValidator, BaseModel
from sqlalchemy import Column, case, or_
from sqlalchemy.inspection import inspect
from typing_extensions import Annotated, Self
+from airflow.api_connexion.endpoints.task_instance_endpoint import
_convert_ti_states
from airflow.models import Base, Connection
from airflow.models.dag import DagModel, DagTag
from airflow.models.dagrun import DagRun
from airflow.models.dagwarning import DagWarning, DagWarningType
from airflow.models.errors import ParseImportError
+from airflow.models.taskinstance import TaskInstance
from airflow.utils import timezone
-from airflow.utils.state import DagRunState
+from airflow.utils.state import DagRunState, TaskInstanceState
if TYPE_CHECKING:
from sqlalchemy.sql import ColumnElement, Select
@@ -45,8 +47,8 @@ T = TypeVar("T")
class BaseParam(Generic[T], ABC):
"""Base class for filters."""
- def __init__(self, skip_none: bool = True) -> None:
- self.value: T | None = None
+ def __init__(self, value: T | None = None, skip_none: bool = True) -> None:
+ self.value = value
self.attribute: ColumnElement | None = None
self.skip_none = skip_none
@@ -128,7 +130,7 @@ class _SearchParam(BaseParam[str]):
"""Search on attribute."""
def __init__(self, attribute: ColumnElement, skip_none: bool = True) ->
None:
- super().__init__(skip_none)
+ super().__init__(skip_none=skip_none)
self.attribute: ColumnElement = attribute
def to_orm(self, select: Select) -> Select:
@@ -227,8 +229,8 @@ class SortParam(BaseParam[str]):
def depends(self, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use dynamic_depends, depends not
implemented.")
- def dynamic_depends(self) -> Callable:
- def inner(order_by: str = self.get_primary_key_string()) -> SortParam:
+ def dynamic_depends(self, default: str | None = None) -> Callable:
+ def inner(order_by: str = default or self.get_primary_key_string()) ->
SortParam:
return self.set_value(self.get_primary_key_string() if order_by ==
"" else order_by)
return inner
@@ -268,6 +270,75 @@ class _OwnersFilter(BaseParam[List[str]]):
return self.set_value(owners)
+class _TIStateFilter(BaseParam[List[Optional[TaskInstanceState]]]):
+ """Filter on task instance state."""
+
+ def to_orm(self, select: Select) -> Select:
+ if self.skip_none is False:
+ raise ValueError(f"Cannot set 'skip_none' to False on a
{type(self)}")
+
+ if not self.value:
+ return select
+
+ conditions = [TaskInstance.state == state for state in self.value]
+ return select.where(or_(*conditions))
+
+ def depends(self, state: list[str] = Query(default_factory=list)) ->
_TIStateFilter:
+ states = _convert_ti_states(state)
+ return self.set_value(states)
+
+
+class _TIPoolFilter(BaseParam[List[str]]):
+ """Filter on task instance pool."""
+
+ def to_orm(self, select: Select) -> Select:
+ if self.skip_none is False:
+ raise ValueError(f"Cannot set 'skip_none' to False on a
{type(self)}")
+
+ if not self.value:
+ return select
+
+ conditions = [TaskInstance.pool == pool for pool in self.value]
+ return select.where(or_(*conditions))
+
+ def depends(self, pool: list[str] = Query(default_factory=list)) ->
_TIPoolFilter:
+ return self.set_value(pool)
+
+
+class _TIQueueFilter(BaseParam[List[str]]):
+ """Filter on task instance queue."""
+
+ def to_orm(self, select: Select) -> Select:
+ if self.skip_none is False:
+ raise ValueError(f"Cannot set 'skip_none' to False on a
{type(self)}")
+
+ if not self.value:
+ return select
+
+ conditions = [TaskInstance.queue == queue for queue in self.value]
+ return select.where(or_(*conditions))
+
+ def depends(self, queue: list[str] = Query(default_factory=list)) ->
_TIQueueFilter:
+ return self.set_value(queue)
+
+
+class _TIExecutorFilter(BaseParam[List[str]]):
+ """Filter on task instance executor."""
+
+ def to_orm(self, select: Select) -> Select:
+ if self.skip_none is False:
+ raise ValueError(f"Cannot set 'skip_none' to False on a
{type(self)}")
+
+ if not self.value:
+ return select
+
+ conditions = [TaskInstance.executor == executor for executor in
self.value]
+ return select.where(or_(*conditions))
+
+ def depends(self, executor: list[str] = Query(default_factory=list)) ->
_TIExecutorFilter:
+ return self.set_value(executor)
+
+
class _LastDagRunStateFilter(BaseParam[DagRunState]):
"""Filter on the state of the latest DagRun."""
@@ -323,7 +394,7 @@ class _DagIdFilter(BaseParam[str]):
"""Filter on dag_id."""
def __init__(self, attribute: ColumnElement, skip_none: bool = True) ->
None:
- super().__init__(skip_none)
+ super().__init__(skip_none=skip_none)
self.attribute = attribute
def to_orm(self, select: Select) -> Select:
@@ -335,6 +406,63 @@ class _DagIdFilter(BaseParam[str]):
return self.set_value(dag_id)
+class Range(BaseModel, Generic[T]):
+ """Range with a lower and upper bound."""
+
+ lower_bound: T | None
+ upper_bound: T | None
+
+
+class RangeFilter(BaseParam[Range]):
+ """Filter on range in between the lower and upper bound."""
+
+ def __init__(self, value: Range | None, attribute: ColumnElement) -> None:
+ super().__init__(value)
+ self.attribute: ColumnElement = attribute
+
+ def to_orm(self, select: Select) -> Select:
+ if self.skip_none is False:
+ raise ValueError(f"Cannot set 'skip_none' to False on a
{type(self)}")
+
+ if self.value and self.value.lower_bound:
+ select = select.where(self.attribute >= self.value.lower_bound)
+ if self.value and self.value.upper_bound:
+ select = select.where(self.attribute <= self.value.upper_bound)
+ return select
+
+ def depends(self, *args: Any, **kwargs: Any) -> Self:
+ raise NotImplementedError("Use the `range_filter_factory` function to
create the dependency")
+
+
+def datetime_range_filter_factory(
+ filter_name: str, model: Base, attribute_name: str | None = None
+) -> Callable[[datetime | None, datetime | None], RangeFilter]:
+ def depends_datetime(
+ lower_bound: datetime | None = Query(alias=f"{filter_name}_gte",
default=None),
+ upper_bound: datetime | None = Query(alias=f"{filter_name}_lte",
default=None),
+ ) -> RangeFilter:
+ return RangeFilter(
+ Range(lower_bound=lower_bound, upper_bound=upper_bound),
+ getattr(model, attribute_name or filter_name),
+ )
+
+ return depends_datetime
+
+
+def float_range_filter_factory(
+ filter_name: str, model: Base
+) -> Callable[[float | None, float | None], RangeFilter]:
+ def depends_float(
+ lower_bound: float | None = Query(alias=f"{filter_name}_gte",
default=None),
+ upper_bound: float | None = Query(alias=f"{filter_name}_lte",
default=None),
+ ) -> RangeFilter:
+ return RangeFilter(
+ Range(lower_bound=lower_bound, upper_bound=upper_bound),
getattr(model, filter_name)
+ )
+
+ return depends_float
+
+
# Common Safe DateTime
DateTimeQuery = Annotated[str, AfterValidator(_safe_parse_datetime)]
@@ -363,3 +491,8 @@ QueryWarningTypeFilter = Annotated[_WarningTypeFilter,
Depends(_WarningTypeFilte
# DAGTags
QueryDagTagPatternSearch = Annotated[_DagTagNamePatternSearch,
Depends(_DagTagNamePatternSearch().depends)]
+# TI
+QueryTIStateFilter = Annotated[_TIStateFilter,
Depends(_TIStateFilter().depends)]
+QueryTIPoolFilter = Annotated[_TIPoolFilter, Depends(_TIPoolFilter().depends)]
+QueryTIQueueFilter = Annotated[_TIQueueFilter,
Depends(_TIQueueFilter().depends)]
+QueryTIExecutorFilter = Annotated[_TIExecutorFilter,
Depends(_TIExecutorFilter().depends)]
diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
index 1426a6b4cd..65fd1cb3c8 100644
--- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
+++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml
@@ -2209,6 +2209,204 @@ paths:
application/json:
schema:
$ref: '#/components/schemas/HTTPValidationError'
+
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped:
+ get:
+ tags:
+ - Task Instance
+ summary: Get Mapped Task Instances
+ description: Get list of mapped task instances.
+ operationId: get_mapped_task_instances
+ parameters:
+ - name: dag_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Id
+ - name: dag_run_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Dag Run Id
+ - name: task_id
+ in: path
+ required: true
+ schema:
+ type: string
+ title: Task Id
+ - name: logical_date_gte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Logical Date Gte
+ - name: logical_date_lte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Logical Date Lte
+ - name: start_date_gte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Start Date Gte
+ - name: start_date_lte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Start Date Lte
+ - name: end_date_gte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: End Date Gte
+ - name: end_date_lte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: End Date Lte
+ - name: updated_at_gte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Updated At Gte
+ - name: updated_at_lte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: string
+ format: date-time
+ - type: 'null'
+ title: Updated At Lte
+ - name: duration_gte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: number
+ - type: 'null'
+ title: Duration Gte
+ - name: duration_lte
+ in: query
+ required: false
+ schema:
+ anyOf:
+ - type: number
+ - type: 'null'
+ title: Duration Lte
+ - name: state
+ in: query
+ required: false
+ schema:
+ type: array
+ items:
+ type: string
+ title: State
+ - name: pool
+ in: query
+ required: false
+ schema:
+ type: array
+ items:
+ type: string
+ title: Pool
+ - name: queue
+ in: query
+ required: false
+ schema:
+ type: array
+ items:
+ type: string
+ title: Queue
+ - name: executor
+ in: query
+ required: false
+ schema:
+ type: array
+ items:
+ type: string
+ title: Executor
+ - name: limit
+ in: query
+ required: false
+ schema:
+ type: integer
+ default: 100
+ title: Limit
+ - name: offset
+ in: query
+ required: false
+ schema:
+ type: integer
+ default: 0
+ title: Offset
+ - name: order_by
+ in: query
+ required: false
+ schema:
+ type: string
+ default: map_index
+ title: Order By
+ responses:
+ '200':
+ description: Successful Response
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/TaskInstanceCollectionResponse'
+ '401':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Unauthorized
+ '403':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Forbidden
+ '404':
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPExceptionResponse'
+ description: Not Found
+ '422':
+ description: Validation Error
+ content:
+ application/json:
+ schema:
+ $ref: '#/components/schemas/HTTPValidationError'
/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}:
get:
tags:
@@ -4187,6 +4385,22 @@ components:
- latest_scheduler_heartbeat
title: SchedulerInfoSchema
description: Schema for Scheduler info.
+ TaskInstanceCollectionResponse:
+ properties:
+ task_instances:
+ items:
+ $ref: '#/components/schemas/TaskInstanceResponse'
+ type: array
+ title: Task Instances
+ total_entries:
+ type: integer
+ title: Total Entries
+ type: object
+ required:
+ - task_instances
+ - total_entries
+ title: TaskInstanceCollectionResponse
+ description: Task Instance Collection serializer for responses.
TaskInstanceResponse:
properties:
id:
diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py
b/airflow/api_fastapi/core_api/routes/public/task_instances.py
index df16c0bc45..389b02718e 100644
--- a/airflow/api_fastapi/core_api/routes/public/task_instances.py
+++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py
@@ -17,16 +17,33 @@
from __future__ import annotations
-from fastapi import Depends, HTTPException, status
+from fastapi import Depends, HTTPException, Request, status
from sqlalchemy.orm import Session, joinedload
from sqlalchemy.sql import select
from typing_extensions import Annotated
-from airflow.api_fastapi.common.db.common import get_session
+from airflow.api_fastapi.common.db.common import get_session, paginated_select
+from airflow.api_fastapi.common.parameters import (
+ QueryLimit,
+ QueryOffset,
+ QueryTIExecutorFilter,
+ QueryTIPoolFilter,
+ QueryTIQueueFilter,
+ QueryTIStateFilter,
+ RangeFilter,
+ SortParam,
+ datetime_range_filter_factory,
+ float_range_filter_factory,
+)
from airflow.api_fastapi.common.router import AirflowRouter
from airflow.api_fastapi.core_api.openapi.exceptions import
create_openapi_http_exception_doc
-from airflow.api_fastapi.core_api.serializers.task_instances import
TaskInstanceResponse
+from airflow.api_fastapi.core_api.serializers.task_instances import (
+ TaskInstanceCollectionResponse,
+ TaskInstanceResponse,
+)
+from airflow.exceptions import TaskNotFound
from airflow.models.taskinstance import TaskInstance as TI
+from airflow.utils.db import get_query_count
task_instances_router = AirflowRouter(
tags=["Task Instance"],
prefix="/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances"
@@ -64,6 +81,93 @@ async def get_task_instance(
return TaskInstanceResponse.model_validate(task_instance,
from_attributes=True)
+@task_instances_router.get(
+ "/{task_id}/listMapped",
+ responses=create_openapi_http_exception_doc(
+ [status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN,
status.HTTP_404_NOT_FOUND]
+ ),
+)
+async def get_mapped_task_instances(
+ dag_id: str,
+ dag_run_id: str,
+ task_id: str,
+ request: Request,
+ logical_date_range: Annotated[
+ RangeFilter, Depends(datetime_range_filter_factory("logical_date", TI,
"execution_date"))
+ ],
+ start_date_range: Annotated[RangeFilter,
Depends(datetime_range_filter_factory("start_date", TI))],
+ end_date_range: Annotated[RangeFilter,
Depends(datetime_range_filter_factory("end_date", TI))],
+ update_at_range: Annotated[RangeFilter,
Depends(datetime_range_filter_factory("updated_at", TI))],
+ duration_range: Annotated[RangeFilter,
Depends(float_range_filter_factory("duration", TI))],
+ state: QueryTIStateFilter,
+ pool: QueryTIPoolFilter,
+ queue: QueryTIQueueFilter,
+ executor: QueryTIExecutorFilter,
+ limit: QueryLimit,
+ offset: QueryOffset,
+ order_by: Annotated[
+ SortParam,
+ Depends(
+ SortParam(
+ ["id", "state", "duration", "start_date", "end_date",
"map_index", "rendered_map_index"],
+ TI,
+ ).dynamic_depends(default="map_index")
+ ),
+ ],
+ session: Annotated[Session, Depends(get_session)],
+) -> TaskInstanceCollectionResponse:
+ """Get list of mapped task instances."""
+ base_query = (
+ select(TI)
+ .where(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id ==
task_id, TI.map_index >= 0)
+ .join(TI.dag_run)
+ )
+ # 0 can mean a mapped TI that expanded to an empty list, so it is not an
automatic 404
+ unfiltered_total_count = get_query_count(base_query, session=session)
+ if unfiltered_total_count == 0:
+ dag = request.app.state.dag_bag.get_dag(dag_id)
+ if not dag:
+ error_message = f"DAG {dag_id} not found"
+ raise HTTPException(404, error_message)
+ try:
+ task = dag.get_task(task_id)
+ except TaskNotFound:
+ error_message = f"Task id {task_id} not found"
+ raise HTTPException(404, error_message)
+ if not task.get_needs_expansion():
+ error_message = f"Task id {task_id} is not mapped"
+ raise HTTPException(404, error_message)
+
+ task_instance_select, total_entries = paginated_select(
+ base_query,
+ [
+ logical_date_range,
+ start_date_range,
+ end_date_range,
+ update_at_range,
+ duration_range,
+ state,
+ pool,
+ queue,
+ executor,
+ ],
+ order_by,
+ offset,
+ limit,
+ session,
+ )
+
+ task_instances = session.scalars(task_instance_select).all()
+
+ return TaskInstanceCollectionResponse(
+ task_instances=[
+ TaskInstanceResponse.model_validate(task_instance,
from_attributes=True)
+ for task_instance in task_instances
+ ],
+ total_entries=total_entries,
+ )
+
+
@task_instances_router.get(
"/{task_id}/{map_index}",
responses=create_openapi_http_exception_doc(
diff --git a/airflow/ui/openapi-gen/queries/common.ts
b/airflow/ui/openapi-gen/queries/common.ts
index f1cb514682..4c1e6cd383 100644
--- a/airflow/ui/openapi-gen/queries/common.ts
+++ b/airflow/ui/openapi-gen/queries/common.ts
@@ -603,6 +603,87 @@ export const UseTaskInstanceServiceGetTaskInstanceKeyFn = (
useTaskInstanceServiceGetTaskInstanceKey,
...(queryKey ?? [{ dagId, dagRunId, taskId }]),
];
+export type TaskInstanceServiceGetMappedTaskInstancesDefaultResponse = Awaited<
+ ReturnType<typeof TaskInstanceService.getMappedTaskInstances>
+>;
+export type TaskInstanceServiceGetMappedTaskInstancesQueryResult<
+ TData = TaskInstanceServiceGetMappedTaskInstancesDefaultResponse,
+ TError = unknown,
+> = UseQueryResult<TData, TError>;
+export const useTaskInstanceServiceGetMappedTaskInstancesKey =
+ "TaskInstanceServiceGetMappedTaskInstances";
+export const UseTaskInstanceServiceGetMappedTaskInstancesKeyFn = (
+ {
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ durationGte?: number;
+ durationLte?: number;
+ endDateGte?: string;
+ endDateLte?: string;
+ executor?: string[];
+ limit?: number;
+ logicalDateGte?: string;
+ logicalDateLte?: string;
+ offset?: number;
+ orderBy?: string;
+ pool?: string[];
+ queue?: string[];
+ startDateGte?: string;
+ startDateLte?: string;
+ state?: string[];
+ taskId: string;
+ updatedAtGte?: string;
+ updatedAtLte?: string;
+ },
+ queryKey?: Array<unknown>,
+) => [
+ useTaskInstanceServiceGetMappedTaskInstancesKey,
+ ...(queryKey ?? [
+ {
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ },
+ ]),
+];
export type TaskInstanceServiceGetMappedTaskInstanceDefaultResponse = Awaited<
ReturnType<typeof TaskInstanceService.getMappedTaskInstance>
>;
diff --git a/airflow/ui/openapi-gen/queries/prefetch.ts
b/airflow/ui/openapi-gen/queries/prefetch.ts
index 04443427ac..02b4137579 100644
--- a/airflow/ui/openapi-gen/queries/prefetch.ts
+++ b/airflow/ui/openapi-gen/queries/prefetch.ts
@@ -775,6 +775,126 @@ export const
prefetchUseTaskInstanceServiceGetTaskInstance = (
queryFn: () =>
TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }),
});
+/**
+ * Get Mapped Task Instances
+ * Get list of mapped task instances.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.logicalDateGte
+ * @param data.logicalDateLte
+ * @param data.startDateGte
+ * @param data.startDateLte
+ * @param data.endDateGte
+ * @param data.endDateLte
+ * @param data.updatedAtGte
+ * @param data.updatedAtLte
+ * @param data.durationGte
+ * @param data.durationLte
+ * @param data.state
+ * @param data.pool
+ * @param data.queue
+ * @param data.executor
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const prefetchUseTaskInstanceServiceGetMappedTaskInstances = (
+ queryClient: QueryClient,
+ {
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ durationGte?: number;
+ durationLte?: number;
+ endDateGte?: string;
+ endDateLte?: string;
+ executor?: string[];
+ limit?: number;
+ logicalDateGte?: string;
+ logicalDateLte?: string;
+ offset?: number;
+ orderBy?: string;
+ pool?: string[];
+ queue?: string[];
+ startDateGte?: string;
+ startDateLte?: string;
+ state?: string[];
+ taskId: string;
+ updatedAtGte?: string;
+ updatedAtLte?: string;
+ },
+) =>
+ queryClient.prefetchQuery({
+ queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstancesKeyFn({
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }),
+ queryFn: () =>
+ TaskInstanceService.getMappedTaskInstances({
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }),
+ });
/**
* Get Mapped Task Instance
* Get task instance.
diff --git a/airflow/ui/openapi-gen/queries/queries.ts
b/airflow/ui/openapi-gen/queries/queries.ts
index 1ce766d3af..05e9e972fc 100644
--- a/airflow/ui/openapi-gen/queries/queries.ts
+++ b/airflow/ui/openapi-gen/queries/queries.ts
@@ -966,6 +966,135 @@ export const useTaskInstanceServiceGetTaskInstance = <
TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }) as
TData,
...options,
});
+/**
+ * Get Mapped Task Instances
+ * Get list of mapped task instances.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.logicalDateGte
+ * @param data.logicalDateLte
+ * @param data.startDateGte
+ * @param data.startDateLte
+ * @param data.endDateGte
+ * @param data.endDateLte
+ * @param data.updatedAtGte
+ * @param data.updatedAtLte
+ * @param data.durationGte
+ * @param data.durationLte
+ * @param data.state
+ * @param data.pool
+ * @param data.queue
+ * @param data.executor
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServiceGetMappedTaskInstances = <
+ TData = Common.TaskInstanceServiceGetMappedTaskInstancesDefaultResponse,
+ TError = unknown,
+ TQueryKey extends Array<unknown> = unknown[],
+>(
+ {
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ durationGte?: number;
+ durationLte?: number;
+ endDateGte?: string;
+ endDateLte?: string;
+ executor?: string[];
+ limit?: number;
+ logicalDateGte?: string;
+ logicalDateLte?: string;
+ offset?: number;
+ orderBy?: string;
+ pool?: string[];
+ queue?: string[];
+ startDateGte?: string;
+ startDateLte?: string;
+ state?: string[];
+ taskId: string;
+ updatedAtGte?: string;
+ updatedAtLte?: string;
+ },
+ queryKey?: TQueryKey,
+ options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
+) =>
+ useQuery<TData, TError>({
+ queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstancesKeyFn(
+ {
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ },
+ queryKey,
+ ),
+ queryFn: () =>
+ TaskInstanceService.getMappedTaskInstances({
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }) as TData,
+ ...options,
+ });
/**
* Get Mapped Task Instance
* Get task instance.
diff --git a/airflow/ui/openapi-gen/queries/suspense.ts
b/airflow/ui/openapi-gen/queries/suspense.ts
index eed1a0afe8..a3c722e812 100644
--- a/airflow/ui/openapi-gen/queries/suspense.ts
+++ b/airflow/ui/openapi-gen/queries/suspense.ts
@@ -951,6 +951,135 @@ export const
useTaskInstanceServiceGetTaskInstanceSuspense = <
TaskInstanceService.getTaskInstance({ dagId, dagRunId, taskId }) as
TData,
...options,
});
+/**
+ * Get Mapped Task Instances
+ * Get list of mapped task instances.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.logicalDateGte
+ * @param data.logicalDateLte
+ * @param data.startDateGte
+ * @param data.startDateLte
+ * @param data.endDateGte
+ * @param data.endDateLte
+ * @param data.updatedAtGte
+ * @param data.updatedAtLte
+ * @param data.durationGte
+ * @param data.durationLte
+ * @param data.state
+ * @param data.pool
+ * @param data.queue
+ * @param data.executor
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+export const useTaskInstanceServiceGetMappedTaskInstancesSuspense = <
+ TData = Common.TaskInstanceServiceGetMappedTaskInstancesDefaultResponse,
+ TError = unknown,
+ TQueryKey extends Array<unknown> = unknown[],
+>(
+ {
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }: {
+ dagId: string;
+ dagRunId: string;
+ durationGte?: number;
+ durationLte?: number;
+ endDateGte?: string;
+ endDateLte?: string;
+ executor?: string[];
+ limit?: number;
+ logicalDateGte?: string;
+ logicalDateLte?: string;
+ offset?: number;
+ orderBy?: string;
+ pool?: string[];
+ queue?: string[];
+ startDateGte?: string;
+ startDateLte?: string;
+ state?: string[];
+ taskId: string;
+ updatedAtGte?: string;
+ updatedAtLte?: string;
+ },
+ queryKey?: TQueryKey,
+ options?: Omit<UseQueryOptions<TData, TError>, "queryKey" | "queryFn">,
+) =>
+ useSuspenseQuery<TData, TError>({
+ queryKey: Common.UseTaskInstanceServiceGetMappedTaskInstancesKeyFn(
+ {
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ },
+ queryKey,
+ ),
+ queryFn: () =>
+ TaskInstanceService.getMappedTaskInstances({
+ dagId,
+ dagRunId,
+ durationGte,
+ durationLte,
+ endDateGte,
+ endDateLte,
+ executor,
+ limit,
+ logicalDateGte,
+ logicalDateLte,
+ offset,
+ orderBy,
+ pool,
+ queue,
+ startDateGte,
+ startDateLte,
+ state,
+ taskId,
+ updatedAtGte,
+ updatedAtLte,
+ }) as TData,
+ ...options,
+ });
/**
* Get Mapped Task Instance
* Get task instance.
diff --git a/airflow/ui/openapi-gen/requests/schemas.gen.ts
b/airflow/ui/openapi-gen/requests/schemas.gen.ts
index 5635188d4f..9acae85c69 100644
--- a/airflow/ui/openapi-gen/requests/schemas.gen.ts
+++ b/airflow/ui/openapi-gen/requests/schemas.gen.ts
@@ -2448,6 +2448,26 @@ export const $SchedulerInfoSchema = {
description: "Schema for Scheduler info.",
} as const;
+export const $TaskInstanceCollectionResponse = {
+ properties: {
+ task_instances: {
+ items: {
+ $ref: "#/components/schemas/TaskInstanceResponse",
+ },
+ type: "array",
+ title: "Task Instances",
+ },
+ total_entries: {
+ type: "integer",
+ title: "Total Entries",
+ },
+ },
+ type: "object",
+ required: ["task_instances", "total_entries"],
+ title: "TaskInstanceCollectionResponse",
+ description: "Task Instance Collection serializer for responses.",
+} as const;
+
export const $TaskInstanceResponse = {
properties: {
id: {
diff --git a/airflow/ui/openapi-gen/requests/services.gen.ts
b/airflow/ui/openapi-gen/requests/services.gen.ts
index 6450029f56..2f1bc01533 100644
--- a/airflow/ui/openapi-gen/requests/services.gen.ts
+++ b/airflow/ui/openapi-gen/requests/services.gen.ts
@@ -80,6 +80,8 @@ import type {
GetProvidersResponse,
GetTaskInstanceData,
GetTaskInstanceResponse,
+ GetMappedTaskInstancesData,
+ GetMappedTaskInstancesResponse,
GetMappedTaskInstanceData,
GetMappedTaskInstanceResponse,
DeleteVariableData,
@@ -1265,6 +1267,72 @@ export class TaskInstanceService {
});
}
+ /**
+ * Get Mapped Task Instances
+ * Get list of mapped task instances.
+ * @param data The data for the request.
+ * @param data.dagId
+ * @param data.dagRunId
+ * @param data.taskId
+ * @param data.logicalDateGte
+ * @param data.logicalDateLte
+ * @param data.startDateGte
+ * @param data.startDateLte
+ * @param data.endDateGte
+ * @param data.endDateLte
+ * @param data.updatedAtGte
+ * @param data.updatedAtLte
+ * @param data.durationGte
+ * @param data.durationLte
+ * @param data.state
+ * @param data.pool
+ * @param data.queue
+ * @param data.executor
+ * @param data.limit
+ * @param data.offset
+ * @param data.orderBy
+ * @returns TaskInstanceCollectionResponse Successful Response
+ * @throws ApiError
+ */
+ public static getMappedTaskInstances(
+ data: GetMappedTaskInstancesData,
+ ): CancelablePromise<GetMappedTaskInstancesResponse> {
+ return __request(OpenAPI, {
+ method: "GET",
+ url:
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped",
+ path: {
+ dag_id: data.dagId,
+ dag_run_id: data.dagRunId,
+ task_id: data.taskId,
+ },
+ query: {
+ logical_date_gte: data.logicalDateGte,
+ logical_date_lte: data.logicalDateLte,
+ start_date_gte: data.startDateGte,
+ start_date_lte: data.startDateLte,
+ end_date_gte: data.endDateGte,
+ end_date_lte: data.endDateLte,
+ updated_at_gte: data.updatedAtGte,
+ updated_at_lte: data.updatedAtLte,
+ duration_gte: data.durationGte,
+ duration_lte: data.durationLte,
+ state: data.state,
+ pool: data.pool,
+ queue: data.queue,
+ executor: data.executor,
+ limit: data.limit,
+ offset: data.offset,
+ order_by: data.orderBy,
+ },
+ errors: {
+ 401: "Unauthorized",
+ 403: "Forbidden",
+ 404: "Not Found",
+ 422: "Validation Error",
+ },
+ });
+ }
+
/**
* Get Mapped Task Instance
* Get task instance.
diff --git a/airflow/ui/openapi-gen/requests/types.gen.ts
b/airflow/ui/openapi-gen/requests/types.gen.ts
index f5d47e0e08..419e60325f 100644
--- a/airflow/ui/openapi-gen/requests/types.gen.ts
+++ b/airflow/ui/openapi-gen/requests/types.gen.ts
@@ -592,6 +592,14 @@ export type SchedulerInfoSchema = {
latest_scheduler_heartbeat: string | null;
};
+/**
+ * Task Instance Collection serializer for responses.
+ */
+export type TaskInstanceCollectionResponse = {
+ task_instances: Array<TaskInstanceResponse>;
+ total_entries: number;
+};
+
/**
* TaskInstance serializer for responses.
*/
@@ -1032,6 +1040,31 @@ export type GetTaskInstanceData = {
export type GetTaskInstanceResponse = TaskInstanceResponse;
+export type GetMappedTaskInstancesData = {
+ dagId: string;
+ dagRunId: string;
+ durationGte?: number | null;
+ durationLte?: number | null;
+ endDateGte?: string | null;
+ endDateLte?: string | null;
+ executor?: Array<string>;
+ limit?: number;
+ logicalDateGte?: string | null;
+ logicalDateLte?: string | null;
+ offset?: number;
+ orderBy?: string;
+ pool?: Array<string>;
+ queue?: Array<string>;
+ startDateGte?: string | null;
+ startDateLte?: string | null;
+ state?: Array<string>;
+ taskId: string;
+ updatedAtGte?: string | null;
+ updatedAtLte?: string | null;
+};
+
+export type GetMappedTaskInstancesResponse = TaskInstanceCollectionResponse;
+
export type GetMappedTaskInstanceData = {
dagId: string;
dagRunId: string;
@@ -2070,6 +2103,33 @@ export type $OpenApiTs = {
};
};
};
+
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/listMapped":
{
+ get: {
+ req: GetMappedTaskInstancesData;
+ res: {
+ /**
+ * Successful Response
+ */
+ 200: TaskInstanceCollectionResponse;
+ /**
+ * Unauthorized
+ */
+ 401: HTTPExceptionResponse;
+ /**
+ * Forbidden
+ */
+ 403: HTTPExceptionResponse;
+ /**
+ * Not Found
+ */
+ 404: HTTPExceptionResponse;
+ /**
+ * Validation Error
+ */
+ 422: HTTPValidationError;
+ };
+ };
+ };
"/public/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/{map_index}":
{
get: {
req: GetMappedTaskInstanceData;
diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
index 0f27abd567..856dadccaf 100644
--- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py
+++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -18,7 +18,8 @@
from __future__ import annotations
import datetime as dt
-import urllib
+import itertools
+import os
from unittest import mock
import pendulum
@@ -27,25 +28,28 @@ import pytest
from airflow.jobs.job import Job
from airflow.jobs.triggerer_job_runner import TriggererJobRunner
from airflow.models import DagRun, TaskInstance
+from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
+from airflow.models.taskmap import TaskMap
from airflow.models.trigger import Trigger
from airflow.utils.platform import getuser
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
from airflow.utils.timezone import datetime
from airflow.utils.types import DagRunType
from tests_common.test_utils.db import clear_db_runs, clear_rendered_ti_fields
+from tests_common.test_utils.mock_operators import MockOperator
pytestmark = pytest.mark.db_test
-DEFAULT_DATETIME_1 = datetime(2020, 1, 1, tzinfo=dt.timezone.utc)
+DEFAULT = datetime(2020, 1, 1)
DEFAULT_DATETIME_STR_1 = "2020-01-01T00:00:00+00:00"
DEFAULT_DATETIME_STR_2 = "2020-01-02T00:00:00+00:00"
-QUOTED_DEFAULT_DATETIME_STR_1 = urllib.parse.quote(DEFAULT_DATETIME_STR_1)
-QUOTED_DEFAULT_DATETIME_STR_2 = urllib.parse.quote(DEFAULT_DATETIME_STR_2)
+DEFAULT_DATETIME_1 = dt.datetime.fromisoformat(DEFAULT_DATETIME_STR_1)
+DEFAULT_DATETIME_2 = dt.datetime.fromisoformat(DEFAULT_DATETIME_STR_2)
class TestTaskInstanceEndpoint:
@@ -57,7 +61,7 @@ class TestTaskInstanceEndpoint:
@pytest.fixture(autouse=True)
def setup_attrs(self, session) -> None:
- self.default_time = DEFAULT_DATETIME_1
+ self.default_time = DEFAULT
self.ti_init = {
"execution_date": self.default_time,
"state": State.RUNNING,
@@ -143,20 +147,6 @@ class TestTaskInstanceEndpoint:
session.commit()
return tis
- session.commit()
- if with_ti_history:
- for ti in tis:
- ti.try_number = 1
- session.merge(ti)
- session.commit()
- dag.clear()
- for ti in tis:
- ti.try_number = 2
- ti.queue = "default_queue"
- session.merge(ti)
- session.commit()
- return tis
-
class TestGetTaskInstance(TestTaskInstanceEndpoint):
def test_should_respond_200(self, test_client, session):
@@ -465,3 +455,315 @@ class TestGetMappedTaskInstance(TestTaskInstanceEndpoint):
assert response.json() == {
"detail": "The Mapped Task Instance with dag_id:
`example_python_operator`, run_id: `TEST_DAG_RUN_ID`, task_id:
`print_the_context`, and map_index: `10` was not found"
}
+
+
+class TestGetMappedTaskInstances:
+ @pytest.fixture(autouse=True)
+ def setup_attrs(self) -> None:
+ self.default_time = DEFAULT_DATETIME_1
+ self.ti_init = {
+ "execution_date": self.default_time,
+ "state": State.RUNNING,
+ }
+ self.ti_extras = {
+ "start_date": self.default_time + dt.timedelta(days=1),
+ "end_date": self.default_time + dt.timedelta(days=2),
+ "pid": 100,
+ "duration": 10000,
+ "pool": "default_pool",
+ "queue": "default_queue",
+ }
+ clear_db_runs()
+ clear_rendered_ti_fields()
+
+ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags=None):
+ for dag_id, dag in (dags or {}).items():
+ count = dag["success"] + dag["running"]
+ with dag_maker(session=session, dag_id=dag_id,
start_date=DEFAULT_DATETIME_1):
+ task1 = BaseOperator(task_id="op1")
+ mapped = MockOperator.partial(task_id="task_2",
executor="default").expand(arg2=task1.output)
+
+ dr = dag_maker.create_dagrun(run_id=f"run_{dag_id}")
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=count,
+ keys=None,
+ )
+ )
+
+ if count:
+ # Remove the map_index=-1 TI when we're creating other TIs
+ session.query(TaskInstance).filter(
+ TaskInstance.dag_id == mapped.dag_id,
+ TaskInstance.task_id == mapped.task_id,
+ TaskInstance.run_id == dr.run_id,
+ ).delete()
+
+ for index, state in enumerate(
+ itertools.chain(
+ itertools.repeat(TaskInstanceState.SUCCESS,
dag["success"]),
+ itertools.repeat(TaskInstanceState.FAILED, dag["failed"]),
+ itertools.repeat(TaskInstanceState.RUNNING,
dag["running"]),
+ )
+ ):
+ ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index,
state=state)
+ setattr(ti, "start_date", DEFAULT_DATETIME_1)
+ session.add(ti)
+
+ dagbag = DagBag(os.devnull, include_examples=False)
+ dagbag.dags = {dag_id: dag_maker.dag}
+ dagbag.sync_to_db()
+ session.flush()
+
+ mapped.expand_mapped_task(dr.run_id, session=session)
+
+ @pytest.fixture
+ def one_task_with_mapped_tis(self, dag_maker, session):
+ self.create_dag_runs_with_mapped_tasks(
+ dag_maker,
+ session,
+ dags={
+ "mapped_tis": {
+ "success": 3,
+ "failed": 0,
+ "running": 0,
+ },
+ },
+ )
+
+ @pytest.fixture
+ def one_task_with_single_mapped_ti(self, dag_maker, session):
+ self.create_dag_runs_with_mapped_tasks(
+ dag_maker,
+ session,
+ dags={
+ "mapped_tis": {
+ "success": 1,
+ "failed": 0,
+ "running": 0,
+ },
+ },
+ )
+
+ @pytest.fixture
+ def one_task_with_many_mapped_tis(self, dag_maker, session):
+ self.create_dag_runs_with_mapped_tasks(
+ dag_maker,
+ session,
+ dags={
+ "mapped_tis": {
+ "success": 5,
+ "failed": 20,
+ "running": 85,
+ },
+ },
+ )
+
+ @pytest.fixture
+ def one_task_with_zero_mapped_tis(self, dag_maker, session):
+ self.create_dag_runs_with_mapped_tasks(
+ dag_maker,
+ session,
+ dags={
+ "mapped_tis": {
+ "success": 0,
+ "failed": 0,
+ "running": 0,
+ },
+ },
+ )
+
+ def test_should_respond_404(self, test_client):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ )
+ assert response.status_code == 404
+ assert response.json() == {"detail": "DAG mapped_tis not found"}
+
+ def test_should_respond_200(self, one_task_with_many_mapped_tis,
test_client):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ )
+
+ assert response.status_code == 200
+ assert response.json()["total_entries"] == 110
+ assert len(response.json()["task_instances"]) == 100
+
+ def test_offset_limit(self, test_client, one_task_with_many_mapped_tis):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"offset": 4, "limit": 10},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 110
+ assert len(body["task_instances"]) == 10
+ assert list(range(4, 14)) == [ti["map_index"] for ti in
body["task_instances"]]
+
+ def test_order(self, test_client, one_task_with_many_mapped_tis):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 110
+ assert len(body["task_instances"]) == 100
+ assert list(range(100)) == [ti["map_index"] for ti in
body["task_instances"]]
+
+ def test_mapped_task_instances_reverse_order(self, test_client,
one_task_with_many_mapped_tis):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"order_by": "-map_index"},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 110
+ assert len(body["task_instances"]) == 100
+ assert list(range(109, 9, -1)) == [ti["map_index"] for ti in
body["task_instances"]]
+
+ def test_state_order(self, test_client, one_task_with_many_mapped_tis):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"order_by": "-state"},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 110
+ assert len(body["task_instances"]) == 100
+ assert list(range(5)[::-1]) + list(range(25, 110)[::-1]) +
list(range(15, 25)[::-1]) == [
+ ti["map_index"] for ti in body["task_instances"]
+ ]
+ # State ascending
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"order_by": "state", "limit": 108},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 110
+ assert len(body["task_instances"]) == 108
+ assert list(range(5, 25)) + list(range(25, 110)) + list(range(3)) == [
+ ti["map_index"] for ti in body["task_instances"]
+ ]
+
+ def test_rendered_map_index_order(self, test_client, session,
one_task_with_many_mapped_tis):
+ ti = (
+ session.query(TaskInstance)
+ .where(TaskInstance.task_id == "task_2", TaskInstance.map_index ==
0)
+ .first()
+ )
+
+ ti.rendered_map_index = "a"
+
+ session.commit()
+
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"order_by": "-rendered_map_index"},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 110
+ assert len(body["task_instances"]) == 100
+ assert [0] + list(range(11, 110)[::-1]) == [ti["map_index"] for ti in
body["task_instances"]]
+ # State ascending
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"order_by": "rendered_map_index", "limit": 108},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 110
+ assert len(body["task_instances"]) == 108
+ assert [0] + list(range(1, 108)) == [ti["map_index"] for ti in
body["task_instances"]]
+
+ def test_with_date(self, test_client, one_task_with_mapped_tis):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"start_date_gte": DEFAULT_DATETIME_1},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 3
+ assert len(body["task_instances"]) == 3
+
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"start_date_gte": DEFAULT_DATETIME_2},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 0
+ assert body["task_instances"] == []
+
+ def test_with_logical_date(self, test_client, one_task_with_mapped_tis):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"logical_date_gte": DEFAULT_DATETIME_1},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 3
+ assert len(body["task_instances"]) == 3
+
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params={"logical_date_gte": DEFAULT_DATETIME_2},
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 0
+ assert body["task_instances"] == []
+
+ @pytest.mark.parametrize(
+ "query_params, expected_total_entries, expected_task_instance_count",
+ [
+ ({"state": "success"}, 3, 3),
+ ({"state": "running"}, 0, 0),
+ ({"pool": "default_pool"}, 3, 3),
+ ({"pool": "test_pool"}, 0, 0),
+ ({"queue": "default"}, 3, 3),
+ ({"queue": "test_queue"}, 0, 0),
+ ({"executor": "default"}, 3, 3),
+ ({"executor": "no_exec"}, 0, 0),
+ ],
+ )
+ def test_mapped_task_instances_filters(
+ self,
+ test_client,
+ one_task_with_mapped_tis,
+ query_params,
+ expected_total_entries,
+ expected_task_instance_count,
+ ):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ params=query_params,
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == expected_total_entries
+ assert len(body["task_instances"]) == expected_task_instance_count
+
+ def test_with_zero_mapped(self, test_client,
one_task_with_zero_mapped_tis, session):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped",
+ )
+ assert response.status_code == 200
+ body = response.json()
+ assert body["total_entries"] == 0
+ assert body["task_instances"] == []
+
+ def test_should_raise_404_not_found_for_nonexistent_task(
+ self, one_task_with_zero_mapped_tis, test_client
+ ):
+ response = test_client.get(
+
"/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/nonexistent_task/listMapped",
+ )
+ assert response.status_code == 404
+ assert response.json()["detail"] == "Task id nonexistent_task not
found"