jason810496 commented on code in PR #58092:
URL: https://github.com/apache/airflow/pull/58092#discussion_r2610557255
##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use search_param_factory instead , depends
is not implemented.")
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+ """Task group filter - returns all tasks in the specified group."""
+
+ def __init__(self, dag=None, skip_none: bool = True):
+ super().__init__(skip_none=skip_none)
+ self._dag: None | Any = dag
+
+ @property
+ def dag(self) -> None | Any:
+ return self._dag
+
+ @dag.setter
+ def dag(self, value: None | Any) -> None:
+ self._dag = value
+
+ def to_orm(self, select: Select) -> Select:
+ if self.value is None and self.skip_none:
+ return select
+
+ if not self.dag:
+ raise ValueError("DAG must be set before calling to_orm")
+
+ if not hasattr(self.dag, "task_group"):
+ return select
+
+ matching_task_ids = []
+
+ # Exact matching on group_id
+ task_groups = self.dag.task_group.get_task_group_dict()
+ task_group = task_groups.get(self.value)
+ if task_group:
+ matching_task_ids = [task.task_id for task in
task_group.iter_tasks()]
+
+ return select.where(TaskInstance.task_id.in_(matching_task_ids))
Review Comment:
It seems we could simplify the filter as:
```suggestion
return select.where(TaskInstance.task_id.in_(task.task_id for task
in task_group.iter_tasks()))
```
##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use search_param_factory instead , depends
is not implemented.")
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+ """Task group filter - returns all tasks in the specified group."""
+
+ def __init__(self, dag=None, skip_none: bool = True):
+ super().__init__(skip_none=skip_none)
+ self._dag: None | Any = dag
+
+ @property
+ def dag(self) -> None | Any:
+ return self._dag
+
+ @dag.setter
+ def dag(self, value: None | Any) -> None:
+ self._dag = value
Review Comment:
The `SerializedDAG` can be imported from
`airflow.serialization.serialized_objects` under `TYPE_CHECKING` block.
```suggestion
self._dag: None | SerializedDAG = dag
@property
def dag(self) -> None | SerializedDAG:
return self._dag
@dag.setter
def dag(self, value: SerializedDAG) -> None:
self._dag = value
```
##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use search_param_factory instead , depends
is not implemented.")
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+ """Task group filter - returns all tasks in the specified group."""
+
+ def __init__(self, dag=None, skip_none: bool = True):
+ super().__init__(skip_none=skip_none)
+ self._dag: None | Any = dag
+
+ @property
+ def dag(self) -> None | Any:
+ return self._dag
+
+ @dag.setter
+ def dag(self, value: None | Any) -> None:
+ self._dag = value
+
+ def to_orm(self, select: Select) -> Select:
+ if self.value is None and self.skip_none:
+ return select
+
+ if not self.dag:
+ raise ValueError("DAG must be set before calling to_orm")
Review Comment:
```suggestion
raise ValueError("Dag must be set before calling to_orm")
```
##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,53 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
raise NotImplementedError("Use search_param_factory instead , depends
is not implemented.")
+class QueryTaskInstanceTaskGroupFilter(BaseParam[str]):
+ """Task group filter - returns all tasks in the specified group."""
+
+ def __init__(self, dag=None, skip_none: bool = True):
+ super().__init__(skip_none=skip_none)
+ self._dag: None | Any = dag
+
+ @property
+ def dag(self) -> None | Any:
+ return self._dag
+
+ @dag.setter
+ def dag(self, value: None | Any) -> None:
+ self._dag = value
+
+ def to_orm(self, select: Select) -> Select:
+ if self.value is None and self.skip_none:
+ return select
+
+ if not self.dag:
+ raise ValueError("DAG must be set before calling to_orm")
+
+ if not hasattr(self.dag, "task_group"):
+ return select
+
+ matching_task_ids = []
+
+ # Exact matching on group_id
Review Comment:
Not sure should we raise Not Found error if the task_group not found like:
https://github.com/apache/airflow/blob/7c6fcecc535de447482b18c85fb98235655e745d/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py#L950-L976
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]