jason810496 commented on code in PR #58092:
URL: https://github.com/apache/airflow/pull/58092#discussion_r2610557255


##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
         raise NotImplementedError("Use search_param_factory instead , depends 
is not implemented.")
 
 
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+    """Task group filter - returns all tasks in the specified group."""
+
+    def __init__(self, dag=None, skip_none: bool = True):
+        super().__init__(skip_none=skip_none)
+        self._dag: None | Any = dag
+
+    @property
+    def dag(self) -> None | Any:
+        return self._dag
+
+    @dag.setter
+    def dag(self, value: None | Any) -> None:
+        self._dag = value
+
+    def to_orm(self, select: Select) -> Select:
+        if self.value is None and self.skip_none:
+            return select
+
+        if not self.dag:
+            raise ValueError("DAG must be set before calling to_orm")
+
+        if not hasattr(self.dag, "task_group"):
+            return select
+
+        matching_task_ids = []
+
+        # Exact matching on group_id
+        task_groups = self.dag.task_group.get_task_group_dict()
+        task_group = task_groups.get(self.value)
+        if task_group:
+            matching_task_ids = [task.task_id for task in 
task_group.iter_tasks()]
+
+        return select.where(TaskInstance.task_id.in_(matching_task_ids))

Review Comment:
   It seems we could simplify the filter as:
   ```suggestion
           return select.where(TaskInstance.task_id.in_(task.task_id for task 
in task_group.iter_tasks()))
   ```



##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
         raise NotImplementedError("Use search_param_factory instead , depends 
is not implemented.")
 
 
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+    """Task group filter - returns all tasks in the specified group."""
+
+    def __init__(self, dag=None, skip_none: bool = True):
+        super().__init__(skip_none=skip_none)
+        self._dag: None | Any = dag
+
+    @property
+    def dag(self) -> None | Any:
+        return self._dag
+
+    @dag.setter
+    def dag(self, value: None | Any) -> None:
+        self._dag = value

Review Comment:
   The `SerializedDAG` can be imported from 
`airflow.serialization.serialized_objects` under `TYPE_CHECKING` block.
   
   ```suggestion
           self._dag: None | SerializedDAG = dag
   
       @property
       def dag(self) -> None | SerializedDAG:
           return self._dag
   
       @dag.setter
       def dag(self, value: SerializedDAG) -> None:
           self._dag = value
   ```



##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
         raise NotImplementedError("Use search_param_factory instead , depends 
is not implemented.")
 
 
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+    """Task group filter - returns all tasks in the specified group."""
+
+    def __init__(self, dag=None, skip_none: bool = True):
+        super().__init__(skip_none=skip_none)
+        self._dag: None | Any = dag
+
+    @property
+    def dag(self) -> None | Any:
+        return self._dag
+
+    @dag.setter
+    def dag(self, value: None | Any) -> None:
+        self._dag = value
+
+    def to_orm(self, select: Select) -> Select:
+        if self.value is None and self.skip_none:
+            return select
+
+        if not self.dag:
+            raise ValueError("DAG must be set before calling to_orm")

Review Comment:
   ```suggestion
               raise ValueError("Dag must be set before calling to_orm")
   ```



##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
         raise NotImplementedError("Use search_param_factory instead , depends 
is not implemented.")
 
 
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+    """Task group filter - returns all tasks in the specified group."""
+
+    def __init__(self, dag=None, skip_none: bool = True):
+        super().__init__(skip_none=skip_none)
+        self._dag: None | Any = dag
+
+    @property
+    def dag(self) -> None | Any:
+        return self._dag
+
+    @dag.setter
+    def dag(self, value: None | Any) -> None:
+        self._dag = value
+
+    def to_orm(self, select: Select) -> Select:
+        if self.value is None and self.skip_none:
+            return select
+
+        if not self.dag:
+            raise ValueError("DAG must be set before calling to_orm")
+
+        if not hasattr(self.dag, "task_group"):
+            return select
+
+        matching_task_ids = []
+
+        # Exact matching on group_id

Review Comment:
   Not sure should we raise Not Found error if the task_group not found like:
   
   
https://github.com/apache/airflow/blob/7c6fcecc535de447482b18c85fb98235655e745d/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py#L950-L976



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to