jason810496 commented on code in PR #56837:
URL: https://github.com/apache/airflow/pull/56837#discussion_r2443007450


##########
airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py:
##########
@@ -130,3 +139,148 @@ def _find_aggregates(
             **_get_aggs_for_node(details),
         }
         return
+
+
+def get_batch_ti_summaries(
+    dag_id: str,
+    run_ids: list[str],
+    session: Session,
+) -> dict:
+    """
+    Fetch task instance summaries for multiple runs in a single query.
+
+    This is much more efficient than the N+1 query pattern of calling
+    get_grid_ti_summaries multiple times.
+
+    Returns a dict with structure:
+    {
+        "dag_id": str,
+        "summaries": [GridTISummaries, ...]
+    }
+    """
+    if not run_ids:
+        return {"dag_id": dag_id, "summaries": []}
+
+    # Single query to fetch ALL task instances for ALL runs
+    tis_query = (
+        select(
+            TaskInstance.run_id,
+            TaskInstance.task_id,
+            TaskInstance.state,
+            TaskInstance.dag_version_id,
+            TaskInstance.start_date,
+            TaskInstance.end_date,
+        )
+        .where(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.run_id.in_(run_ids),
+        )
+        .order_by(TaskInstance.run_id, TaskInstance.task_id)
+    )
+
+    task_instances = list(session.execute(tis_query))
+
+    # Group by run_id and collect unique dag_version_ids
+    tis_by_run = collections.defaultdict(list)
+    dag_version_ids = set()
+    for ti in task_instances:
+        tis_by_run[ti.run_id].append(ti)
+        if ti.dag_version_id:
+            dag_version_ids.add(ti.dag_version_id)
+
+    # Fetch all needed serialized DAGs in one query
+    serdags_by_version = {}
+    if dag_version_ids:
+        serdags = session.scalars(
+            select(SerializedDagModel)
+            .join(DagVersion, SerializedDagModel.dag_version_id == 
DagVersion.id)
+            .where(DagVersion.id.in_(dag_version_ids))
+        )
+        for serdag in serdags:
+            if serdag.dag_version_id:
+                serdags_by_version[serdag.dag_version_id] = serdag

Review Comment:
   Maybe we need to handle this case as well
   
   
https://github.com/apache/airflow/blob/c54c1a97220756f5af4914333a4ab9941f32e8b1/airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py#L92-L101



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py:
##########
@@ -425,3 +427,52 @@ def get_node_sumaries():
         "dag_id": dag_id,
         "task_instances": filtered,
     }
+
+
+@grid_router.post(
+    "/ti_summaries_batch/{dag_id}",
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_404_NOT_FOUND,
+        ]
+    ),
+    dependencies=[
+        Depends(
+            requires_access_dag(
+                method="GET",
+                access_entity=DagAccessEntity.TASK_INSTANCE,
+            )
+        ),
+        Depends(
+            requires_access_dag(
+                method="GET",
+                access_entity=DagAccessEntity.RUN,
+            )
+        ),
+    ],
+)
+def get_grid_ti_summaries_batch(
+    dag_id: str,
+    run_ids: list[str],
+    session: SessionDep,
+) -> GridTISummariesBatch:
+    """
+    Get task instance summaries for multiple DAG runs in a single request.
+
+    This endpoint is much more efficient than calling 
/ti_summaries/{dag_id}/{run_id}
+    multiple times, as it fetches all task instances in a single database 
query.
+    """
+    if not run_ids:
+        raise HTTPException(
+            status.HTTP_400_BAD_REQUEST,
+            "run_ids must not be empty",
+        )
+
+    if len(run_ids) > 100:
+        raise HTTPException(
+            status.HTTP_400_BAD_REQUEST,
+            "Cannot fetch more than 100 runs at once",
+        )

Review Comment:
   Maybe we could laverage Pydantic to handle the length of `run_ids`



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/ui/grid.py:
##########
@@ -425,3 +427,52 @@ def get_node_sumaries():
         "dag_id": dag_id,
         "task_instances": filtered,
     }
