jason810496 commented on code in PR #56022:
URL: https://github.com/apache/airflow/pull/56022#discussion_r2438365018


##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -626,9 +663,35 @@ def get_list_dag_runs_batch(
         session=session,
     )
 
-    dag_runs = session.scalars(dag_runs_select)
-
-    return DAGRunCollectionResponse(
-        dag_runs=dag_runs,
-        total_entries=total_entries,
-    )
+    dag_runs = list(session.scalars(dag_runs_select))
+

Review Comment:
   It would be nice to modularize the post processing logic for binding 
`dr.consumed_assets` and `dr.produced_assets` in 
`airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py`



##########
providers/fab/src/airflow/providers/fab/auth_manager/security_manager/override.py:
##########
@@ -127,7 +127,7 @@ def _iter_dags() -> Iterable[DAG | SerializedDAG]:
     from airflow.models.dagbag import DagBag  # type: ignore[attr-defined, 
no-redef]
 
     def _iter_dags() -> Iterable[DAG | SerializedDAG]:
-        dagbag = DagBag(read_dags_from_db=True)  # type: ignore[call-arg]
+        dagbag = DagBag(read_dags_from_db=True)  # type: ignore[misc, call-arg]

Review Comment:
   It seems to be not related change.



##########
airflow-core/src/airflow/api_fastapi/core_api/datamodels/dag_run.py:
##########
@@ -61,6 +68,15 @@ class DAGRunClearBody(StrictBaseModel):
     )
 
 
+class AssetSummary(PydanticBaseModel):

Review Comment:
   May I ask why do we need `PydanticBaseModel` instead of `BaseModel` here ?



##########
airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_run.py:
##########
@@ -1091,6 +1149,53 @@ def test_invalid_state(self, test_client, post_body, 
expected_response):
         assert response.status_code == 422
         assert response.json()["detail"] == expected_response
 
+    @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
+    def test_list_includes_asset_summaries_when_present(self, test_client, 
session):
+        prod_asset = Asset(name="batch_prod_asset", uri="file:///batch_prod")
+        prod_model = AssetModel.from_public(prod_asset)
+        session.add(prod_model)
+        session.flush()
+        prod_event = AssetEvent(
+            asset_id=prod_model.id,
+            source_task_id="task_1",
+            source_dag_id=DAG1_ID,
+            source_run_id=DAG1_RUN1_ID,
+        )
+        session.add(prod_event)
+        session.flush()
+
+        cons_asset = Asset(name="batch_cons_asset", uri="file:///batch_cons")
+        cons_model = AssetModel.from_public(cons_asset)
+        session.add(cons_model)
+        session.flush()
+        cons_event = AssetEvent(
+            asset_id=cons_model.id,
+            source_task_id="task_2",
+            source_dag_id=DAG1_ID,
+            source_run_id=DAG1_RUN2_ID,
+        )
+        session.add(cons_event)
+        session.flush()
+        dr2 = session.scalar(select(DagRun).where(DagRun.dag_id == DAG1_ID, 
DagRun.run_id == DAG1_RUN2_ID))
+        dr2.consumed_asset_events.append(cons_event)
+        session.commit()
+
+        resp = test_client.post("/dags/~/dagRuns/list", json={})
+        assert resp.status_code == 200
+        body = resp.json()
+        runs_by_id = {each["dag_run_id"]: each for each in 
body.get("dag_runs", [])}
+
+        if "produced_assets" not in runs_by_id.get(DAG1_RUN1_ID, {}):
+            pytest.xfail("Batch list endpoint does not currently expose 
produced/consumed asset summaries")
+
+        prod_list = runs_by_id[DAG1_RUN1_ID].get("produced_assets", [])
+        assert isinstance(prod_list, list)
+        assert any(item.get("id") == prod_model.id for item in prod_list)
+
+        cons_list = runs_by_id[DAG1_RUN2_ID].get("consumed_assets", [])
+        assert isinstance(cons_list, list)
+        assert any(item.get("id") == cons_model.id for item in cons_list)

Review Comment:
   Would it be better to validate what the exact dict looks like for 
AssetSummary here?
   
   Somehow like
   ```python
   assert consumed_assets == [
       {
           "id": ...,
           "name": ...,
   ...
   ```
   which might be more readable for the `consumed_assets` and `produced_assets` 
response.



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -392,18 +401,46 @@ def get_dag_runs(
             run_id_pattern,
             triggering_user_name_pattern,
             dag_id_pattern,
+            consuming_asset,
+            producing_asset,
         ],
         order_by=order_by,
         offset=offset,
         limit=limit,
         session=session,
     )
