This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 1afc42210b47fff7d3f3a7ba9516193c8bc344c7
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Tue Oct 29 15:45:41 2024 +0000

    [skip-ci]
---
 task_sdk/src/airflow/sdk/definitions/taskgroup.py | 29 ++++++++---------------
 tests/utils/test_task_group.py                    |  5 ++--
 2 files changed, 12 insertions(+), 22 deletions(-)

diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
index 70f4537aa1a..de4bd0c771a 100644
--- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -47,21 +47,6 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.mixins import DependencyMixin
     from airflow.serialization.enums import DagAttributeTypes
 
-# TODO: The following mapping is used to validate that the arguments passed to 
the TaskGroup are of the
-#  correct type. This is a temporary solution until we find a more 
sophisticated method for argument
-#  validation. One potential method is to use get_type_hints from the typing 
module. However, this is not
-#  fully compatible with future annotations for Python versions below 3.10. 
Once we require a minimum Python
-#  version that supports `get_type_hints` effectively or find a better 
approach, we can replace this
-#  manual type-checking method.
-TASKGROUP_ARGS_EXPECTED_TYPES = {
-    "group_id": str,
-    "prefix_group_id": bool,
-    "tooltip": str,
-    "ui_color": str,
-    "ui_fgcolor": str,
-    "add_suffix_on_collision": bool,
-}
-
 
 def _default_parent_group() -> TaskGroup | None:
     from airflow.sdk.definitions.contextmanager import TaskGroupContext
@@ -117,12 +102,14 @@ class TaskGroup(DAGNode):
         automatically add `__1` etc suffixes
     """
 
-    _group_id: str | None
+    _group_id: str | None = attrs.field(
+        validator=attrs.validators.optional(attrs.validators.instance_of(str))
+    )
     prefix_group_id: bool = attrs.field(default=True)
     parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group)
     dag: DAG = attrs.field(default=attrs.Factory(_default_dag, 
takes_self=True))
     default_args: dict[str, Any] = attrs.field(factory=dict, 
converter=copy.deepcopy)
-    tooltip: str = ""
+    tooltip: str = attrs.field(default="", 
validator=attrs.validators.instance_of(str))
     children: dict[str, DAGNode] = attrs.field(factory=dict, init=False)
 
     upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False)
@@ -136,8 +123,8 @@ class TaskGroup(DAGNode):
         on_setattr=attrs.setters.frozen,
     )
 
-    ui_color: str = "CornflowerBlue"
-    ui_fgcolor: str = "#000"
+    ui_color: str = attrs.field(default="CornflowerBlue", 
validator=attrs.validators.instance_of(str))
+    ui_fgcolor: str = attrs.field(default="#000", 
validator=attrs.validators.instance_of(str))
 
     add_suffix_on_collision: bool = False
 
@@ -151,6 +138,10 @@ class TaskGroup(DAGNode):
         # https://github.com/python-attrs/attrs/issues/342
         self._check_for_group_id_collisions(self.add_suffix_on_collision)
 
+        if self._group_id and not self.parent_group and self.dag:
+            # Support `tg = TaskGroup(x, dag=dag)`
+            self.parent_group = self.dag.task_group
+
         if self.parent_group:
             self.parent_group.add(self)
             if self.parent_group.default_args:
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index 85c225cc9b6..b973a1f615e 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -1659,8 +1659,7 @@ def test_task_group_arrow_with_setup_group_deeper_setup():
 
 
 def test_task_group_with_invalid_arg_type_raises_error():
-    error_msg = "'ui_color' has an invalid type <class 'int'> with value 123, 
expected type is <class 'str'>"
+    error_msg = r"'ui_color' must be <class 'str'> \(got 123 that is a <class 
'int'>\)\."
     with DAG(dag_id="dag_with_tg_invalid_arg_type", schedule=None):
         with pytest.raises(TypeError, match=error_msg):
-            with TaskGroup("group_1", ui_color=123):
-                EmptyOperator(task_id="task1")
+            _ = TaskGroup("group_1", ui_color=123)

Reply via email to