+
+
+@grid_router.post(
+    "/ti_summaries_batch/{dag_id}",
+    responses=create_openapi_http_exception_doc(
+        [
+            status.HTTP_400_BAD_REQUEST,
+            status.HTTP_404_NOT_FOUND,
+        ]
+    ),
+    dependencies=[
+        Depends(
+            requires_access_dag(
+                method="GET",
+                access_entity=DagAccessEntity.TASK_INSTANCE,
+            )
+        ),
+        Depends(
+            requires_access_dag(
+                method="GET",
+                access_entity=DagAccessEntity.RUN,
+            )
+        ),
+    ],
+)
+def get_grid_ti_summaries_batch(
+    dag_id: str,
+    run_ids: list[str],

Review Comment:
   ```suggestion
       run_ids: list[str] = Field(min_length=1, max_length=100),
   ```



##########
airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py:
##########
@@ -130,3 +139,148 @@ def _find_aggregates(
             **_get_aggs_for_node(details),
         }
         return
+
+
+def get_batch_ti_summaries(
+    dag_id: str,
+    run_ids: list[str],
+    session: Session,
+) -> dict:
+    """
+    Fetch task instance summaries for multiple runs in a single query.
+
+    This is much more efficient than the N+1 query pattern of calling
+    get_grid_ti_summaries multiple times.
+
+    Returns a dict with structure:
+    {
+        "dag_id": str,
+        "summaries": [GridTISummaries, ...]
+    }
+    """
+    if not run_ids:
+        return {"dag_id": dag_id, "summaries": []}
+
+    # Single query to fetch ALL task instances for ALL runs
+    tis_query = (
+        select(
+            TaskInstance.run_id,
+            TaskInstance.task_id,
+            TaskInstance.state,
+            TaskInstance.dag_version_id,
+            TaskInstance.start_date,
+            TaskInstance.end_date,
+        )
+        .where(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.run_id.in_(run_ids),
+        )
+        .order_by(TaskInstance.run_id, TaskInstance.task_id)
+    )
+
+    task_instances = list(session.execute(tis_query))
+
+    # Group by run_id and collect unique dag_version_ids
+    tis_by_run = collections.defaultdict(list)
+    dag_version_ids = set()
+    for ti in task_instances:
+        tis_by_run[ti.run_id].append(ti)
+        if ti.dag_version_id:
+            dag_version_ids.add(ti.dag_version_id)
+
+    # Fetch all needed serialized DAGs in one query
+    serdags_by_version = {}
+    if dag_version_ids:
+        serdags = session.scalars(
+            select(SerializedDagModel)
+            .join(DagVersion, SerializedDagModel.dag_version_id == 
DagVersion.id)
+            .where(DagVersion.id.in_(dag_version_ids))
+        )
+        for serdag in serdags:
+            if serdag.dag_version_id:
+                serdags_by_version[serdag.dag_version_id] = serdag
+
+    # Process each run
+    summaries = []
+    for run_id in run_ids:
+        tis = tis_by_run.get(run_id, [])
+        if not tis:
+            continue
+
+        # Build ti_details structure
+        ti_details = collections.defaultdict(list)
+        for ti in tis:
+            ti_details[ti.task_id].append(
+                {
+                    "state": ti.state,
+                    "start_date": ti.start_date,
+                    "end_date": ti.end_date,
+                }
+            )
+
+        # Get the appropriate serdag
+        dag_version_id = tis[0].dag_version_id if tis else None
+        serdag = serdags_by_version.get(dag_version_id) if dag_version_id else 
None
+
+        if not serdag:
+            log.warning(
+                "No serialized dag found for run",
+                dag_id=dag_id,
+                run_id=run_id,
+                dag_version_id=dag_version_id,
+            )
+            continue
+
+        # Helper function to generate node summaries
+        def get_node_summaries():
+            yielded_task_ids: set[str] = set()
+
+            # Yield all nodes discoverable from the serialized DAG structure
+            for node in _find_aggregates(
+                node=serdag.dag.task_group,
+                parent_node=None,
+                ti_details=ti_details,
+            ):
+                if node["type"] in {"task", "mapped_task"}:
+                    yielded_task_ids.add(node["task_id"])
+                    if node["type"] == "task":
+                        node["child_states"] = None
+                        node["min_start_date"] = None
+                        node["max_end_date"] = None
+                yield node
+
+            # For good history: add synthetic leaf nodes for task_ids that 
have TIs in this run
+            # but are not present in the current DAG structure (e.g. removed 
tasks)
+            missing_task_ids = set(ti_details.keys()) - yielded_task_ids
+            for task_id in sorted(missing_task_ids):
+                detail = ti_details[task_id]
+                # Create a leaf task node with aggregated state from its TIs
+                agg = _get_aggs_for_node(detail)
+                yield {
+                    "task_id": task_id,
+                    "type": "task",
+                    "parent_id": None,
+                    **agg,
+                    # Align with leaf behavior
+                    "child_states": None,
+                    "min_start_date": None,
+                    "max_end_date": None,
+                }
+
+        task_instances_list = list(get_node_summaries())
+
+        # If a group id and a task id collide, prefer the group record
+        group_ids = {n.get("task_id") for n in task_instances_list if 
n.get("type") == "group"}
+        filtered = [
+            n for n in task_instances_list if not (n.get("type") == "task" and 
n.get("task_id") in group_ids)
+        ]
+
+        summaries.append(
+            {
+                "run_id": run_id,
+                "dag_id": dag_id,
+                "task_instances": filtered,
+            }
+        )
+
+    return {"dag_id": dag_id, "summaries": summaries}

