This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 1afc42210b47fff7d3f3a7ba9516193c8bc344c7 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Tue Oct 29 15:45:41 2024 +0000 [skip-ci] --- task_sdk/src/airflow/sdk/definitions/taskgroup.py | 29 ++++++++--------------- tests/utils/test_task_group.py | 5 ++-- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/task_sdk/src/airflow/sdk/definitions/taskgroup.py b/task_sdk/src/airflow/sdk/definitions/taskgroup.py index 70f4537aa1a..de4bd0c771a 100644 --- a/task_sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task_sdk/src/airflow/sdk/definitions/taskgroup.py @@ -47,21 +47,6 @@ if TYPE_CHECKING: from airflow.sdk.definitions.mixins import DependencyMixin from airflow.serialization.enums import DagAttributeTypes -# TODO: The following mapping is used to validate that the arguments passed to the TaskGroup are of the -# correct type. This is a temporary solution until we find a more sophisticated method for argument -# validation. One potential method is to use get_type_hints from the typing module. However, this is not -# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python -# version that supports `get_type_hints` effectively or find a better approach, we can replace this -# manual type-checking method. -TASKGROUP_ARGS_EXPECTED_TYPES = { - "group_id": str, - "prefix_group_id": bool, - "tooltip": str, - "ui_color": str, - "ui_fgcolor": str, - "add_suffix_on_collision": bool, -} - def _default_parent_group() -> TaskGroup | None: from airflow.sdk.definitions.contextmanager import TaskGroupContext @@ -117,12 +102,14 @@ class TaskGroup(DAGNode): automatically add `__1` etc suffixes """ - _group_id: str | None + _group_id: str | None = attrs.field( + validator=attrs.validators.optional(attrs.validators.instance_of(str)) + ) prefix_group_id: bool = attrs.field(default=True) parent_group: TaskGroup | None = attrs.field(factory=_default_parent_group) dag: DAG = attrs.field(default=attrs.Factory(_default_dag, takes_self=True)) default_args: dict[str, Any] = attrs.field(factory=dict, converter=copy.deepcopy) - tooltip: str = "" + tooltip: str = attrs.field(default="", validator=attrs.validators.instance_of(str)) children: dict[str, DAGNode] = attrs.field(factory=dict, init=False) upstream_group_ids: set[str | None] = attrs.field(factory=set, init=False) @@ -136,8 +123,8 @@ class TaskGroup(DAGNode): on_setattr=attrs.setters.frozen, ) - ui_color: str = "CornflowerBlue" - ui_fgcolor: str = "#000" + ui_color: str = attrs.field(default="CornflowerBlue", validator=attrs.validators.instance_of(str)) + ui_fgcolor: str = attrs.field(default="#000", validator=attrs.validators.instance_of(str)) add_suffix_on_collision: bool = False @@ -151,6 +138,10 @@ class TaskGroup(DAGNode): # https://github.com/python-attrs/attrs/issues/342 self._check_for_group_id_collisions(self.add_suffix_on_collision) + if self._group_id and not self.parent_group and self.dag: + # Support `tg = TaskGroup(x, dag=dag)` + self.parent_group = self.dag.task_group + if self.parent_group: self.parent_group.add(self) if self.parent_group.default_args: diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py index 85c225cc9b6..b973a1f615e 100644 --- a/tests/utils/test_task_group.py +++ b/tests/utils/test_task_group.py @@ -1659,8 +1659,7 @@ def test_task_group_arrow_with_setup_group_deeper_setup(): def test_task_group_with_invalid_arg_type_raises_error(): - error_msg = "'ui_color' has an invalid type <class 'int'> with value 123, expected type is <class 'str'>" + error_msg = r"'ui_color' must be <class 'str'> \(got 123 that is a <class 'int'>\)\." with DAG(dag_id="dag_with_tg_invalid_arg_type", schedule=None): with pytest.raises(TypeError, match=error_msg): - with TaskGroup("group_1", ui_color=123): - EmptyOperator(task_id="task1") + _ = TaskGroup("group_1", ui_color=123)
