dheerajturaga commented on code in PR #56837:
URL: https://github.com/apache/airflow/pull/56837#discussion_r2443539526


##########
airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py:
##########
@@ -130,3 +139,148 @@ def _find_aggregates(
             **_get_aggs_for_node(details),
         }
         return
+
+
+def get_batch_ti_summaries(
+    dag_id: str,
+    run_ids: list[str],
+    session: Session,
+) -> dict:
+    """
+    Fetch task instance summaries for multiple runs in a single query.
+
+    This is much more efficient than the N+1 query pattern of calling
+    get_grid_ti_summaries multiple times.
+
+    Returns a dict with structure:
+    {
+        "dag_id": str,
+        "summaries": [GridTISummaries, ...]
+    }
+    """
+    if not run_ids:
+        return {"dag_id": dag_id, "summaries": []}
+
+    # Single query to fetch ALL task instances for ALL runs
+    tis_query = (
+        select(
+            TaskInstance.run_id,
+            TaskInstance.task_id,
+            TaskInstance.state,
+            TaskInstance.dag_version_id,
+            TaskInstance.start_date,
+            TaskInstance.end_date,
+        )
+        .where(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.run_id.in_(run_ids),
+        )
+        .order_by(TaskInstance.run_id, TaskInstance.task_id)
+    )
+
+    task_instances = list(session.execute(tis_query))
+
+    # Group by run_id and collect unique dag_version_ids
+    tis_by_run = collections.defaultdict(list)
+    dag_version_ids = set()
+    for ti in task_instances:
+        tis_by_run[ti.run_id].append(ti)
+        if ti.dag_version_id:
+            dag_version_ids.add(ti.dag_version_id)
+
+    # Fetch all needed serialized DAGs in one query
+    serdags_by_version = {}
+    if dag_version_ids:
+        serdags = session.scalars(
+            select(SerializedDagModel)
+            .join(DagVersion, SerializedDagModel.dag_version_id == 
DagVersion.id)
+            .where(DagVersion.id.in_(dag_version_ids))
+        )
+        for serdag in serdags:
+            if serdag.dag_version_id:
+                serdags_by_version[serdag.dag_version_id] = serdag
+
+    # Process each run
+    summaries = []
+    for run_id in run_ids:
+        tis = tis_by_run.get(run_id, [])
+        if not tis:
+            continue
+
+        # Build ti_details structure
+        ti_details = collections.defaultdict(list)
+        for ti in tis:
+            ti_details[ti.task_id].append(
+                {
+                    "state": ti.state,
+                    "start_date": ti.start_date,
+                    "end_date": ti.end_date,
+                }
+            )
+
+        # Get the appropriate serdag
+        dag_version_id = tis[0].dag_version_id if tis else None
+        serdag = serdags_by_version.get(dag_version_id) if dag_version_id else 
None
+
+        if not serdag:
+            log.warning(
+                "No serialized dag found for run",
+                dag_id=dag_id,
+                run_id=run_id,
+                dag_version_id=dag_version_id,
+            )
+            continue
+
+        # Helper function to generate node summaries
+        def get_node_summaries():

Review Comment:
   Done!



##########
airflow-core/src/airflow/api_fastapi/core_api/services/ui/grid.py:
##########
@@ -130,3 +139,148 @@ def _find_aggregates(
             **_get_aggs_for_node(details),
         }
         return
