This is an automated email from the ASF dual-hosted git repository.
dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 05b39cf2ad Ignore teardowns and setups when arrowing from groups
(#32157)
05b39cf2ad is described below
commit 05b39cf2adb2998c01ff27057aeda585b4320d00
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Jun 27 15:07:18 2023 -0700
Ignore teardowns and setups when arrowing from groups (#32157)
---
airflow/utils/task_group.py | 17 +++++++++++++++--
tests/models/test_dag.py | 21 ++++++++++++++-------
tests/utils/test_task_group.py | 16 +++++++++++++++-
3 files changed, 44 insertions(+), 10 deletions(-)
diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 4d79d74e24..f20a5032a8 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -365,9 +365,22 @@ class TaskGroup(DAGNode):
Returns a generator of tasks that are leaf tasks, i.e. those with no
downstream
dependencies within the TaskGroup.
"""
+
+ def recurse_for_first_non_setup_teardown(group, task):
+ for upstream_task in task.upstream_list:
+ if not group.has_task(upstream_task):
+ continue
+ if upstream_task.is_setup or upstream_task.is_teardown:
+ yield from recurse_for_first_non_setup_teardown(group,
upstream_task)
+ else:
+ yield upstream_task
+
for task in self:
- if not any(self.has_task(child) for child in
task.get_direct_relatives(upstream=False)):
- yield task
+ if not any(self.has_task(x) for x in
task.get_direct_relatives(upstream=False)):
+ if not (task.is_teardown or task.is_setup):
+ yield task
+ else:
+ yield from recurse_for_first_non_setup_teardown(self, task)
def child_id(self, label):
"""
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 9eec01c44a..bd460e6cc7 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -3772,13 +3772,20 @@ class TestTaskClearingSetupTeardownBehavior:
g2_w3 = dag.task_dict["g2.w3"]
g2_group_teardown = dag.task_dict["g2.group_teardown"]
- with pytest.raises(Exception):
- # fix_me
- # the line `dag_setup >> tg >> dag_teardown` should be
equivalent to
- # dag_setup >> group_setup; w3 >> dag_teardown
- # i.e. not group_teardown >> dag_teardown
- assert g2_group_teardown.downstream_task_ids == {}
- assert g2_w3.downstream_task_ids == {"g2.group_teardown",
"dag_teardown"}
+ # the line `dag_setup >> tg >> dag_teardown` should be equivalent to
+ # dag_setup >> group_setup; w3 >> dag_teardown
+ # i.e. not group_teardown >> dag_teardown
+ # this way the two teardowns can run in parallel
+ # so first, check that dag_teardown not downstream of group 2 teardown
+ # this means they can run in parallel
+ assert "dag_teardown" not in g2_group_teardown.downstream_task_ids
+ # and just document that g2 teardown is in effect a dag leaf
+ assert g2_group_teardown.downstream_task_ids == set()
+ # group 2 task w3 is in the scope of 2 teardowns -- the dag teardown
and the group teardown
+ # it is arrowed to both of them
+ assert g2_w3.downstream_task_ids == {"g2.group_teardown",
"dag_teardown"}
+ # dag teardown should have 3 upstreams: the last work task in groups 1
and 2, and its setup
+ assert dag_teardown.upstream_task_ids == {"g1.w3", "g2.w3",
"dag_setup"}
assert {x.task_id for x in
g2_w2.get_upstreams_only_setups_and_teardowns()} == {
"dag_setup",
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index a5262aa2eb..e7088afbd0 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -24,7 +24,8 @@ import pytest
from airflow.decorators import dag, task_group as task_group_decorator
from airflow.exceptions import TaskAlreadyInTaskGroup
-from airflow.models import DAG
+from airflow.models.baseoperator import BaseOperator
+from airflow.models.dag import DAG
from airflow.models.xcom_arg import XComArg
from airflow.operators.bash import BashOperator
from airflow.operators.empty import EmptyOperator
@@ -1381,3 +1382,16 @@ def
test_override_dag_default_args_in_multi_level_nested_tg():
assert task.retries == 1
assert task.owner == "z"
assert task.execution_timeout == timedelta(seconds=10)
+
+
+def test_task_group_arrow_with_setups_teardowns():
+ with DAG(dag_id="hi", start_date=pendulum.datetime(2022, 1, 1)):
+ with TaskGroup(group_id="tg1") as tg1:
+ s1 = BaseOperator(task_id="s1")
+ w1 = BaseOperator(task_id="w1")
+ t1 = BaseOperator(task_id="t1")
+ s1 >> w1 >> t1.as_teardown(setups=s1)
+ w2 = BaseOperator(task_id="w2")
+ tg1 >> w2
+ assert t1.downstream_task_ids == set()
+ assert w1.downstream_task_ids == {"tg1.t1", "w2"}