This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cd7e7bcb23 Don't ignore setups when arrowing from group (#33097)
cd7e7bcb23 is described below

commit cd7e7bcb2310dea19f7ee946716a7c91ed610c68
Author: Daniel Standish <[email protected]>
AuthorDate: Tue Aug 8 12:02:47 2023 -0700

    Don't ignore setups when arrowing from group (#33097)
    
    This enables us to have a group with just setups in it.
---
 airflow/utils/task_group.py    |  14 +++--
 tests/utils/test_task_group.py | 120 ++++++++++++++++++++++++++++++++++++++++-
 2 files changed, 128 insertions(+), 6 deletions(-)

diff --git a/airflow/utils/task_group.py b/airflow/utils/task_group.py
index 3c3a01bc7d..b6c40a14a9 100644
--- a/airflow/utils/task_group.py
+++ b/airflow/utils/task_group.py
@@ -370,21 +370,25 @@ class TaskGroup(DAGNode):
         tasks = list(self)
         ids = {x.task_id for x in tasks}
 
-        def recurse_for_first_non_setup_teardown(task):
+        def recurse_for_first_non_teardown(task):
             for upstream_task in task.upstream_list:
                 if upstream_task.task_id not in ids:
+                    # upstream task is not in task group
+                    continue
+                elif upstream_task.is_teardown:
+                    yield from recurse_for_first_non_teardown(upstream_task)
+                elif task.is_teardown and upstream_task.is_setup:
+                    # don't go through the teardown-to-setup path
                     continue
-                if upstream_task.is_setup or upstream_task.is_teardown:
-                    yield from 
recurse_for_first_non_setup_teardown(upstream_task)
                 else:
                     yield upstream_task
 
         for task in tasks:
             if task.downstream_task_ids.isdisjoint(ids):
-                if not (task.is_teardown or task.is_setup):
+                if not task.is_teardown:
                     yield task
                 else:
-                    yield from recurse_for_first_non_setup_teardown(task)
+                    yield from recurse_for_first_non_teardown(task)
 
     def child_id(self, label):
         """Prefix label with group_id if prefix_group_id is True. Otherwise 
return the label as-is."""
diff --git a/tests/utils/test_task_group.py b/tests/utils/test_task_group.py
index c021d98b88..a9f61debc6 100644
--- a/tests/utils/test_task_group.py
+++ b/tests/utils/test_task_group.py
@@ -22,7 +22,13 @@ from datetime import timedelta
 import pendulum
 import pytest
 
-from airflow.decorators import dag, task as task_decorator, task_group as 
task_group_decorator
+from airflow.decorators import (
+    dag,
+    setup,
+    task as task_decorator,
+    task_group as task_group_decorator,
+    teardown,
+)
 from airflow.exceptions import TaskAlreadyInTaskGroup
 from airflow.models.baseoperator import BaseOperator
 from airflow.models.dag import DAG
@@ -1479,3 +1485,115 @@ def test_task_group_arrow_with_setups_teardowns():
         tg1 >> w2
     assert t1.downstream_task_ids == set()
     assert w1.downstream_task_ids == {"tg1.t1", "w2"}
+
+
+def test_task_group_arrow_with_setup_group():
+    with DAG(dag_id="setup_group_teardown_group", start_date=pendulum.now()):
+        with TaskGroup("group_1") as g1:
+
+            @setup
+            def setup_1():
+                ...
+
+            @setup
+            def setup_2():
+                ...
+
+            s1 = setup_1()
+            s2 = setup_2()
+
+        with TaskGroup("group_2") as g2:
+
+            @teardown
+            def teardown_1():
+                ...
+
+            @teardown
+            def teardown_2():
+                ...
+
+            t1 = teardown_1()
+            t2 = teardown_2()
+
+        @task_decorator
+        def work():
+            ...
+
+        w1 = work()
+        g1 >> w1 >> g2
+        t1.as_teardown(setups=s1)
+        t2.as_teardown(setups=s2)
+    assert set(s1.operator.downstream_task_ids) == {"work", 
"group_2.teardown_1"}
+    assert set(s2.operator.downstream_task_ids) == {"work", 
"group_2.teardown_2"}
+    assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", 
"group_2.teardown_2"}
+    assert set(t1.operator.downstream_task_ids) == set()
+    assert set(t2.operator.downstream_task_ids) == set()
+
+    def get_nodes(group):
+        d = task_group_to_dict(group)
+        new_d = {}
+        new_d["id"] = d["id"]
+        new_d["children"] = [{"id": x["id"]} for x in d["children"]]
+        return new_d
+
+    assert get_nodes(g1) == {
+        "id": "group_1",
+        "children": [
+            {"id": "group_1.setup_1"},
+            {"id": "group_1.setup_2"},
+            {"id": "group_1.downstream_join_id"},
+        ],
+    }
+
+
+def test_task_group_arrow_with_setup_group_deeper_setup():
+    """
+    When recursing upstream for a non-teardown leaf, we should ignore setups 
that
+    are direct upstream of a teardown.
+    """
+    with DAG(dag_id="setup_group_teardown_group_2", start_date=pendulum.now()):
+        with TaskGroup("group_1") as g1:
+
+            @setup
+            def setup_1():
+                ...
+
+            @setup
+            def setup_2():
+                ...
+
+            @teardown
+            def teardown_0():
+                ...
+
+            s1 = setup_1()
+            s2 = setup_2()
+            t0 = teardown_0()
+            s2 >> t0
+
+        with TaskGroup("group_2") as g2:
+
+            @teardown
+            def teardown_1():
+                ...
+
+            @teardown
+            def teardown_2():
+                ...
+
+            t1 = teardown_1()
+            t2 = teardown_2()
+
+        @task_decorator
+        def work():
+            ...
+
+        w1 = work()
+        g1 >> w1 >> g2
+        t1.as_teardown(setups=s1)
+        t2.as_teardown(setups=s2)
+    assert set(s1.operator.downstream_task_ids) == {"work", 
"group_2.teardown_1"}
+    assert set(s2.operator.downstream_task_ids) == {"group_1.teardown_0", 
"group_2.teardown_2"}
+    assert set(w1.operator.downstream_task_ids) == {"group_2.teardown_1", 
"group_2.teardown_2"}
+    assert set(t1.operator.downstream_task_ids) == set()
+    assert set(t2.operator.downstream_task_ids) == set()

Reply via email to