This is an automated email from the ASF dual-hosted git repository.
pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b3d73af67c fix(rest-api): Add order_by query param to TI listing APIs
(#41283) (#41307)
b3d73af67c is described below
commit b3d73af67ca9d8d1b52f28fba6b90d985b03b07d
Author: Omkar P <[email protected]>
AuthorDate: Wed Aug 21 19:37:33 2024 +0530
fix(rest-api): Add order_by query param to TI listing APIs (#41283) (#41307)
* fix(rest-api): Add order_by query param to TI listing APIs (#41283)
This adds db-level sorting with order_by query param to the
following TI listing APIs:
1. List task instances - /api/v1/dags/~/dagRuns/~/taskInstances
2. List task instances (batch) - /api/v1/dags/~/dagRuns/~/taskInstances/list
order_by defaults to sorting by start_date (ascending) for above
mentioned 2 APIs. Please note that this does NOT change the default
sorting param for the List mapped task instances API.
This also adds corresponding unit tests.
* Raise ValueError for unsupported order_by values
* Update docs for order_by
* Add TaskInstanceOrderBy for TI APIs
* Fix comment indentation
* Minor refines
---------
Co-authored-by: Tzu-ping Chung <[email protected]>
---
.../endpoints/task_instance_endpoint.py | 80 ++++++++++-----
airflow/api_connexion/openapi/v1.yaml | 28 ++++-
.../api_connexion/schemas/task_instance_schema.py | 1 +
airflow/www/static/js/types/api-generated.ts | 41 +++++++-
.../endpoints/test_task_instance_endpoint.py | 113 +++++++++++++++++++++
5 files changed, 233 insertions(+), 30 deletions(-)
diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index a9213d2c24..23e5f5a6d1 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -62,6 +62,7 @@ from airflow.www.extensions.init_auth_manager import
get_auth_manager
if TYPE_CHECKING:
from sqlalchemy.orm import Session
from sqlalchemy.sql import ClauseElement, Select
+ from sqlalchemy.sql.expression import ColumnOperators
from airflow.api_connexion.types import APIResponse
from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
@@ -245,28 +246,11 @@ def get_mapped_task_instances(
.options(joinedload(TI.rendered_task_instance_fields))
)
- if order_by is None:
- entry_query = entry_query.order_by(TI.map_index.asc())
- elif order_by == "state":
- entry_query = entry_query.order_by(TI.state.asc(), TI.map_index.asc())
- elif order_by == "-state":
- entry_query = entry_query.order_by(TI.state.desc(), TI.map_index.asc())
- elif order_by == "duration":
- entry_query = entry_query.order_by(TI.duration.asc(),
TI.map_index.asc())
- elif order_by == "-duration":
- entry_query = entry_query.order_by(TI.duration.desc(),
TI.map_index.asc())
- elif order_by == "start_date":
- entry_query = entry_query.order_by(TI.start_date.asc(),
TI.map_index.asc())
- elif order_by == "-start_date":
- entry_query = entry_query.order_by(TI.start_date.desc(),
TI.map_index.asc())
- elif order_by == "end_date":
- entry_query = entry_query.order_by(TI.end_date.asc(),
TI.map_index.asc())
- elif order_by == "-end_date":
- entry_query = entry_query.order_by(TI.end_date.desc(),
TI.map_index.asc())
- elif order_by == "-map_index":
- entry_query = entry_query.order_by(TI.map_index.desc())
- else:
- raise BadRequest(detail=f"Ordering with '{order_by}' is not supported")
+ try:
+ order_by_params = _get_order_by_params(order_by)
+ entry_query = entry_query.order_by(*order_by_params)
+ except _UnsupportedOrderBy as e:
+ raise BadRequest(detail=f"Ordering with {e.order_by!r} is not
supported")
# using execute because we want the SlaMiss entity. Scalars don't return
None for missing entities
task_instances =
session.execute(entry_query.offset(offset).limit(limit)).all()
@@ -297,6 +281,39 @@ def _apply_range_filter(query: Select, key: ClauseElement,
value_range: tuple[T,
return query
+class _UnsupportedOrderBy(ValueError):
+ def __init__(self, order_by: str) -> None:
+ super().__init__(order_by)
+ self.order_by = order_by
+
+
+def _get_order_by_params(order_by: str | None = None) ->
tuple[ColumnOperators, ...]:
+ """Return a tuple with the order by params to be used in the query."""
+ if order_by is None:
+ return (TI.map_index.asc(),)
+ if order_by == "state":
+ return (TI.state.asc(), TI.map_index.asc())
+ if order_by == "-state":
+ return (TI.state.desc(), TI.map_index.asc())
+ if order_by == "duration":
+ return (TI.duration.asc(), TI.map_index.asc())
+ if order_by == "-duration":
+ return (TI.duration.desc(), TI.map_index.asc())
+ if order_by == "start_date":
+ return (TI.start_date.asc(), TI.map_index.asc())
+ if order_by == "-start_date":
+ return (TI.start_date.desc(), TI.map_index.asc())
+ if order_by == "end_date":
+ return (TI.end_date.asc(), TI.map_index.asc())
+ if order_by == "-end_date":
+ return (TI.end_date.desc(), TI.map_index.asc())
+ if order_by == "map_index":
+ return (TI.map_index.asc(),)
+ if order_by == "-map_index":
+ return (TI.map_index.desc(),)
+ raise _UnsupportedOrderBy(order_by)
+
+
@format_parameters(
{
"execution_date_gte": format_datetime,
@@ -331,6 +348,7 @@ def get_task_instances(
queue: list[str] | None = None,
executor: list[str] | None = None,
offset: int | None = None,
+ order_by: str | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get list of task instances."""
@@ -378,11 +396,16 @@ def get_task_instances(
)
.add_columns(SlaMiss)
.options(joinedload(TI.rendered_task_instance_fields))
- .offset(offset)
- .limit(limit)
)
+
+ try:
+ order_by_params = _get_order_by_params(order_by)
+ entry_query = entry_query.order_by(*order_by_params)
+ except _UnsupportedOrderBy as e:
+ raise BadRequest(detail=f"Ordering with {e.order_by!r} is not
supported")
+
# using execute because we want the SlaMiss entity. Scalars don't return
None for missing entities
- task_instances = session.execute(entry_query).all()
+ task_instances =
session.execute(entry_query.offset(offset).limit(limit)).all()
return task_instance_collection_schema.dump(
TaskInstanceCollection(task_instances=task_instances,
total_entries=total_entries)
)
@@ -453,6 +476,13 @@ def get_task_instances_batch(session: Session =
NEW_SESSION) -> APIResponse:
ti_query = base_query.options(
joinedload(TI.rendered_task_instance_fields),
joinedload(TI.task_instance_note)
)
+
+ try:
+ order_by_params = _get_order_by_params(data["order_by"])
+ ti_query = ti_query.order_by(*order_by_params)
+ except _UnsupportedOrderBy as e:
+ raise BadRequest(detail=f"Ordering with {e.order_by!r} is not
supported")
+
# using execute because we want the SlaMiss entity. Scalars don't return
None for missing entities
task_instances = session.execute(ti_query).all()
diff --git a/airflow/api_connexion/openapi/v1.yaml
b/airflow/api_connexion/openapi/v1.yaml
index d0f5db43db..a0809f045d 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1517,6 +1517,7 @@ paths:
parameters:
- $ref: "#/components/parameters/PageLimit"
- $ref: "#/components/parameters/PageOffset"
+ - $ref: "#/components/parameters/TaskInstanceOrderBy"
responses:
"200":
description: Success.
@@ -1675,7 +1676,7 @@ paths:
- $ref: "#/components/parameters/FilterPool"
- $ref: "#/components/parameters/FilterQueue"
- $ref: "#/components/parameters/FilterExecutor"
- - $ref: "#/components/parameters/OrderBy"
+ - $ref: "#/components/parameters/TaskInstanceOrderBy"
responses:
"200":
description: Success.
@@ -5177,6 +5178,16 @@ components:
items:
type: string
description: The value can be repeated to retrieve multiple matching
values (OR condition).
+ order_by:
+ type: string
+ description: |
+ The name of the field to order the results by. Prefix a field name
+ with `-` to reverse the sort order. `order_by` defaults to
+ `map_index` when unspecified.
+ Supported field names: `state`, `duration`, `start_date`,
`end_date`
+ and `map_index`.
+
+ *New in version 3.0.0*
# Common data type
ScheduleInterval:
@@ -5802,6 +5813,21 @@ components:
type: integer
description: Filter on try_number for task instance.
+ TaskInstanceOrderBy:
+ in: query
+ name: order_by
+ schema:
+ type: string
+ required: false
+ description: |
+ The name of the field to order the results by. Prefix a field name
+ with `-` to reverse the sort order. `order_by` defaults to
+ `map_index` when unspecified.
+ Supported field names: `state`, `duration`, `start_date`, `end_date`
+ and `map_index`.
+
+ *New in version 3.0.0*
+
OrderBy:
in: query
name: order_by
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py
b/airflow/api_connexion/schemas/task_instance_schema.py
index 74cd0585dc..0c1daf6ce2 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -167,6 +167,7 @@ class TaskInstanceBatchFormSchema(Schema):
pool = fields.List(fields.Str(), load_default=None)
queue = fields.List(fields.Str(), load_default=None)
executor = fields.List(fields.Str(), load_default=None)
+ order_by = fields.Str(load_default=None)
class ClearTaskInstanceFormSchema(Schema):
diff --git a/airflow/www/static/js/types/api-generated.ts
b/airflow/www/static/js/types/api-generated.ts
index d5c1e06b6e..1897fd90ce 100644
--- a/airflow/www/static/js/types/api-generated.ts
+++ b/airflow/www/static/js/types/api-generated.ts
@@ -2304,6 +2304,16 @@ export interface components {
queue?: string[];
/** @description The value can be repeated to retrieve multiple matching
values (OR condition). */
executor?: string[];
+ /**
+ * @description The name of the field to order the results by. Prefix a
field name
+ * with `-` to reverse the sort order. `order_by` defaults to
+ * `map_index` when unspecified.
+ * Supported field names: `state`, `duration`, `start_date`, `end_date`
+ * and `map_index`.
+ *
+ * *New in version 3.0.0*
+ */
+ order_by?: string;
};
/**
* @description Schedule interval. Defines how often DAG runs, this object
gets added to your latest task instance's
@@ -2654,6 +2664,16 @@ export interface components {
FilterMapIndex: number;
/** @description Filter on try_number for task instance. */
FilterTryNumber: number;
+ /**
+ * @description The name of the field to order the results by. Prefix a
field name
+ * with `-` to reverse the sort order. `order_by` defaults to
+ * `map_index` when unspecified.
+ * Supported field names: `state`, `duration`, `start_date`, `end_date`
+ * and `map_index`.
+ *
+ * *New in version 3.0.0*
+ */
+ TaskInstanceOrderBy: string;
/**
* @description The name of the field to order the results by.
* Prefix a field name with `-` to reverse the sort order.
@@ -4049,6 +4069,16 @@ export interface operations {
limit?: components["parameters"]["PageLimit"];
/** The number of items to skip before starting to collect the result
set. */
offset?: components["parameters"]["PageOffset"];
+ /**
+ * The name of the field to order the results by. Prefix a field name
+ * with `-` to reverse the sort order. `order_by` defaults to
+ * `map_index` when unspecified.
+ * Supported field names: `state`, `duration`, `start_date`, `end_date`
+ * and `map_index`.
+ *
+ * *New in version 3.0.0*
+ */
+ order_by?: components["parameters"]["TaskInstanceOrderBy"];
};
};
responses: {
@@ -4276,12 +4306,15 @@ export interface operations {
/** The value can be repeated to retrieve multiple matching values (OR
condition). */
executor?: components["parameters"]["FilterExecutor"];
/**
- * The name of the field to order the results by.
- * Prefix a field name with `-` to reverse the sort order.
+ * The name of the field to order the results by. Prefix a field name
+ * with `-` to reverse the sort order. `order_by` defaults to
+ * `map_index` when unspecified.
+ * Supported field names: `state`, `duration`, `start_date`, `end_date`
+ * and `map_index`.
*
- * *New in version 2.1.0*
+ * *New in version 3.0.0*
*/
- order_by?: components["parameters"]["OrderBy"];
+ order_by?: components["parameters"]["TaskInstanceOrderBy"];
};
};
responses: {
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 4fcd66affe..25ded6c814 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -782,6 +782,80 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
assert count == response.json["total_entries"]
assert count == len(response.json["task_instances"])
+ def test_should_respond_200_for_order_by(self, session):
+ dag_id = "example_python_operator"
+ self.create_task_instances(
+ session,
+ task_instances=[
+ {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i +
1))} for i in range(10)
+ ],
+ dag_id=dag_id,
+ )
+
+ ti_count = session.query(TaskInstance).filter(TaskInstance.dag_id ==
dag_id).count()
+
+ # Ascending order
+ response_asc = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances?order_by=start_date",
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response_asc.status_code == 200
+ assert response_asc.json["total_entries"] == ti_count
+ assert len(response_asc.json["task_instances"]) == ti_count
+
+ # Descending order
+ response_desc = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances?order_by=-start_date",
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response_desc.status_code == 200
+ assert response_desc.json["total_entries"] == ti_count
+ assert len(response_desc.json["task_instances"]) == ti_count
+
+ # Compare
+ start_dates_asc = [ti["start_date"] for ti in
response_asc.json["task_instances"]]
+ assert len(start_dates_asc) == ti_count
+ start_dates_desc = [ti["start_date"] for ti in
response_desc.json["task_instances"]]
+ assert len(start_dates_desc) == ti_count
+ assert start_dates_asc == list(reversed(start_dates_desc))
+
+ def test_should_respond_200_for_pagination(self, session):
+ dag_id = "example_python_operator"
+ self.create_task_instances(
+ session,
+ task_instances=[
+ {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i +
1))} for i in range(10)
+ ],
+ dag_id=dag_id,
+ )
+
+ # First 5 items
+ response_batch1 = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances?limit=5&offset=0",
+ environ_overrides={"REMOTE_USER": "test"},
+ )
+ assert response_batch1.status_code == 200, response_batch1.json
+ num_entries_batch1 = len(response_batch1.json["task_instances"])
+ assert num_entries_batch1 == 5
+ assert len(response_batch1.json["task_instances"]) == 5
+
+ # 5 items after that
+ response_batch2 = self.client.get(
+ "/api/v1/dags/~/dagRuns/~/taskInstances?limit=5&offset=5",
+ environ_overrides={"REMOTE_USER": "test"},
+ json={"limit": 5, "offset": 0, "dag_ids": [dag_id]},
+ )
+ assert response_batch2.status_code == 200, response_batch2.json
+ num_entries_batch2 = len(response_batch2.json["task_instances"])
+ assert num_entries_batch2 > 0
+ assert len(response_batch2.json["task_instances"]) > 0
+
+ # Match
+ ti_count = session.query(TaskInstance).filter(TaskInstance.dag_id ==
dag_id).count()
+ assert response_batch1.json["total_entries"] ==
response_batch2.json["total_entries"] == ti_count
+ assert (num_entries_batch1 + num_entries_batch2) == ti_count
+ assert response_batch1 != response_batch2
+
def test_should_raises_401_unauthenticated(self):
response = self.client.get(
"/api/v1/dags/example_python_operator/dagRuns/~/taskInstances",
@@ -971,6 +1045,45 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
assert expected_ti_count == response.json["total_entries"]
assert expected_ti_count == len(response.json["task_instances"])
+ def test_should_respond_200_for_order_by(self, session):
+ dag_id = "example_python_operator"
+ self.create_task_instances(
+ session,
+ task_instances=[
+ {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i +
1))} for i in range(10)
+ ],
+ dag_id=dag_id,
+ )
+
+ ti_count = session.query(TaskInstance).filter(TaskInstance.dag_id ==
dag_id).count()
+
+ # Ascending order
+ response_asc = self.client.post(
+ "/api/v1/dags/~/dagRuns/~/taskInstances/list",
+ environ_overrides={"REMOTE_USER": "test"},
+ json={"order_by": "start_date", "dag_ids": [dag_id]},
+ )
+ assert response_asc.status_code == 200, response_asc.json
+ assert response_asc.json["total_entries"] == ti_count
+ assert len(response_asc.json["task_instances"]) == ti_count
+
+ # Descending order
+ response_desc = self.client.post(
+ "/api/v1/dags/~/dagRuns/~/taskInstances/list",
+ environ_overrides={"REMOTE_USER": "test"},
+ json={"order_by": "-start_date", "dag_ids": [dag_id]},
+ )
+ assert response_desc.status_code == 200, response_desc.json
+ assert response_desc.json["total_entries"] == ti_count
+ assert len(response_desc.json["task_instances"]) == ti_count
+
+ # Compare
+ start_dates_asc = [ti["start_date"] for ti in
response_asc.json["task_instances"]]
+ assert len(start_dates_asc) == ti_count
+ start_dates_desc = [ti["start_date"] for ti in
response_desc.json["task_instances"]]
+ assert len(start_dates_desc) == ti_count
+ assert start_dates_asc == list(reversed(start_dates_desc))
+
@pytest.mark.parametrize(
"task_instances, payload, expected_ti_count",
[