pierrejeambrun commented on code in PR #46484:
URL: https://github.com/apache/airflow/pull/46484#discussion_r1946280413


##########
airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -63,15 +65,85 @@
 from airflow.api_fastapi.logging.decorators import action_logging
 from airflow.exceptions import ParamValidationError
 from airflow.listeners.listener import get_listener_manager
-from airflow.models import DAG, DagModel, DagRun
+from airflow.models import DAG, DagModel, DagRun, TaskInstance
 from airflow.models.dag_version import DagVersion
+from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.timetables.base import DataInterval
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 dag_run_router = AirflowRouter(tags=["DagRun"], 
prefix="/dags/{dag_id}/dagRuns")
 
 
+def get_dag_run_with_task_ids(dag_id: str, run_id: str, session: SessionDep) 
-> DagRun:
+    """Get the DagRun with the given task_ids."""
+    return session.scalar(
+        select(DagRun)
+        .options(joinedload(DagRun.task_instances).load_only("task_id"))
+        .filter_by(dag_id=dag_id, run_id=run_id)
+    )
+
+
+def get_dag_version_ids_among_dag_run(dag_id: str, dag_run_id: str, task_ids: 
list[str], session: SessionDep):
+    """Get the DagVersions from the TaskInstances and TaskInstanceHistories of 
the DagRun."""
+    task_instance_dag_version_ids = select(TaskInstance.dag_version_id).where(
+        TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, 
TaskInstance.task_id.in_(task_ids)
+    )
+    task_instance_history_dag_version_ids = 
select(TaskInstanceHistory.dag_version_id).where(
+        TaskInstanceHistory.dag_id == dag_id,
+        TaskInstanceHistory.run_id == dag_run_id,
+        TaskInstanceHistory.task_id.in_(task_ids),
+    )
+    return (
+        session.execute(
+            
task_instance_dag_version_ids.distinct().union(task_instance_history_dag_version_ids.distinct())
+        )
+        .scalars()
+        .all()
+    )

Review Comment:
   Ideally we would prefer to have a proxy object on the DagRun ORM object that 
would allow us to retrieve all that when doing `DagRun.dag_versions` (that 
would be a list of ORM DagVersion objects). Or simply a list of int 
representing the dag run version numbers.
   
   
   Importantly we need the ability to eagerly load that when we query DagRuns 
so we have no extra lazy queries emitted later on when accessing this property.
   
   You can take a look at relationships or associations proxies. Basically 
setting up two relationships (history_versions, and current_versions) and maybe 
adding a property that returns the concatenation of both. 
   
   Or even better 1 relationship that aggregates the version from related 
TaskInstance and TaskInstanceHistoryResponse (but that might be harder to 
implement.)



##########
airflow/api_fastapi/core_api/routes/public/dag_run.py:
##########
@@ -63,15 +65,85 @@
 from airflow.api_fastapi.logging.decorators import action_logging
 from airflow.exceptions import ParamValidationError
 from airflow.listeners.listener import get_listener_manager
-from airflow.models import DAG, DagModel, DagRun
+from airflow.models import DAG, DagModel, DagRun, TaskInstance
 from airflow.models.dag_version import DagVersion
+from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.timetables.base import DataInterval
 from airflow.utils.state import DagRunState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
 dag_run_router = AirflowRouter(tags=["DagRun"], 
prefix="/dags/{dag_id}/dagRuns")
 
 
+def get_dag_run_with_task_ids(dag_id: str, run_id: str, session: SessionDep) 
-> DagRun:
+    """Get the DagRun with the given task_ids."""
+    return session.scalar(
+        select(DagRun)
+        .options(joinedload(DagRun.task_instances).load_only("task_id"))
+        .filter_by(dag_id=dag_id, run_id=run_id)
+    )
+
+
+def get_dag_version_ids_among_dag_run(dag_id: str, dag_run_id: str, task_ids: 
list[str], session: SessionDep):
+    """Get the DagVersions from the TaskInstances and TaskInstanceHistories of 
the DagRun."""
+    task_instance_dag_version_ids = select(TaskInstance.dag_version_id).where(
+        TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, 
TaskInstance.task_id.in_(task_ids)
+    )
+    task_instance_history_dag_version_ids = 
select(TaskInstanceHistory.dag_version_id).where(
+        TaskInstanceHistory.dag_id == dag_id,
+        TaskInstanceHistory.run_id == dag_run_id,
+        TaskInstanceHistory.task_id.in_(task_ids),
+    )
+    return (
+        session.execute(
+            
task_instance_dag_version_ids.distinct().union(task_instance_history_dag_version_ids.distinct())
+        )
+        .scalars()
+        .all()
+    )

Review Comment:
   This is a fallback solution, because we need to emit an extra custom query 
to fetch those `version_ids` and implement a lot of application code logic to 
get the expected behavior.
   
   And this code should not live in the `routes` folder.



##########
airflow/api_fastapi/core_api/datamodels/dag_run.py:
##########
@@ -69,6 +70,7 @@ class DAGRunResponse(BaseModel):
     triggered_by: DagRunTriggeredByType
     conf: dict
     note: str | None
+    dag_versions: list[UUID]

Review Comment:
   We use the `version_number` as a key on other endpoints. I think we should 
return `dag_version_numbers` which makes also more sense from a use point of 
view. (UUID don't bring much information)



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to