+
+
+def get_batch_ti_summaries(
+    dag_id: str,
+    run_ids: list[str],
+    session: Session,
+) -> dict:
+    """
+    Fetch task instance summaries for multiple runs in a single query.
+
+    This is much more efficient than the N+1 query pattern of calling
+    get_grid_ti_summaries multiple times.
+
+    Returns a dict with structure:
+    {
+        "dag_id": str,
+        "summaries": [GridTISummaries, ...]
+    }
+    """
+    if not run_ids:
+        return {"dag_id": dag_id, "summaries": []}
+
+    # Single query to fetch ALL task instances for ALL runs
+    tis_query = (
+        select(
+            TaskInstance.run_id,
+            TaskInstance.task_id,
+            TaskInstance.state,
+            TaskInstance.dag_version_id,
+            TaskInstance.start_date,
+            TaskInstance.end_date,
+        )
+        .where(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.run_id.in_(run_ids),
+        )
+        .order_by(TaskInstance.run_id, TaskInstance.task_id)
+    )
+
+    task_instances = list(session.execute(tis_query))
+
+    # Group by run_id and collect unique dag_version_ids
+    tis_by_run = collections.defaultdict(list)
+    dag_version_ids = set()
+    for ti in task_instances:
+        tis_by_run[ti.run_id].append(ti)
+        if ti.dag_version_id:
+            dag_version_ids.add(ti.dag_version_id)
+
+    # Fetch all needed serialized DAGs in one query
+    serdags_by_version = {}
+    if dag_version_ids:
+        serdags = session.scalars(
+            select(SerializedDagModel)
+            .join(DagVersion, SerializedDagModel.dag_version_id == 
DagVersion.id)
+            .where(DagVersion.id.in_(dag_version_ids))
+        )
+        for serdag in serdags:
+            if serdag.dag_version_id:
+                serdags_by_version[serdag.dag_version_id] = serdag
+
+    # Process each run
+    summaries = []
+    for run_id in run_ids:
+        tis = tis_by_run.get(run_id, [])
+        if not tis:
+            continue
+
+        # Build ti_details structure
+        ti_details = collections.defaultdict(list)
+        for ti in tis:
+            ti_details[ti.task_id].append(
+                {
+                    "state": ti.state,
+                    "start_date": ti.start_date,
+                    "end_date": ti.end_date,
+                }
+            )
+
+        # Get the appropriate serdag
+        dag_version_id = tis[0].dag_version_id if tis else None
+        serdag = serdags_by_version.get(dag_version_id) if dag_version_id else 
None
+
+        if not serdag:
+            log.warning(
+                "No serialized dag found for run",
+                dag_id=dag_id,
+                run_id=run_id,
+                dag_version_id=dag_version_id,
+            )
+            continue
+
+        # Helper function to generate node summaries
+        def get_node_summaries():
+            yielded_task_ids: set[str] = set()
+
+            # Yield all nodes discoverable from the serialized DAG structure
+            for node in _find_aggregates(
+                node=serdag.dag.task_group,
+                parent_node=None,
+                ti_details=ti_details,
+            ):
+                if node["type"] in {"task", "mapped_task"}:
+                    yielded_task_ids.add(node["task_id"])
+                    if node["type"] == "task":
+                        node["child_states"] = None
+                        node["min_start_date"] = None
+                        node["max_end_date"] = None
+                yield node
+
+            # For good history: add synthetic leaf nodes for task_ids that 
have TIs in this run
+            # but are not present in the current DAG structure (e.g. removed 
tasks)
+            missing_task_ids = set(ti_details.keys()) - yielded_task_ids
+            for task_id in sorted(missing_task_ids):
+                detail = ti_details[task_id]
+                # Create a leaf task node with aggregated state from its TIs
+                agg = _get_aggs_for_node(detail)
+                yield {
+                    "task_id": task_id,
+                    "type": "task",
+                    "parent_id": None,
+                    **agg,
+                    # Align with leaf behavior
+                    "child_states": None,
+                    "min_start_date": None,
+                    "max_end_date": None,
+                }
+
+        task_instances_list = list(get_node_summaries())
+
+        # If a group id and a task id collide, prefer the group record
+        group_ids = {n.get("task_id") for n in task_instances_list if 
n.get("type") == "group"}
+        filtered = [
+            n for n in task_instances_list if not (n.get("type") == "task" and 
n.get("task_id") in group_ids)
+        ]
+
+        summaries.append(
+            {
+                "run_id": run_id,
+                "dag_id": dag_id,
+                "task_instances": filtered,
+            }
+        )
+
+    return {"dag_id": dag_id, "summaries": summaries}

Review Comment:
   Done!



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to