This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9e8627faa7 Fix overriding `default_args` in nested task groups (#31608)
9e8627faa7 is described below

commit 9e8627faa71e9d2047816b291061c28585809508
Author: Hussein Awala <[email protected]>
AuthorDate: Tue May 30 17:31:34 2023 +0300

    Fix overriding `default_args` in nested task groups (#31608)
    
    * add unit tests for default_args overriding in task group
    
    Signed-off-by: Hussein Awala <[email protected]>
    
    * fix overriding default args in nested task groups
    
    Signed-off-by: Hussein Awala <[email protected]>
    
    * Update airflow/utils/task_group.py
    
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
    
    ---------
    
    Signed-off-by: Hussein Awala <[email protected]>
    Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
 airflow/utils/task_group.py         |  5 +++
 tests/decorators/test_task_group.py | 74 ++++++++++++++++++++++++++++++++++++
 tests/utils/test_task_group.py      | 75 +++++++++++++++++++++++++++++++++++++
 3 files changed, 154 insertions(+)

diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 85ec1eb0d3..4b8392be7b 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -140,6 +140,7 @@ class TaskGroup(DAGNode):
 
         if parent_group:
             parent_group.add(self)
+            self._update_default_args(parent_group)
 
         self.used_group_ids.add(self.group_id)
         if self.group_id:
@@ -176,6 +177,10 @@ class TaskGroup(DAGNode):
             else:
                 self._group_id = f"{base}__{suffixes[-1] + 1}"
 
+    def _update_default_args(self, parent_group: TaskGroup):
+        if parent_group.default_args:
+            self.default_args = {**self.default_args, 
**parent_group.default_args}
+
     @classmethod
     def create_root(cls, dag: DAG) -> TaskGroup:
         """Create a root TaskGroup with no group_id or parent."""
diff --git a/tests/decorators/test_task_group.py 
b/tests/decorators/test_task_group.py
index 38b54cf1c5..3462c3a1d8 100644
--- a/tests/decorators/test_task_group.py
+++ b/tests/decorators/test_task_group.py
@@ -17,11 +17,14 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import timedelta
+
 import pendulum
 import pytest
 
 from airflow.decorators import dag, task_group
 from airflow.models.expandinput import DictOfListsExpandInput, 
ListOfDictsExpandInput, MappedArgument
+from airflow.operators.empty import EmptyOperator
 from airflow.utils.task_group import MappedTaskGroup
 
 
@@ -186,3 +189,74 @@ def test_expand_kwargs_create_mapped():
     assert tg._expand_input == ListOfDictsExpandInput([{"b": "x"}, {"b": 
None}])
 
     assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input, 
key="b")}
+
+
+def test_override_dag_default_args():
+    @dag(
+        dag_id="test_dag",
+        start_date=pendulum.parse("20200101"),
+        default_args={
+            "retries": 1,
+            "owner": "x",
+        },
+    )
+    def pipeline():
+        @task_group(
+            group_id="task_group",
+            default_args={
+                "owner": "y",
+                "execution_timeout": timedelta(seconds=10),
+            },
+        )
+        def tg():
+            EmptyOperator(task_id="task")
+
+        tg()
+
+    test_dag = pipeline()
+    test_task = 
test_dag.task_group_dict["task_group"].children["task_group.task"]
+    assert test_task.retries == 1
+    assert test_task.owner == "y"
+    assert test_task.execution_timeout == timedelta(seconds=10)
+
+
+def test_override_dag_default_args_nested_tg():
+    @dag(
+        dag_id="test_dag",
+        start_date=pendulum.parse("20200101"),
+        default_args={
+            "retries": 1,
+            "owner": "x",
+        },
+    )
+    def pipeline():
+        @task_group(
+            group_id="task_group",
+            default_args={
+                "owner": "y",
+                "execution_timeout": timedelta(seconds=10),
+            },
+        )
+        def tg():
+            @task_group(group_id="nested_task_group")
+            def nested_tg():
+                @task_group(group_id="another_task_group")
+                def another_tg():
+                    EmptyOperator(task_id="task")
+
+                another_tg()
+
+            nested_tg()
+
+        tg()
+
+    test_dag = pipeline()
+    test_task = (
+        test_dag.task_group_dict["task_group"]
+        .children["task_group.nested_task_group"]
+        .children["task_group.nested_task_group.another_task_group"]
+        .children["task_group.nested_task_group.another_task_group.task"]
+    )
+    assert test_task.retries == 1
+    assert test_task.owner == "y"
+    assert test_task.execution_timeout == timedelta(seconds=10)
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index ba6f174773..c9927eb3ba 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -17,6 +17,8 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import timedelta
+
 import pendulum
 import pytest
 
@@ -1301,3 +1303,76 @@ def test_iter_tasks():
         "section_2.task3",
         "section_2.bash_task",
     ]
+
+
+def test_override_dag_default_args():
+    with DAG(
+        dag_id="test_dag",
+        start_date=pendulum.parse("20200101"),
+        default_args={
+            "retries": 1,
+            "owner": "x",
+        },
+    ):
+        with TaskGroup(
+            group_id="task_group",
+            default_args={
+                "owner": "y",
+                "execution_timeout": timedelta(seconds=10),
+            },
+        ):
+            task = EmptyOperator(task_id="task")
+
+    assert task.retries == 1
+    assert task.owner == "y"
+    assert task.execution_timeout == timedelta(seconds=10)
+
+
+def test_override_dag_default_args_in_nested_tg():
+    with DAG(
+        dag_id="test_dag",
+        start_date=pendulum.parse("20200101"),
+        default_args={
+            "retries": 1,
+            "owner": "x",
+        },
+    ):
+        with TaskGroup(
+            group_id="task_group",
+            default_args={
+                "owner": "y",
+                "execution_timeout": timedelta(seconds=10),
+            },
+        ):
+            with TaskGroup(group_id="nested_task_group"):
+                task = EmptyOperator(task_id="task")
+
+    assert task.retries == 1
+    assert task.owner == "y"
+    assert task.execution_timeout == timedelta(seconds=10)
+
+
+def test_override_dag_default_args_in_multi_level_nested_tg():
+    with DAG(
+        dag_id="test_dag",
+        start_date=pendulum.parse("20200101"),
+        default_args={
+            "retries": 1,
+            "owner": "x",
+        },
+    ):
+        with TaskGroup(
+            group_id="task_group",
+            default_args={
+                "owner": "y",
+                "execution_timeout": timedelta(seconds=10),
+            },
+        ):
+            with TaskGroup(group_id="first_nested_task_group"):
+                with TaskGroup(group_id="second_nested_task_group"):
+                    with TaskGroup(group_id="third_nested_task_group"):
+                        task = EmptyOperator(task_id="task")
+
+    assert task.retries == 1
+    assert task.owner == "y"
+    assert task.execution_timeout == timedelta(seconds=10)

Reply via email to