pierrejeambrun commented on code in PR #58092:
URL: https://github.com/apache/airflow/pull/58092#discussion_r2603117761


##########
airflow-core/src/airflow/api_fastapi/common/parameters.py:
##########
@@ -185,6 +185,64 @@ def depends(cls, *args: Any, **kwargs: Any) -> Self:
         raise NotImplementedError("Use search_param_factory instead , depends 
is not implemented.")
 
 
+class QueryTITaskGroupDisplayNamePattern(BaseParam[str]):
+    """Task group display name pattern filter - returns all tasks in matching 
groups."""
+
+    def __init__(self, dag=None, skip_none: bool = True):
+        super().__init__(skip_none=skip_none)
+        self.dag = dag
+
+    def to_orm(self, select: Select) -> Select:
+        if self.value is None and self.skip_none:
+            return select
+
+        if self.dag and hasattr(self.dag, "task_group"):
+            task_groups = self.dag.task_group.get_task_group_dict()
+
+            # Pattern matching on both group display name and group_id
+            matching_task_ids = []
+            for group_id, task_group in task_groups.items():
+                if group_id is None:  # Skip root group
+                    continue
+
+                # Check both the display name (label) and the group_id for 
pattern matching
+                display_name = getattr(task_group, "label", None)
+                if (
+                    display_name and self._matches_pattern(display_name, 
self.value)
+                ) or self._matches_pattern(group_id, self.value):
+                    matching_task_ids.extend([task.task_id for task in 
task_group.iter_tasks()])
+
+            if matching_task_ids:
+                return 
select.where(TaskInstance.task_id.in_(matching_task_ids))
+
+        return select.where(TaskInstance.task_id.is_(None))
+
+    def _matches_pattern(self, display_name: str, pattern: str) -> bool:
+        """Check if display_name matches the SQL LIKE pattern or exact 
match."""
+        import re
+
+        if "%" in pattern:
+            pattern_temp = pattern.replace("%", "\x00").replace("_", "\x01")
+            escaped = re.escape(pattern_temp)
+            regex_pattern = escaped.replace("\x00", ".*").replace("\x01", ".")
+            return bool(re.match(f"^{regex_pattern}$", display_name, 
re.IGNORECASE))
+        return display_name.lower() == pattern.lower()

Review Comment:
   Maybe to keep things simple, we should just do an 'exact' match on the group.
   
   It's not really a 'group' pattern, because we are not returning all groups 
that match a pattern. 
   
   We are returning all TIs that matche a precise group. So a plain 
`QueryTaskInstanceTaskGroupFilter` is enough I think and will make it simpler 
to implement.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to