This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9e8627faa7 Fix overriding `default_args` in nested task groups (#31608)
9e8627faa7 is described below
commit 9e8627faa71e9d2047816b291061c28585809508
Author: Hussein Awala <[email protected]>
AuthorDate: Tue May 30 17:31:34 2023 +0300
Fix overriding `default_args` in nested task groups (#31608)
* add unit tests for default_args overriding in task group
Signed-off-by: Hussein Awala <[email protected]>
* fix overriding default args in nested task groups
Signed-off-by: Hussein Awala <[email protected]>
* Update airflow/utils/task_group.py
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---------
Signed-off-by: Hussein Awala <[email protected]>
Co-authored-by: Ash Berlin-Taylor <[email protected]>
---
airflow/utils/task_group.py | 5 +++
tests/decorators/test_task_group.py | 74 ++++++++++++++++++++++++++++++++++++
tests/utils/test_task_group.py | 75 +++++++++++++++++++++++++++++++++++++
3 files changed, 154 insertions(+)
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 85ec1eb0d3..4b8392be7b 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -140,6 +140,7 @@ class TaskGroup(DAGNode):
if parent_group:
parent_group.add(self)
+ self._update_default_args(parent_group)
self.used_group_ids.add(self.group_id)
if self.group_id:
@@ -176,6 +177,10 @@ class TaskGroup(DAGNode):
else:
self._group_id = f"{base}__{suffixes[-1] + 1}"
+ def _update_default_args(self, parent_group: TaskGroup):
+ if parent_group.default_args:
+ self.default_args = {**self.default_args,
**parent_group.default_args}
+
@classmethod
def create_root(cls, dag: DAG) -> TaskGroup:
"""Create a root TaskGroup with no group_id or parent."""
diff --git a/tests/decorators/test_task_group.py
b/tests/decorators/test_task_group.py
index 38b54cf1c5..3462c3a1d8 100644
--- a/tests/decorators/test_task_group.py
+++ b/tests/decorators/test_task_group.py
@@ -17,11 +17,14 @@
# under the License.
from __future__ import annotations
+from datetime import timedelta
+
import pendulum
import pytest
from airflow.decorators import dag, task_group
from airflow.models.expandinput import DictOfListsExpandInput,
ListOfDictsExpandInput, MappedArgument
+from airflow.operators.empty import EmptyOperator
from airflow.utils.task_group import MappedTaskGroup
@@ -186,3 +189,74 @@ def test_expand_kwargs_create_mapped():
assert tg._expand_input == ListOfDictsExpandInput([{"b": "x"}, {"b":
None}])
assert saved == {"a": 1, "b": MappedArgument(input=tg._expand_input,
key="b")}
+
+
+def test_override_dag_default_args():
+ @dag(
+ dag_id="test_dag",
+ start_date=pendulum.parse("20200101"),
+ default_args={
+ "retries": 1,
+ "owner": "x",
+ },
+ )
+ def pipeline():
+ @task_group(
+ group_id="task_group",
+ default_args={
+ "owner": "y",
+ "execution_timeout": timedelta(seconds=10),
+ },
+ )
+ def tg():
+ EmptyOperator(task_id="task")
+
+ tg()
+
+ test_dag = pipeline()
+ test_task =
test_dag.task_group_dict["task_group"].children["task_group.task"]
+ assert test_task.retries == 1
+ assert test_task.owner == "y"
+ assert test_task.execution_timeout == timedelta(seconds=10)
+
+
+def test_override_dag_default_args_nested_tg():
+ @dag(
+ dag_id="test_dag",
+ start_date=pendulum.parse("20200101"),
+ default_args={
+ "retries": 1,
+ "owner": "x",
+ },
+ )
+ def pipeline():
+ @task_group(
+ group_id="task_group",
+ default_args={
+ "owner": "y",
+ "execution_timeout": timedelta(seconds=10),
+ },
+ )
+ def tg():
+ @task_group(group_id="nested_task_group")
+ def nested_tg():
+ @task_group(group_id="another_task_group")
+ def another_tg():
+ EmptyOperator(task_id="task")
+
+ another_tg()
+
+ nested_tg()
+
+ tg()
+
+ test_dag = pipeline()
+ test_task = (
+ test_dag.task_group_dict["task_group"]
+ .children["task_group.nested_task_group"]
+ .children["task_group.nested_task_group.another_task_group"]
+ .children["task_group.nested_task_group.another_task_group.task"]
+ )
+ assert test_task.retries == 1
+ assert test_task.owner == "y"
+ assert test_task.execution_timeout == timedelta(seconds=10)
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index ba6f174773..c9927eb3ba 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations
+from datetime import timedelta
+
import pendulum
import pytest
@@ -1301,3 +1303,76 @@ def test_iter_tasks():
"section_2.task3",
"section_2.bash_task",
]
+
+
+def test_override_dag_default_args():
+ with DAG(
+ dag_id="test_dag",
+ start_date=pendulum.parse("20200101"),
+ default_args={
+ "retries": 1,
+ "owner": "x",
+ },
+ ):
+ with TaskGroup(
+ group_id="task_group",
+ default_args={
+ "owner": "y",
+ "execution_timeout": timedelta(seconds=10),
+ },
+ ):
+ task = EmptyOperator(task_id="task")
+
+ assert task.retries == 1
+ assert task.owner == "y"
+ assert task.execution_timeout == timedelta(seconds=10)
+
+
+def test_override_dag_default_args_in_nested_tg():
+ with DAG(
+ dag_id="test_dag",
+ start_date=pendulum.parse("20200101"),
+ default_args={
+ "retries": 1,
+ "owner": "x",
+ },
+ ):
+ with TaskGroup(
+ group_id="task_group",
+ default_args={
+ "owner": "y",
+ "execution_timeout": timedelta(seconds=10),
+ },
+ ):
+ with TaskGroup(group_id="nested_task_group"):
+ task = EmptyOperator(task_id="task")
+
+ assert task.retries == 1
+ assert task.owner == "y"
+ assert task.execution_timeout == timedelta(seconds=10)
+
+
+def test_override_dag_default_args_in_multi_level_nested_tg():
+ with DAG(
+ dag_id="test_dag",
+ start_date=pendulum.parse("20200101"),
+ default_args={
+ "retries": 1,
+ "owner": "x",
+ },
+ ):
+ with TaskGroup(
+ group_id="task_group",
+ default_args={
+ "owner": "y",
+ "execution_timeout": timedelta(seconds=10),
+ },
+ ):
+ with TaskGroup(group_id="first_nested_task_group"):
+ with TaskGroup(group_id="second_nested_task_group"):
+ with TaskGroup(group_id="third_nested_task_group"):
+ task = EmptyOperator(task_id="task")
+
+ assert task.retries == 1
+ assert task.owner == "y"
+ assert task.execution_timeout == timedelta(seconds=10)