-    dag_runs = session.scalars(dag_run_select)
-
-    return DAGRunCollectionResponse(
-        dag_runs=dag_runs,
-        total_entries=total_entries,
-    )
+    dag_runs = list(session.scalars(dag_run_select))
+
+    source_keys = {(dr.dag_id, dr.run_id) for dr in dag_runs}
+    produced_map: dict[tuple[str, str], list[AssetSummary]] = {}
+    if source_keys:
+        produced_events = session.scalars(
+            select(AssetEvent)
+            .join(AssetEvent.asset)
+            .where(
+                or_(
+                    *[
+                        and_(AssetEvent.source_dag_id == dag_id, 
AssetEvent.source_run_id == run_id)
+                        for dag_id, run_id in source_keys
+                    ]
+                )
+            )
+        ).all()
+        for ev in produced_events:
+            key = (ev.source_dag_id, ev.source_run_id)
+            produced_map.setdefault(key, []).append(
+                AssetSummary(id=ev.asset_id, name=ev.name, uri=ev.uri, 
group=ev.group)
+            )
+
+    for dr in dag_runs:
+        consumed_list: list[AssetSummary] = []
+        for ev in getattr(dr, "consumed_asset_events", []) or []:
+            consumed_list.append(AssetSummary(id=ev.asset_id, name=ev.name, 
uri=ev.uri, group=ev.group))
+        dr.consumed_assets = consumed_list
+        produced_list = produced_map.get((dr.dag_id, dr.run_id)) or []

Review Comment:
   ```suggestion
           produced_list = produced_map.get((dr.dag_id, dr.run_id), [])
   ```



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -392,18 +401,46 @@ def get_dag_runs(
             run_id_pattern,
             triggering_user_name_pattern,
             dag_id_pattern,
+            consuming_asset,
+            producing_asset,
         ],
         order_by=order_by,
         offset=offset,
         limit=limit,
         session=session,
     )
-    dag_runs = session.scalars(dag_run_select)
-
-    return DAGRunCollectionResponse(
-        dag_runs=dag_runs,
-        total_entries=total_entries,
-    )
+    dag_runs = list(session.scalars(dag_run_select))
+
+    source_keys = {(dr.dag_id, dr.run_id) for dr in dag_runs}
+    produced_map: dict[tuple[str, str], list[AssetSummary]] = {}
+    if source_keys:
+        produced_events = session.scalars(
+            select(AssetEvent)
+            .join(AssetEvent.asset)
+            .where(
+                or_(
+                    *[
+                        and_(AssetEvent.source_dag_id == dag_id, 
AssetEvent.source_run_id == run_id)
+                        for dag_id, run_id in source_keys
+                    ]
+                )
+            )
+        ).all()
+        for ev in produced_events:
+            key = (ev.source_dag_id, ev.source_run_id)
+            produced_map.setdefault(key, []).append(
+                AssetSummary(id=ev.asset_id, name=ev.name, uri=ev.uri, 
group=ev.group)
+            )
+
+    for dr in dag_runs:
+        consumed_list: list[AssetSummary] = []
+        for ev in getattr(dr, "consumed_asset_events", []) or []:

Review Comment:
   ```suggestion
           for ev in getattr(dr, "consumed_asset_events", []):
   ```



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -359,13 +363,18 @@ def get_dag_runs(
         Depends(search_param_factory(DagRun.triggering_user_name, 
"triggering_user_name_pattern")),
     ],
     dag_id_pattern: Annotated[_SearchParam, 
Depends(search_param_factory(DagRun.dag_id, "dag_id_pattern"))],
+    consuming_asset: QueryDagRunConsumingAssetFilter,
+    producing_asset: QueryDagRunProducingAssetFilter,
 ) -> DAGRunCollectionResponse:
     """
     Get all DAG Runs.
 
     This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for 
all DAGs.
     """
-    query = select(DagRun)
+    query = select(DagRun).options(
+        
selectinload(DagRun.consumed_asset_events).selectinload(AssetEvent.asset),
+        joinedload(DagRun.dag_model),
+    )

Review Comment:
   It seems we don't need these `selectinload` or `joinedload` statement here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to