This is an automated email from the ASF dual-hosted git repository.

pierrejeambrun pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b3d73af67c fix(rest-api): Add order_by query param to TI listing APIs 
(#41283) (#41307)
b3d73af67c is described below

commit b3d73af67ca9d8d1b52f28fba6b90d985b03b07d
Author: Omkar P <[email protected]>
AuthorDate: Wed Aug 21 19:37:33 2024 +0530

    fix(rest-api): Add order_by query param to TI listing APIs (#41283) (#41307)
    
    * fix(rest-api): Add order_by query param to TI listing APIs (#41283)
    
    This adds db-level sorting with order_by query param to the
    following TI listing APIs:
    1. List task instances - /api/v1/dags/~/dagRuns/~/taskInstances
    2. List task instances (batch) - /api/v1/dags/~/dagRuns/~/taskInstances/list
    
    order_by defaults to sorting by start_date (ascending) for above
    mentioned 2 APIs. Please note that this does NOT change the default
    sorting param for the List mapped task instances API.
    
    This also adds corresponding unit tests.
    
    * Raise ValueError for unsupported order_by values
    
    * Update docs for order_by
    
    * Add TaskInstanceOrderBy for TI APIs
    
    * Fix comment indentation
    
    * Minor refines
    
    ---------
    
    Co-authored-by: Tzu-ping Chung <[email protected]>
---
 .../endpoints/task_instance_endpoint.py            |  80 ++++++++++-----
 airflow/api_connexion/openapi/v1.yaml              |  28 ++++-
 .../api_connexion/schemas/task_instance_schema.py  |   1 +
 airflow/www/static/js/types/api-generated.ts       |  41 +++++++-
 .../endpoints/test_task_instance_endpoint.py       | 113 +++++++++++++++++++++
 5 files changed, 233 insertions(+), 30 deletions(-)

diff --git a/airflow/api_connexion/endpoints/task_instance_endpoint.py 
b/airflow/api_connexion/endpoints/task_instance_endpoint.py
index a9213d2c24..23e5f5a6d1 100644
--- a/airflow/api_connexion/endpoints/task_instance_endpoint.py
+++ b/airflow/api_connexion/endpoints/task_instance_endpoint.py
@@ -62,6 +62,7 @@ from airflow.www.extensions.init_auth_manager import 
get_auth_manager
 if TYPE_CHECKING:
     from sqlalchemy.orm import Session
     from sqlalchemy.sql import ClauseElement, Select
+    from sqlalchemy.sql.expression import ColumnOperators
 
     from airflow.api_connexion.types import APIResponse
     from airflow.auth.managers.models.batch_apis import IsAuthorizedDagRequest
@@ -245,28 +246,11 @@ def get_mapped_task_instances(
         .options(joinedload(TI.rendered_task_instance_fields))
     )
 
-    if order_by is None:
-        entry_query = entry_query.order_by(TI.map_index.asc())
-    elif order_by == "state":
-        entry_query = entry_query.order_by(TI.state.asc(), TI.map_index.asc())
-    elif order_by == "-state":
-        entry_query = entry_query.order_by(TI.state.desc(), TI.map_index.asc())
-    elif order_by == "duration":
-        entry_query = entry_query.order_by(TI.duration.asc(), 
TI.map_index.asc())
-    elif order_by == "-duration":
-        entry_query = entry_query.order_by(TI.duration.desc(), 
TI.map_index.asc())
-    elif order_by == "start_date":
-        entry_query = entry_query.order_by(TI.start_date.asc(), 
TI.map_index.asc())
-    elif order_by == "-start_date":
-        entry_query = entry_query.order_by(TI.start_date.desc(), 
TI.map_index.asc())
-    elif order_by == "end_date":
-        entry_query = entry_query.order_by(TI.end_date.asc(), 
TI.map_index.asc())
-    elif order_by == "-end_date":
-        entry_query = entry_query.order_by(TI.end_date.desc(), 
TI.map_index.asc())
-    elif order_by == "-map_index":
-        entry_query = entry_query.order_by(TI.map_index.desc())
-    else:
-        raise BadRequest(detail=f"Ordering with '{order_by}' is not supported")
+    try:
+        order_by_params = _get_order_by_params(order_by)
+        entry_query = entry_query.order_by(*order_by_params)
+    except _UnsupportedOrderBy as e:
+        raise BadRequest(detail=f"Ordering with {e.order_by!r} is not 
supported")
 
     # using execute because we want the SlaMiss entity. Scalars don't return 
None for missing entities
     task_instances = 
session.execute(entry_query.offset(offset).limit(limit)).all()
@@ -297,6 +281,39 @@ def _apply_range_filter(query: Select, key: ClauseElement, 
value_range: tuple[T,
     return query
 
 
+class _UnsupportedOrderBy(ValueError):
+    def __init__(self, order_by: str) -> None:
+        super().__init__(order_by)
+        self.order_by = order_by
+
+
+def _get_order_by_params(order_by: str | None = None) -> 
tuple[ColumnOperators, ...]:
+    """Return a tuple with the order by params to be used in the query."""
+    if order_by is None:
+        return (TI.map_index.asc(),)
+    if order_by == "state":
+        return (TI.state.asc(), TI.map_index.asc())
+    if order_by == "-state":
+        return (TI.state.desc(), TI.map_index.asc())
+    if order_by == "duration":
+        return (TI.duration.asc(), TI.map_index.asc())
+    if order_by == "-duration":
+        return (TI.duration.desc(), TI.map_index.asc())
+    if order_by == "start_date":
+        return (TI.start_date.asc(), TI.map_index.asc())
+    if order_by == "-start_date":
+        return (TI.start_date.desc(), TI.map_index.asc())
+    if order_by == "end_date":
+        return (TI.end_date.asc(), TI.map_index.asc())
+    if order_by == "-end_date":
+        return (TI.end_date.desc(), TI.map_index.asc())
+    if order_by == "map_index":
+        return (TI.map_index.asc(),)
+    if order_by == "-map_index":
+        return (TI.map_index.desc(),)
+    raise _UnsupportedOrderBy(order_by)
+
+
 @format_parameters(
     {
         "execution_date_gte": format_datetime,
@@ -331,6 +348,7 @@ def get_task_instances(
     queue: list[str] | None = None,
     executor: list[str] | None = None,
     offset: int | None = None,
+    order_by: str | None = None,
     session: Session = NEW_SESSION,
 ) -> APIResponse:
     """Get list of task instances."""
@@ -378,11 +396,16 @@ def get_task_instances(
         )
         .add_columns(SlaMiss)
         .options(joinedload(TI.rendered_task_instance_fields))
-        .offset(offset)
-        .limit(limit)
     )
+
+    try:
+        order_by_params = _get_order_by_params(order_by)
+        entry_query = entry_query.order_by(*order_by_params)
+    except _UnsupportedOrderBy as e:
+        raise BadRequest(detail=f"Ordering with {e.order_by!r} is not 
supported")
+
     # using execute because we want the SlaMiss entity. Scalars don't return 
None for missing entities
-    task_instances = session.execute(entry_query).all()
+    task_instances = 
session.execute(entry_query.offset(offset).limit(limit)).all()
     return task_instance_collection_schema.dump(
         TaskInstanceCollection(task_instances=task_instances, 
total_entries=total_entries)
     )
@@ -453,6 +476,13 @@ def get_task_instances_batch(session: Session = 
NEW_SESSION) -> APIResponse:
     ti_query = base_query.options(
         joinedload(TI.rendered_task_instance_fields), 
joinedload(TI.task_instance_note)
     )
+
+    try:
+        order_by_params = _get_order_by_params(data["order_by"])
+        ti_query = ti_query.order_by(*order_by_params)
+    except _UnsupportedOrderBy as e:
+        raise BadRequest(detail=f"Ordering with {e.order_by!r} is not 
supported")
+
     # using execute because we want the SlaMiss entity. Scalars don't return 
None for missing entities
     task_instances = session.execute(ti_query).all()
 
diff --git a/airflow/api_connexion/openapi/v1.yaml 
b/airflow/api_connexion/openapi/v1.yaml
index d0f5db43db..a0809f045d 100644
--- a/airflow/api_connexion/openapi/v1.yaml
+++ b/airflow/api_connexion/openapi/v1.yaml
@@ -1517,6 +1517,7 @@ paths:
       parameters:
         - $ref: "#/components/parameters/PageLimit"
         - $ref: "#/components/parameters/PageOffset"
+        - $ref: "#/components/parameters/TaskInstanceOrderBy"
       responses:
         "200":
           description: Success.
@@ -1675,7 +1676,7 @@ paths:
         - $ref: "#/components/parameters/FilterPool"
         - $ref: "#/components/parameters/FilterQueue"
         - $ref: "#/components/parameters/FilterExecutor"
-        - $ref: "#/components/parameters/OrderBy"
+        - $ref: "#/components/parameters/TaskInstanceOrderBy"
       responses:
         "200":
           description: Success.
@@ -5177,6 +5178,16 @@ components:
           items:
             type: string
           description: The value can be repeated to retrieve multiple matching 
values (OR condition).
+        order_by:
+          type: string
+          description: |
+            The name of the field to order the results by. Prefix a field name
+            with `-` to reverse the sort order. `order_by` defaults to
+            `map_index` when unspecified.
+            Supported field names: `state`, `duration`, `start_date`, 
`end_date`
+            and `map_index`.
+
+            *New in version 3.0.0*
 
     # Common data type
     ScheduleInterval:
@@ -5802,6 +5813,21 @@ components:
         type: integer
       description: Filter on try_number for task instance.
 
+    TaskInstanceOrderBy:
+      in: query
+      name: order_by
+      schema:
+        type: string
+      required: false
+      description: |
+        The name of the field to order the results by. Prefix a field name
+        with `-` to reverse the sort order. `order_by` defaults to
+        `map_index` when unspecified.
+        Supported field names: `state`, `duration`, `start_date`, `end_date`
+        and `map_index`.
+
+        *New in version 3.0.0*
+
     OrderBy:
       in: query
       name: order_by
diff --git a/airflow/api_connexion/schemas/task_instance_schema.py 
b/airflow/api_connexion/schemas/task_instance_schema.py
index 74cd0585dc..0c1daf6ce2 100644
--- a/airflow/api_connexion/schemas/task_instance_schema.py
+++ b/airflow/api_connexion/schemas/task_instance_schema.py
@@ -167,6 +167,7 @@ class TaskInstanceBatchFormSchema(Schema):
     pool = fields.List(fields.Str(), load_default=None)
     queue = fields.List(fields.Str(), load_default=None)
     executor = fields.List(fields.Str(), load_default=None)
+    order_by = fields.Str(load_default=None)
 
 
 class ClearTaskInstanceFormSchema(Schema):
diff --git a/airflow/www/static/js/types/api-generated.ts 
b/airflow/www/static/js/types/api-generated.ts
index d5c1e06b6e..1897fd90ce 100644
--- a/airflow/www/static/js/types/api-generated.ts
+++ b/airflow/www/static/js/types/api-generated.ts
@@ -2304,6 +2304,16 @@ export interface components {
       queue?: string[];
       /** @description The value can be repeated to retrieve multiple matching 
values (OR condition). */
       executor?: string[];
+      /**
+       * @description The name of the field to order the results by. Prefix a 
field name
+       * with `-` to reverse the sort order. `order_by` defaults to
+       * `map_index` when unspecified.
+       * Supported field names: `state`, `duration`, `start_date`, `end_date`
+       * and `map_index`.
+       *
+       * *New in version 3.0.0*
+       */
+      order_by?: string;
     };
     /**
      * @description Schedule interval. Defines how often DAG runs, this object 
gets added to your latest task instance's
@@ -2654,6 +2664,16 @@ export interface components {
     FilterMapIndex: number;
     /** @description Filter on try_number for task instance. */
     FilterTryNumber: number;
+    /**
+     * @description The name of the field to order the results by. Prefix a 
field name
+     * with `-` to reverse the sort order. `order_by` defaults to
+     * `map_index` when unspecified.
+     * Supported field names: `state`, `duration`, `start_date`, `end_date`
+     * and `map_index`.
+     *
+     * *New in version 3.0.0*
+     */
+    TaskInstanceOrderBy: string;
     /**
      * @description The name of the field to order the results by.
      * Prefix a field name with `-` to reverse the sort order.
@@ -4049,6 +4069,16 @@ export interface operations {
         limit?: components["parameters"]["PageLimit"];
         /** The number of items to skip before starting to collect the result 
set. */
         offset?: components["parameters"]["PageOffset"];
+        /**
+         * The name of the field to order the results by. Prefix a field name
+         * with `-` to reverse the sort order. `order_by` defaults to
+         * `map_index` when unspecified.
+         * Supported field names: `state`, `duration`, `start_date`, `end_date`
+         * and `map_index`.
+         *
+         * *New in version 3.0.0*
+         */
+        order_by?: components["parameters"]["TaskInstanceOrderBy"];
       };
     };
     responses: {
@@ -4276,12 +4306,15 @@ export interface operations {
         /** The value can be repeated to retrieve multiple matching values (OR 
condition). */
         executor?: components["parameters"]["FilterExecutor"];
         /**
-         * The name of the field to order the results by.
-         * Prefix a field name with `-` to reverse the sort order.
+         * The name of the field to order the results by. Prefix a field name
+         * with `-` to reverse the sort order. `order_by` defaults to
+         * `map_index` when unspecified.
+         * Supported field names: `state`, `duration`, `start_date`, `end_date`
+         * and `map_index`.
          *
-         * *New in version 2.1.0*
+         * *New in version 3.0.0*
          */
-        order_by?: components["parameters"]["OrderBy"];
+        order_by?: components["parameters"]["TaskInstanceOrderBy"];
       };
     };
     responses: {
diff --git a/tests/api_connexion/endpoints/test_task_instance_endpoint.py 
b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
index 4fcd66affe..25ded6c814 100644
--- a/tests/api_connexion/endpoints/test_task_instance_endpoint.py
+++ b/tests/api_connexion/endpoints/test_task_instance_endpoint.py
@@ -782,6 +782,80 @@ class TestGetTaskInstances(TestTaskInstanceEndpoint):
         assert count == response.json["total_entries"]
         assert count == len(response.json["task_instances"])
 
+    def test_should_respond_200_for_order_by(self, session):
+        dag_id = "example_python_operator"
+        self.create_task_instances(
+            session,
+            task_instances=[
+                {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 
1))} for i in range(10)
+            ],
+            dag_id=dag_id,
+        )
+
+        ti_count = session.query(TaskInstance).filter(TaskInstance.dag_id == 
dag_id).count()
+
+        # Ascending order
+        response_asc = self.client.get(
+            "/api/v1/dags/~/dagRuns/~/taskInstances?order_by=start_date",
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response_asc.status_code == 200
+        assert response_asc.json["total_entries"] == ti_count
+        assert len(response_asc.json["task_instances"]) == ti_count
+
+        # Descending order
+        response_desc = self.client.get(
+            "/api/v1/dags/~/dagRuns/~/taskInstances?order_by=-start_date",
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response_desc.status_code == 200
+        assert response_desc.json["total_entries"] == ti_count
+        assert len(response_desc.json["task_instances"]) == ti_count
+
+        # Compare
+        start_dates_asc = [ti["start_date"] for ti in 
response_asc.json["task_instances"]]
+        assert len(start_dates_asc) == ti_count
+        start_dates_desc = [ti["start_date"] for ti in 
response_desc.json["task_instances"]]
+        assert len(start_dates_desc) == ti_count
+        assert start_dates_asc == list(reversed(start_dates_desc))
+
+    def test_should_respond_200_for_pagination(self, session):
+        dag_id = "example_python_operator"
+        self.create_task_instances(
+            session,
+            task_instances=[
+                {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 
1))} for i in range(10)
+            ],
+            dag_id=dag_id,
+        )
+
+        # First 5 items
+        response_batch1 = self.client.get(
+            "/api/v1/dags/~/dagRuns/~/taskInstances?limit=5&offset=0",
+            environ_overrides={"REMOTE_USER": "test"},
+        )
+        assert response_batch1.status_code == 200, response_batch1.json
+        num_entries_batch1 = len(response_batch1.json["task_instances"])
+        assert num_entries_batch1 == 5
+        assert len(response_batch1.json["task_instances"]) == 5
+
+        # 5 items after that
+        response_batch2 = self.client.get(
+            "/api/v1/dags/~/dagRuns/~/taskInstances?limit=5&offset=5",
+            environ_overrides={"REMOTE_USER": "test"},
+            json={"limit": 5, "offset": 0, "dag_ids": [dag_id]},
+        )
+        assert response_batch2.status_code == 200, response_batch2.json
+        num_entries_batch2 = len(response_batch2.json["task_instances"])
+        assert num_entries_batch2 > 0
+        assert len(response_batch2.json["task_instances"]) > 0
+
+        # Match
+        ti_count = session.query(TaskInstance).filter(TaskInstance.dag_id == 
dag_id).count()
+        assert response_batch1.json["total_entries"] == 
response_batch2.json["total_entries"] == ti_count
+        assert (num_entries_batch1 + num_entries_batch2) == ti_count
+        assert response_batch1 != response_batch2
+
     def test_should_raises_401_unauthenticated(self):
         response = self.client.get(
             "/api/v1/dags/example_python_operator/dagRuns/~/taskInstances",
@@ -971,6 +1045,45 @@ class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint):
         assert expected_ti_count == response.json["total_entries"]
         assert expected_ti_count == len(response.json["task_instances"])
 
+    def test_should_respond_200_for_order_by(self, session):
+        dag_id = "example_python_operator"
+        self.create_task_instances(
+            session,
+            task_instances=[
+                {"start_date": DEFAULT_DATETIME_1 + dt.timedelta(minutes=(i + 
1))} for i in range(10)
+            ],
+            dag_id=dag_id,
+        )
+
+        ti_count = session.query(TaskInstance).filter(TaskInstance.dag_id == 
dag_id).count()
+
+        # Ascending order
+        response_asc = self.client.post(
+            "/api/v1/dags/~/dagRuns/~/taskInstances/list",
+            environ_overrides={"REMOTE_USER": "test"},
+            json={"order_by": "start_date", "dag_ids": [dag_id]},
+        )
+        assert response_asc.status_code == 200, response_asc.json
+        assert response_asc.json["total_entries"] == ti_count
+        assert len(response_asc.json["task_instances"]) == ti_count
+
+        # Descending order
+        response_desc = self.client.post(
+            "/api/v1/dags/~/dagRuns/~/taskInstances/list",
+            environ_overrides={"REMOTE_USER": "test"},
+            json={"order_by": "-start_date", "dag_ids": [dag_id]},
+        )
+        assert response_desc.status_code == 200, response_desc.json
+        assert response_desc.json["total_entries"] == ti_count
+        assert len(response_desc.json["task_instances"]) == ti_count
+
+        # Compare
+        start_dates_asc = [ti["start_date"] for ti in 
response_asc.json["task_instances"]]
+        assert len(start_dates_asc) == ti_count
+        start_dates_desc = [ti["start_date"] for ti in 
response_desc.json["task_instances"]]
+        assert len(start_dates_desc) == ti_count
+        assert start_dates_asc == list(reversed(start_dates_desc))
+
     @pytest.mark.parametrize(
         "task_instances, payload, expected_ti_count",
         [

Reply via email to