Review Comment:
   Same here. The logic is exact same in `get_grid_ti_summaries` route, maybe 
we can have another common util function.



##########
airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py:
##########
@@ -130,3 +139,148 @@ def _find_aggregates(
             **_get_aggs_for_node(details),
         }
         return
+
+
+def get_batch_ti_summaries(
+    dag_id: str,
+    run_ids: list[str],
+    session: Session,
+) -> dict:
+    """
+    Fetch task instance summaries for multiple runs in a single query.
+
+    This is much more efficient than the N+1 query pattern of calling
+    get_grid_ti_summaries multiple times.
+
+    Returns a dict with structure:
+    {
+        "dag_id": str,
+        "summaries": [GridTISummaries, ...]
+    }
+    """
+    if not run_ids:
+        return {"dag_id": dag_id, "summaries": []}
+
+    # Single query to fetch ALL task instances for ALL runs
+    tis_query = (
+        select(
+            TaskInstance.run_id,
+            TaskInstance.task_id,
+            TaskInstance.state,
+            TaskInstance.dag_version_id,
+            TaskInstance.start_date,
+            TaskInstance.end_date,
+        )
+        .where(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.run_id.in_(run_ids),
+        )
+        .order_by(TaskInstance.run_id, TaskInstance.task_id)
+    )
+
+    task_instances = list(session.execute(tis_query))
+
+    # Group by run_id and collect unique dag_version_ids
+    tis_by_run = collections.defaultdict(list)
+    dag_version_ids = set()
+    for ti in task_instances:
+        tis_by_run[ti.run_id].append(ti)
+        if ti.dag_version_id:
+            dag_version_ids.add(ti.dag_version_id)
+
+    # Fetch all needed serialized DAGs in one query
+    serdags_by_version = {}
+    if dag_version_ids:
+        serdags = session.scalars(
+            select(SerializedDagModel)
+            .join(DagVersion, SerializedDagModel.dag_version_id == 
DagVersion.id)
+            .where(DagVersion.id.in_(dag_version_ids))
+        )
+        for serdag in serdags:
+            if serdag.dag_version_id:
+                serdags_by_version[serdag.dag_version_id] = serdag
+
+    # Process each run
+    summaries = []
+    for run_id in run_ids:
+        tis = tis_by_run.get(run_id, [])
+        if not tis:
+            continue
+
+        # Build ti_details structure
+        ti_details = collections.defaultdict(list)
+        for ti in tis:
+            ti_details[ti.task_id].append(
+                {
+                    "state": ti.state,
+                    "start_date": ti.start_date,
+                    "end_date": ti.end_date,
+                }
+            )
+
+        # Get the appropriate serdag
+        dag_version_id = tis[0].dag_version_id if tis else None
+        serdag = serdags_by_version.get(dag_version_id) if dag_version_id else 
None
+
+        if not serdag:
+            log.warning(
+                "No serialized dag found for run",
+                dag_id=dag_id,
+                run_id=run_id,
+                dag_version_id=dag_version_id,
+            )
+            continue
+
+        # Helper function to generate node summaries
+        def get_node_summaries():

Review Comment:
   It seems the `get_node_summaries` helper function can be moved to services 
as well. Since `get_grid_ti_summaries` route will use the same implementation.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to