ephraimbuddy commented on code in PR #62294:
URL: https://github.com/apache/airflow/pull/62294#discussion_r2986927177


##########
airflow-core/src/airflow/api_fastapi/core_api/security.py:
##########
@@ -182,7 +182,14 @@ def inner(
 class PermittedDagFilter(OrmClause[set[str]]):
     """A parameter that filters the permitted dags for the user."""
 
+    def __init__(self, value: set[str] | None = None, skip_filter: bool = 
False):
+        super().__init__(value=value)
+        self.skip_filter = skip_filter
+
     def to_orm(self, select: Select) -> Select:
+        if self.skip_filter:
+            return select

Review Comment:
   I think other specialized filters can benefit from this more e.g:
   ```python
   class PermittedDagFilter(OrmClause[set[str]]):
       dag_id_column = DagModel.dag_id
   
       def __init__(self, value: set[str] | None = None, skip_filter: bool = 
False):
           super().__init__(value=value)
           self.skip_filter = skip_filter
   
       def to_orm(self, select: Select) -> Select:
           if self.skip_filter:
               return select
           return select.where(self.dag_id_column.in_(self.value or set()))
   ```
   then:
   ```python
   class PermittedDagRunFilter(PermittedDagFilter):
       dag_id_column = DagRun.dag_id
   
   class PermittedTIFilter(PermittedDagFilter):
       dag_id_column = TI.dag_id
   
   class PermittedDagVersionFilter(PermittedDagFilter):
       dag_id_column = DagVersion.dag_id
   ```
   The eventlog can still do:
   ```python
   if self.skip_filter:
       return select
   return select.where(or_(Log.dag_id.in_(self.value or set()), 
Log.dag_id.is_(None)))
   ```



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -495,47 +562,53 @@ def get_task_instances(
                 status.HTTP_404_NOT_FOUND,
                 f"DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` 
was not found",
             )
-        query = query.where(TI.run_id == dag_run_id)
+        filters.append(FilterParam(TI.run_id, dag_run_id))
     if dag_id != "~":
         dag = get_dag_for_run_or_latest_version(dag_bag, dag_run, dag_id, 
session)
-        query = query.where(TI.dag_id == dag_id)
+        filters.append(FilterParam(TI.dag_id, dag_id))
         if dag:
             task_group_id.dag = dag
 
-    task_instance_select, total_entries = paginated_select(
-        statement=query,
-        filters=[
-            run_after_range,
-            logical_date_range,
-            start_date_range,
-            end_date_range,
-            update_at_range,
-            duration_range,
-            state,
-            pool,
-            pool_name_pattern,
-            queue,
-            queue_name_pattern,
-            executor,
-            task_id,
-            task_display_name_pattern,
-            task_group_id,
-            dag_id_pattern,
-            run_id_pattern,
-            version_number,
-            readable_ti_filter,
-            try_number,
-            operator,
-            operator_name_pattern,
-            map_index,
-        ],
+    task_instance_id_select = 
apply_filters_to_select(statement=task_instance_id_select, filters=filters)
+    total_entries = get_query_count(task_instance_id_select, session=session, 
allow_estimation=True)
+
+    # Apply ordering/pagination to the ID query.
+    # Join DagRun only when sorting by DagRun fields
+    order_by_column_names = [v.lstrip("-") for v in (order_by.value or [])]
+    dag_run_column_names = [
+        col.name for col in dag_run_columns + (DagRun.data_interval_start, 
DagRun.data_interval_end)
+    ]
+    if any(name in dag_run_column_names for name in order_by_column_names):
+        task_instance_id_select = task_instance_id_select.join(TI.dag_run)
+    task_instance_id_select, _ = paginated_select(
+        statement=task_instance_id_select,
         order_by=order_by,
         offset=offset,
         limit=limit,
-        session=session,
+        return_total_entries=False,
     )
+    task_instance_id = 
list(session.scalars(task_instance_id_select.with_only_columns(TI.id)))
+
+    # Fetch full TI rows for the paginated ids
+    query = (
+        select(TI)
+        .where(TI.id.in_(task_instance_id))
+        .join(TI.dag_run)
+        .options(
+            load_only(*task_instance_columns),
+            contains_eager(TI.dag_run).load_only(*dag_run_columns),
+            joinedload(TI.trigger),
+        )
+    )
+    query = query.outerjoin(DagRun.dag_model).options(
+        
contains_eager(TI.dag_run).contains_eager(DagRun.dag_model).load_only(*dag_model_columns)
+    )
+    query = 
query.join(TI.dag_version).options(contains_eager(TI.dag_version).joinedload(DagVersion.bundle))

Review Comment:
   This does an inner join and can exclude TI with null dag_version_id coming 
from airflow 2 upgrade. I'm not sure if those TI's dag_version_id ever gets 
updated somewhere else, so we need to check that



##########
airflow-core/src/airflow/api_fastapi/core_api/routes/public/task_instances.py:
##########
@@ -495,47 +562,53 @@ def get_task_instances(
                 status.HTTP_404_NOT_FOUND,
                 f"DagRun with dag_id: `{dag_id}` and run_id: `{dag_run_id}` 
was not found",
             )
-        query = query.where(TI.run_id == dag_run_id)
+        filters.append(FilterParam(TI.run_id, dag_run_id))
     if dag_id != "~":
         dag = get_dag_for_run_or_latest_version(dag_bag, dag_run, dag_id, 
session)
-        query = query.where(TI.dag_id == dag_id)
+        filters.append(FilterParam(TI.dag_id, dag_id))
         if dag:
             task_group_id.dag = dag
 
-    task_instance_select, total_entries = paginated_select(
-        statement=query,
-        filters=[
-            run_after_range,
-            logical_date_range,
-            start_date_range,
-            end_date_range,
-            update_at_range,
-            duration_range,
-            state,
-            pool,
-            pool_name_pattern,
-            queue,
-            queue_name_pattern,
-            executor,
-            task_id,
-            task_display_name_pattern,
-            task_group_id,
-            dag_id_pattern,
-            run_id_pattern,
-            version_number,
-            readable_ti_filter,
-            try_number,
-            operator,
-            operator_name_pattern,
-            map_index,
-        ],
+    task_instance_id_select = 
apply_filters_to_select(statement=task_instance_id_select, filters=filters)

Review Comment:
   This is applied before any DagRun join exists. and line 581 only joins 
DagRun when sorting on a DagRun field. On /dags/<dag>/dagRuns/~/taskInstances 
for a DAG with more than one run, `logical_date_* / run_after_*` filters will 
therefore behave like a cartesian join and overcount or mis-page task instances 
instead of matching each TI to its own DagRun.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to