This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch v2-8-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 96feeaabb53e274351378b8e2f5dfd41985e8dc8
Author: Steven Schaerer <[email protected]>
AuthorDate: Sat Mar 2 20:18:38 2024 +0100

    Improve code coverage for TriggerRuleDep (#37680)
    
    (cherry picked from commit 73a632a5a0af3fe51484e89a7bd4c771800707e5)
---
 airflow/ti_deps/deps/trigger_rule_dep.py    |   12 +-
 tests/ti_deps/deps/test_trigger_rule_dep.py | 1000 +++++++++++++++++----------
 2 files changed, 620 insertions(+), 392 deletions(-)

diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow/ti_deps/deps/trigger_rule_dep.py
index ca2a6100a2..c381166634 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -435,8 +435,8 @@ class TriggerRuleDep(BaseTIDep):
                 if success + failed <= 0:
                     yield self._failing_status(
                         reason=(
-                            f"Task's trigger rule '{trigger_rule}'"
-                            "requires at least one upstream task failure or 
success"
+                            f"Task's trigger rule '{trigger_rule}' "
+                            "requires at least one upstream task failure or 
success "
                             f"but none were failed or success. 
upstream_states={upstream_states}, "
                             f"upstream_task_ids={task.upstream_task_ids}"
                         )
@@ -521,14 +521,6 @@ class TriggerRuleDep(BaseTIDep):
                             f"upstream_task_ids={task.upstream_task_ids}"
                         )
                     )
-                elif upstream_setup is None:  # for now, None only happens in 
mapped case
-                    yield self._failing_status(
-                        reason=(
-                            f"Task's trigger rule '{trigger_rule}' cannot have 
mapped tasks as upstream. "
-                            f"upstream_states={upstream_states}, "
-                            f"upstream_task_ids={task.upstream_task_ids}"
-                        )
-                    )
                 elif upstream_setup and not success_setup:
                     yield self._failing_status(
                         reason=(
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py 
b/tests/ti_deps/deps/test_trigger_rule_dep.py
index 00cbcd449a..7656dcdec9 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -35,12 +35,16 @@ from airflow.utils.trigger_rule import TriggerRule
 
 pytestmark = pytest.mark.db_test
 
-
 if TYPE_CHECKING:
+    from sqlalchemy.orm.session import Session
+
     from airflow.models.dagrun import DagRun
 
 SKIPPED = TaskInstanceState.SKIPPED
 UPSTREAM_FAILED = TaskInstanceState.UPSTREAM_FAILED
+REMOVED = TaskInstanceState.REMOVED
+SUCCESS = TaskInstanceState.SUCCESS
+FAILED = TaskInstanceState.FAILED
 
 
 @pytest.fixture
@@ -92,9 +96,21 @@ def get_task_instance(monkeypatch, session, dag_maker):
 
 @pytest.fixture
 def get_mapped_task_dagrun(session, dag_maker):
-    def _get_dagrun(trigger_rule=TriggerRule.ALL_SUCCESS, 
state=TaskInstanceState.SUCCESS):
+    def _get_dagrun(trigger_rule=TriggerRule.ALL_SUCCESS, state=SUCCESS, 
add_setup_tasks: bool = False):
         from airflow.decorators import task
 
+        @task
+        def setup_1(i):
+            return 1
+
+        @task
+        def setup_2(i):
+            return 1
+
+        @task
+        def setup_3(i):
+            return 1
+
         @task
         def do_something(i):
             return 1
@@ -106,46 +122,76 @@ def get_mapped_task_dagrun(session, dag_maker):
         with dag_maker(dag_id="test_dag"):
             nums = do_something.expand(i=[i + 1 for i in range(5)])
             do_something_else.expand(i=nums)
+            if add_setup_tasks:
+                setup_nums = setup_1.expand(i=[i + 1 for i in range(5)])
+                setup_more_nums = setup_2.expand(i=setup_nums)
+                setup_other_nums = setup_3.expand(i=setup_more_nums)
+                setup_more_nums.as_setup() >> nums
+                setup_nums.as_setup() >> nums
+                setup_other_nums.as_setup() >> nums
 
         dr = dag_maker.create_dagrun()
 
-        ti = dr.get_task_instance("do_something_else", session=session)
-        ti.map_index = 0
-        for map_index in range(1, 5):
-            ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
-            ti.dag_run = dr
-            session.add(ti)
-        session.flush()
-        tis = dr.get_task_instances()
-        for ti in tis:
-            if ti.task_id == "do_something":
-                if ti.map_index > 2:
-                    ti.state = TaskInstanceState.REMOVED
-                else:
-                    ti.state = state
-                session.merge(ti)
+        def _expand_tasks(task_instance: str, upstream: str) -> BaseOperator | 
None:
+            ti = dr.get_task_instance(task_instance, session=session)
+            ti.map_index = 0
+            for map_index in range(1, 5):
+                ti = TaskInstance(ti.task, run_id=dr.run_id, 
map_index=map_index)
+                ti.dag_run = dr
+                session.add(ti)
+            session.flush()
+            tis = dr.get_task_instances(session=session)
+            for ti in tis:
+                if ti.task_id == upstream:
+                    if ti.map_index > 2:
+                        ti.state = REMOVED
+                    else:
+                        ti.state = state
+                    session.merge(ti)
+            return ti.task
+
+        do_task = _expand_tasks("do_something_else", "do_something")
+        if add_setup_tasks:
+            _expand_tasks("setup_2", "setup_1")
+            setup_task = _expand_tasks("setup_3", "setup_2")
+        else:
+            setup_task = None
         session.commit()
-        return dr, ti.task
+        return dr, do_task, setup_task
 
     return _get_dagrun
 
 
 class TestTriggerRuleDep:
-    def test_no_upstream_tasks(self, get_task_instance):
+    def test_no_upstream_tasks(self, session, get_task_instance):
         """
         If the TI has no upstream TIs then there is nothing to check and the 
dep is passed
         """
         ti = get_task_instance(TriggerRule.ALL_DONE)
-        assert TriggerRuleDep().is_met(ti=ti)
+        dep_statuses = tuple(
+            TriggerRuleDep().get_dep_statuses(ti=ti, dep_context=DepContext(), 
session=session)
+        )
+        assert len(dep_statuses) == 1
+        assert dep_statuses[0].passed
+        assert dep_statuses[0].reason == "The task instance did not have any 
upstream tasks."
 
-    def test_always_tr(self, get_task_instance):
+    def test_always_tr(self, session, get_task_instance):
         """
         The always trigger rule should always pass this dep
         """
-        ti = get_task_instance(TriggerRule.ALWAYS)
-        assert TriggerRuleDep().is_met(ti=ti)
+        ti = get_task_instance(TriggerRule.ALWAYS, normal_tasks=["a"])
 
-    def test_one_success_tr_success(self, session, get_task_instance):
+        dep_statuses = tuple(
+            TriggerRuleDep().get_dep_statuses(ti=ti, dep_context=DepContext(), 
session=session)
+        )
+        assert len(dep_statuses) == 1
+        assert dep_statuses[0].passed
+        assert dep_statuses[0].reason == "The task had a always trigger rule 
set."
+
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_one_success_tr_success(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         One-success trigger rule success
         """
@@ -158,39 +204,67 @@ class TestTriggerRuleDep:
             upstream_failed=2,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 0
 
-    def test_one_success_tr_failure(self, session, get_task_instance):
+    @pytest.mark.parametrize(
+        "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), 
(False, None)]
+    )
+    def test_one_success_tr_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         One-success trigger rule failure
         """
+        ti = get_task_instance(
+            TriggerRule.ONE_SUCCESS,
+            success=0,
+            skipped=1,
+            failed=1,
+            removed=1,
+            upstream_failed=1,
+            done=4,
+        )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires one upstream task success, but none were 
found.",
+            expected_ti_state=expected_ti_state,
+        )
+
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_one_success_tr_failure_all_skipped(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
+        """
+        One-success trigger rule failure and all are skipped
+        """
         ti = get_task_instance(
             TriggerRule.ONE_SUCCESS,
             success=0,
             skipped=2,
-            failed=2,
+            failed=0,
             removed=0,
-            upstream_failed=2,
+            upstream_failed=0,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires one upstream task success, but none were 
found.",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_one_failure_tr_failure(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_one_failure_tr_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         One-failure trigger rule failure
         """
@@ -203,17 +277,16 @@ class TestTriggerRuleDep:
             upstream_failed=0,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires one upstream task failure, but none were 
found.",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_one_failure_tr_success(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_one_failure_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
         """
         One-failure trigger rule success
         """
@@ -226,16 +299,10 @@ class TestTriggerRuleDep:
             upstream_failed=0,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_one_failure_tr_success_no_failed(self, session, 
get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_one_failure_tr_success_no_failed(self, session, 
get_task_instance, flag_upstream_failed):
         """
         One-failure trigger rule success
         """
@@ -248,16 +315,10 @@ class TestTriggerRuleDep:
             upstream_failed=2,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_one_done_tr_success(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_one_done_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
         """
         One-done trigger rule success
         """
@@ -270,16 +331,10 @@ class TestTriggerRuleDep:
             upstream_failed=0,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_one_done_tr_success_with_failed(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_one_done_tr_success_with_failed(self, session, get_task_instance, 
flag_upstream_failed):
         """
         One-done trigger rule success
         """
@@ -292,16 +347,10 @@ class TestTriggerRuleDep:
             upstream_failed=0,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_one_done_tr_skip(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_one_done_tr_skip(self, session, get_task_instance, 
flag_upstream_failed, expected_ti_state):
         """
         One-done trigger rule skip
         """
@@ -314,17 +363,18 @@ class TestTriggerRuleDep:
             upstream_failed=0,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires at least one upstream task failure or 
success but none",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_one_done_tr_upstream_failed(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_one_done_tr_upstream_failed(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         One-done trigger rule upstream_failed
         """
@@ -337,17 +387,16 @@ class TestTriggerRuleDep:
             upstream_failed=2,
             done=2,
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires at least one upstream task failure or 
success but none",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_all_success_tr_success(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_all_success_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
         """
         All-success trigger rule success
         """
@@ -361,16 +410,14 @@ class TestTriggerRuleDep:
             done=1,
             normal_tasks=["FakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_all_success_tr_failure(self, session, get_task_instance):
+    @pytest.mark.parametrize(
+        "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), 
(False, None)]
+    )
+    def test_all_success_tr_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         All-success trigger rule failure
         """
@@ -384,20 +431,15 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires all upstream tasks to have succeeded, 
but found 1",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    @pytest.mark.parametrize(
-        "flag_upstream_failed, expected_ti_state",
-        [(True, TaskInstanceState.SKIPPED), (False, None)],
-    )
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
     def test_all_success_tr_skip(self, session, get_task_instance, 
flag_upstream_failed, expected_ti_state):
         """
         All-success trigger rule fails when some upstream tasks are skipped.
@@ -412,18 +454,18 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires all upstream tasks to have succeeded, 
but found 1",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
-        assert ti.state == expected_ti_state
 
-    def test_all_success_tr_skip_wait_for_past_depends_before_skipping(self, 
session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_all_success_tr_skip_wait_for_past_depends_before_skipping(
+        self, session, get_task_instance, flag_upstream_failed
+    ):
         """
         All-success trigger rule fails when some upstream tasks are skipped. 
The state of the ti
         should not be set to SKIPPED when flag_upstream_failed is True and
@@ -442,21 +484,21 @@ class TestTriggerRuleDep:
         ti.task.xcom_pull.return_value = None
         xcom_mock = Mock(return_value=None)
         with mock.patch("airflow.models.taskinstance.TaskInstance.xcom_pull", 
xcom_mock):
-            dep_statuses = tuple(
-                TriggerRuleDep()._evaluate_trigger_rule(
-                    ti=ti,
-                    dep_context=DepContext(
-                        flag_upstream_failed=True, 
wait_for_past_depends_before_skipping=True
-                    ),
-                    session=session,
-                )
+            _test_trigger_rule(
+                ti=ti,
+                session=session,
+                flag_upstream_failed=flag_upstream_failed,
+                wait_for_past_depends_before_skipping=True,
+                expected_reason=(
+                    "Task should be skipped but the past depends are not met"
+                    if flag_upstream_failed
+                    else "requires all upstream tasks to have succeeded, but 
found 1"
+                ),
             )
-            assert len(dep_statuses) == 1
-            assert not dep_statuses[0].passed
-            assert ti.state is None
 
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
     def 
test_all_success_tr_skip_wait_for_past_depends_before_skipping_past_depends_met(
-        self, session, get_task_instance
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
     ):
         """
         All-success trigger rule fails when some upstream tasks are skipped. 
The state of the ti
@@ -476,23 +518,19 @@ class TestTriggerRuleDep:
         ti.task.xcom_pull.return_value = None
         xcom_mock = Mock(return_value=True)
         with mock.patch("airflow.models.taskinstance.TaskInstance.xcom_pull", 
xcom_mock):
-            dep_statuses = tuple(
-                TriggerRuleDep()._evaluate_trigger_rule(
-                    ti=ti,
-                    dep_context=DepContext(
-                        flag_upstream_failed=True, 
wait_for_past_depends_before_skipping=True
-                    ),
-                    session=session,
-                )
+            _test_trigger_rule(
+                ti=ti,
+                session=session,
+                flag_upstream_failed=flag_upstream_failed,
+                wait_for_past_depends_before_skipping=True,
+                expected_ti_state=expected_ti_state,
+                expected_reason="requires all upstream tasks to have 
succeeded, but found 1",
             )
-            assert len(dep_statuses) == 1
-            assert not dep_statuses[0].passed
-            assert ti.state == TaskInstanceState.SKIPPED
 
     @pytest.mark.parametrize("flag_upstream_failed", [True, False])
     def test_none_failed_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
         """
-        All success including skip trigger rule success
+        None failed trigger rule success
         """
         ti = get_task_instance(
             TriggerRule.NONE_FAILED,
@@ -504,19 +542,16 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
-        assert ti.state is None
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_none_failed_tr_failure(self, session, get_task_instance):
+    @pytest.mark.parametrize(
+        "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), 
(False, None)]
+    )
+    def test_none_failed_tr_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
-        All success including skip trigger rule failure
+        None failed trigger rule failure
         """
         ti = get_task_instance(
             TriggerRule.NONE_FAILED,
@@ -528,19 +563,45 @@ class TestTriggerRuleDep:
             done=3,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="all upstream tasks to have succeeded or been 
skipped, but found 1",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_none_failed_min_one_success_tr_success(self, session, 
get_task_instance):
+    @pytest.mark.parametrize(
+        "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), 
(False, None)]
+    )
+    def test_none_failed_tr_failure_with_upstream_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
-        All success including skip trigger rule success
+        None failed skip trigger rule failure
+        """
+        ti = get_task_instance(
+            TriggerRule.NONE_FAILED,
+            success=1,
+            skipped=1,
+            failed=0,
+            removed=0,
+            upstream_failed=1,
+            done=3,
+            normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"],
+        )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="all upstream tasks to have succeeded or been 
skipped, but found 1",
+            expected_ti_state=expected_ti_state,
+        )
+
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_none_failed_min_one_success_tr_success(self, session, 
get_task_instance, flag_upstream_failed):
+        """
+        None failed min one success trigger rule success
         """
         ti = get_task_instance(
             TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
@@ -552,18 +613,14 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_none_failed_min_one_success_tr_skipped(self, session, 
get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_none_failed_min_one_success_tr_skipped(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
-        All success including all upstream skips trigger rule success
+        None failed min one success trigger rule success with all skipped
         """
         ti = get_task_instance(
             TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
@@ -575,19 +632,21 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=True),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 0
-        assert ti.state == TaskInstanceState.SKIPPED
 
-    def test_none_failed_min_one_success_tr_failure(self, session, 
get_task_instance):
+    @pytest.mark.parametrize(
+        "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), 
(False, None)]
+    )
+    def test_none_failed_min_one_success_tr_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
-        All success including skip trigger rule failure
+        None failed min one success trigger rule failure due to single failure
         """
         ti = get_task_instance(
             TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
@@ -599,17 +658,43 @@ class TestTriggerRuleDep:
             done=3,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="all upstream tasks to have succeeded or been 
skipped, but found 1",
+            expected_ti_state=expected_ti_state,
+        )
+
+    @pytest.mark.parametrize(
+        "flag_upstream_failed, expected_ti_state", [(True, UPSTREAM_FAILED), 
(False, None)]
+    )
+    def test_none_failed_min_one_success_tr_upstream_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
+        """
+        None failed min one success trigger rule failure due to single 
upstream failure
+        """
+        ti = get_task_instance(
+            TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
+            success=1,
+            skipped=1,
+            failed=0,
+            removed=0,
+            upstream_failed=1,
+            done=3,
+            normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"],
+        )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="all upstream tasks to have succeeded or been 
skipped, but found 1",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_all_failed_tr_success(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_all_failed_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
         """
         All-failed trigger rule success
         """
@@ -623,16 +708,10 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_all_failed_tr_failure(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_all_failed_tr_failure(self, session, get_task_instance, 
flag_upstream_failed, expected_ti_state):
         """
         All-failed trigger rule failure
         """
@@ -646,17 +725,16 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires all upstream tasks to have failed, but 
found 2",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_all_done_tr_success(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_all_done_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
         """
         All-done trigger rule success
         """
@@ -670,14 +748,7 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
     @pytest.mark.parametrize(
         "task_cfg, states, exp_reason, exp_state",
@@ -761,8 +832,9 @@ class TestTriggerRuleDep:
             ),
         ],
     )
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
     def test_teardown_tr_not_all_done(
-        self, task_cfg, states, exp_reason, exp_state, session, 
get_task_instance
+        self, task_cfg, states, exp_reason, exp_state, session, 
get_task_instance, flag_upstream_failed
     ):
         """
         All-done trigger rule success
@@ -773,22 +845,18 @@ class TestTriggerRuleDep:
             normal_tasks=[f"w{x}" for x in range(task_cfg["work"])],
             setup_tasks=[f"s{x}" for x in range(task_cfg["setup"])],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti, dep_context=DepContext(flag_upstream_failed=True), 
session=session
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason=exp_reason,
+            expected_ti_state=exp_state if exp_state and flag_upstream_failed 
else None,
         )
-        if exp_reason:
-            dep_status = dep_statuses[0]
-            assert len(dep_statuses) == 1
-            assert exp_reason in dep_status.reason
-            assert dep_status.passed is False
-            assert ti.state == exp_state
-        else:
-            assert len(dep_statuses) == 0
-            assert ti.state is None
 
-    def test_all_skipped_tr_failure(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_all_skipped_tr_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         All-skipped trigger rule failure
         """
@@ -802,17 +870,18 @@ class TestTriggerRuleDep:
             done=1,
             normal_tasks=["FakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires all upstream tasks to have been skipped, 
but found 1",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_all_skipped_tr_failure_upstream_failed(self, session, 
get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_all_skipped_tr_failure_upstream_failed(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         All-skipped trigger rule failure if an upstream task is in a 
`upstream_failed` state
         """
@@ -826,15 +895,13 @@ class TestTriggerRuleDep:
             done=1,
             normal_tasks=["FakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires all upstream tasks to have been skipped, 
but found 1",
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
     @pytest.mark.parametrize("flag_upstream_failed", [True, False])
     def test_all_skipped_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
@@ -851,16 +918,10 @@ class TestTriggerRuleDep:
             done=3,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    def test_all_done_tr_failure(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_all_done_tr_failure(self, session, get_task_instance, 
flag_upstream_failed):
         """
         All-done trigger rule failure
         """
@@ -876,15 +937,12 @@ class TestTriggerRuleDep:
         )
         EmptyOperator(task_id="OtherFakeTeakID", dag=ti.task.dag) >> ti.task  
# An unfinished upstream.
 
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires all upstream tasks to have completed, 
but found 1",
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
     @pytest.mark.parametrize("flag_upstream_failed", [True, False])
     def test_none_skipped_tr_success(self, session, get_task_instance, 
flag_upstream_failed):
@@ -901,17 +959,12 @@ class TestTriggerRuleDep:
             done=3,
             normal_tasks=["FakeTaskID", "OtherFakeTaskID", "FailedFakeTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
-                session=session,
-            )
-        )
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
-    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
-    def test_none_skipped_tr_failure(self, session, get_task_instance, 
flag_upstream_failed):
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, SKIPPED), (False, None)])
+    def test_none_skipped_tr_failure(
+        self, session, get_task_instance, flag_upstream_failed, 
expected_ti_state
+    ):
         """
         None-skipped trigger rule failure
         """
@@ -925,17 +978,16 @@ class TestTriggerRuleDep:
             done=2,
             normal_tasks=["FakeTaskID", "SkippedTaskID"],
         )
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                
dep_context=DepContext(flag_upstream_failed=flag_upstream_failed),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_ti_state=expected_ti_state,
+            expected_reason="requires all upstream tasks to not have been 
skipped, but found 1",
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_none_skipped_tr_failure_empty(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_none_skipped_tr_failure_empty(self, session, get_task_instance, 
flag_upstream_failed):
         """
         None-skipped trigger rule fails until all upstream tasks have 
completed execution
         """
@@ -950,17 +1002,15 @@ class TestTriggerRuleDep:
         )
         EmptyOperator(task_id="FakeTeakID", dag=ti.task.dag) >> ti.task  # An 
unfinished upstream.
 
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="requires all upstream tasks to not have been 
skipped, but found 0",
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
-    def test_unknown_tr(self, session, get_task_instance):
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    def test_unknown_tr(self, session, get_task_instance, 
flag_upstream_failed):
         """
         Unknown trigger rules should cause this dep to fail
         """
@@ -975,15 +1025,12 @@ class TestTriggerRuleDep:
         )
         ti.task.trigger_rule = "Unknown Trigger Rule"
 
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_reason="No strategy to evaluate trigger rule 'Unknown 
Trigger Rule'.",
         )
-        assert len(dep_statuses) == 1
-        assert not dep_statuses[0].passed
 
     def test_UpstreamTIStates(self, session, dag_maker):
         """
@@ -1002,11 +1049,11 @@ class TestTriggerRuleDep:
         dr = dag_maker.create_dagrun()
         tis = {ti.task_id: ti for ti in dr.task_instances}
 
-        tis["op1"].state = TaskInstanceState.SUCCESS
-        tis["op2"].state = TaskInstanceState.FAILED
-        tis["op3"].state = TaskInstanceState.SUCCESS
-        tis["op4"].state = TaskInstanceState.SUCCESS
-        tis["op5"].state = TaskInstanceState.SUCCESS
+        tis["op1"].state = SUCCESS
+        tis["op2"].state = FAILED
+        tis["op3"].state = SUCCESS
+        tis["op4"].state = SUCCESS
+        tis["op5"].state = SUCCESS
 
         def _get_finished_tis(task_id: str) -> Iterator[TaskInstance]:
             return (ti for ti in tis.values() if ti.task_id in 
tis[task_id].task.upstream_task_ids)
@@ -1019,16 +1066,19 @@ class TestTriggerRuleDep:
         dr.update_state(session=session)
         assert dr.state == DagRunState.SUCCESS
 
+    @pytest.mark.parametrize("flag_upstream_failed, expected_ti_state", 
[(True, REMOVED), (False, None)])
     def test_mapped_task_upstream_removed_with_all_success_trigger_rules(
         self,
         monkeypatch,
         session,
         get_mapped_task_dagrun,
+        flag_upstream_failed,
+        expected_ti_state,
     ):
         """
         Test ALL_SUCCESS trigger rule with mapped task upstream removed
         """
-        dr, task = get_mapped_task_dagrun()
+        dr, task, _ = get_mapped_task_dagrun()
 
         # ti with removed upstream ti
         ti = dr.get_task_instance(task_id="do_something_else", map_index=3, 
session=session)
@@ -1046,28 +1096,26 @@ class TestTriggerRuleDep:
         )
         monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: 
upstream_states)
 
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                # Marks the task as removed if upstream is removed.
-                dep_context=DepContext(flag_upstream_failed=True),
-                session=session,
-            )
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            expected_ti_state=expected_ti_state,
         )
-        assert len(dep_statuses) == 0
-        assert ti.state == TaskInstanceState.REMOVED
 
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
     def test_mapped_task_upstream_removed_with_all_failed_trigger_rules(
         self,
         monkeypatch,
         session,
         get_mapped_task_dagrun,
+        flag_upstream_failed,
     ):
         """
         Test ALL_FAILED trigger rule with mapped task upstream removed
         """
 
-        dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, 
state=TaskInstanceState.FAILED)
+        dr, task, _ = 
get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, state=FAILED)
 
         # ti with removed upstream ti
         ti = dr.get_task_instance(task_id="do_something_else", map_index=3, 
session=session)
@@ -1085,31 +1133,24 @@ class TestTriggerRuleDep:
         )
         monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: 
upstream_states)
 
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
     @pytest.mark.parametrize(
-        "trigger_rule",
-        [TriggerRule.NONE_FAILED, TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS],
+        "trigger_rule", [TriggerRule.NONE_FAILED, 
TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS]
     )
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
     def test_mapped_task_upstream_removed_with_none_failed_trigger_rules(
         self,
         monkeypatch,
         session,
         get_mapped_task_dagrun,
         trigger_rule,
+        flag_upstream_failed,
     ):
         """
         Test NONE_FAILED trigger rule with mapped task upstream removed
         """
-        dr, task = get_mapped_task_dagrun(trigger_rule=trigger_rule)
+        dr, task, _ = get_mapped_task_dagrun(trigger_rule=trigger_rule)
 
         # ti with removed upstream ti
         ti = dr.get_task_instance(task_id="do_something_else", map_index=3, 
session=session)
@@ -1127,15 +1168,7 @@ class TestTriggerRuleDep:
         )
         monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: 
upstream_states)
 
-        dep_statuses = tuple(
-            TriggerRuleDep()._evaluate_trigger_rule(
-                ti=ti,
-                dep_context=DepContext(flag_upstream_failed=False),
-                session=session,
-            )
-        )
-
-        assert len(dep_statuses) == 0
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
 
 
 def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
@@ -1148,12 +1181,18 @@ def 
test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
             return x
 
         @task_group
-        def tg(x):
+        def tg1(x):
             t1 = t.override(task_id="t1")(x=x)
             return t.override(task_id="t2")(x=t1)
 
-        t2 = tg.expand(x=[1, 2, 3])
-        t.override(task_id="t3")(x=t2)
+        t2 = tg1.expand(x=[1, 2, 3])
+
+        @task_group
+        def tg2(x):
+            return t.override(task_id="t3")(x=t2)
+
+        vals2 = tg2.expand(x=[4, 5, 6])
+        t.override(task_id="t4")(x=vals2)
 
     dr: DagRun = dag_maker.create_dagrun()
 
@@ -1163,29 +1202,40 @@ def 
test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
 
     # Initial decision.
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t1", 0), ("tg.t1", 1), ("tg.t1", 2)]
+    assert sorted(tis) == [("tg1.t1", 0), ("tg1.t1", 1), ("tg1.t1", 2)]
 
     # After running the first t1, the first t2 becomes immediately available.
-    tis["tg.t1", 0].run()
+    tis["tg1.t1", 0].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t1", 1), ("tg.t1", 2), ("tg.t2", 0)]
+    assert sorted(tis) == [("tg1.t1", 1), ("tg1.t1", 2), ("tg1.t2", 0)]
 
     # Similarly for the subsequent t2 instances.
-    tis["tg.t1", 2].run()
+    tis["tg1.t1", 2].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t1", 1), ("tg.t2", 0), ("tg.t2", 2)]
+    assert sorted(tis) == [("tg1.t1", 1), ("tg1.t2", 0), ("tg1.t2", 2)]
 
     # But running t2 partially does not make t3 available.
-    tis["tg.t1", 1].run()
-    tis["tg.t2", 0].run()
-    tis["tg.t2", 2].run()
+    tis["tg1.t1", 1].run()
+    tis["tg1.t2", 0].run()
+    tis["tg1.t2", 2].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("tg.t2", 1)]
+    assert sorted(tis) == [("tg1.t2", 1)]
 
     # Only after all t2 instances are run does t3 become available.
-    tis["tg.t2", 1].run()
+    tis["tg1.t2", 1].run()
+    tis = _one_scheduling_decision_iteration()
+    assert sorted(tis) == [("tg2.t3", 0), ("tg2.t3", 1), ("tg2.t3", 2)]
+
+    # But running t3 partially does not make t4 available.
+    tis["tg2.t3", 0].run()
+    tis["tg2.t3", 2].run()
+    tis = _one_scheduling_decision_iteration()
+    assert sorted(tis) == [("tg2.t3", 1)]
+
+    # Only after all t3 instances are run does t4 become available.
+    tis["tg2.t3", 1].run()
     tis = _one_scheduling_decision_iteration()
-    assert sorted(tis) == [("t3", -1)]
+    assert sorted(tis) == [("t4", -1)]
 
 
 def test_upstream_in_mapped_group_when_mapped_tasks_list_is_empty(dag_maker, 
session):
@@ -1216,7 +1266,12 @@ def 
test_upstream_in_mapped_group_when_mapped_tasks_list_is_empty(dag_maker, ses
     assert tis == {}
 
 
-def test_mapped_task_check_before_expand(dag_maker, session):
[email protected]("flag_upstream_failed", [True, False])
+def test_mapped_task_check_before_expand(dag_maker, session, 
flag_upstream_failed):
+    """
+    t3 depends on t2, which depends on t1 for expansion. Since t1 has not yet 
run, t2 has not expanded yet,
+    and we need to guarantee this lack of expansion does not fail the 
dependency-checking logic.
+    """
     with dag_maker(session=session):
 
         @task
@@ -1232,17 +1287,46 @@ def test_mapped_task_check_before_expand(dag_maker, 
session):
         tg.expand(a=t([1, 2, 3]))
 
     dr: DagRun = dag_maker.create_dagrun()
-    result_iterator = TriggerRuleDep()._evaluate_trigger_rule(
-        # t3 depends on t2, which depends on t1 for expansion. Since t1 has not
-        # yet run, t2 has not expanded yet, and we need to guarantee this lack
-        # of expansion does not fail the dependency-checking logic.
+
+    _test_trigger_rule(
         ti=next(ti for ti in dr.task_instances if ti.task_id == "tg.t3" and 
ti.map_index == -1),
-        dep_context=DepContext(),
         session=session,
+        flag_upstream_failed=flag_upstream_failed,
+        expected_reason="requires all upstream tasks to have succeeded, but 
found 1",
+    )
+
+
[email protected]("flag_upstream_failed, expected_ti_state", [(True, 
SKIPPED), (False, None)])
+def test_mapped_task_group_finished_upstream_before_expand(
+    dag_maker, session, flag_upstream_failed, expected_ti_state
+):
+    """
+    t3 depends on t2, which was skipped before it was expanded. We need to 
guarantee this lack of expansion
+    does not fail the dependency-checking logic.
+    """
+    with dag_maker(session=session):
+
+        @task
+        def t(x):
+            return x
+
+        @task_group
+        def tg(x):
+            return t.override(task_id="t3")(x=x)
+
+        t.override(task_id="t2").expand(x=t.override(task_id="t1")([1, 2])) >> 
tg.expand(x=[1, 2])
+
+    dr: DagRun = dag_maker.create_dagrun()
+    tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
+    tis["t2"].set_state(SKIPPED, session=session)
+    session.flush()
+    _test_trigger_rule(
+        ti=tis["tg.t3"],
+        session=session,
+        flag_upstream_failed=flag_upstream_failed,
+        expected_reason="requires all upstream tasks to have succeeded, but 
found 1",
+        expected_ti_state=expected_ti_state,
     )
-    results = list(result_iterator)
-    assert len(results) == 1
-    assert results[0].passed is False
 
 
 class TestTriggerRuleDepSetupConstraint:
@@ -1407,3 +1491,155 @@ class TestTriggerRuleDepSetupConstraint:
             (status,) = self.get_dep_statuses(dr, "w2", 
flag_upstream_failed=True, session=session)
         assert status.reason.startswith("All setup tasks must complete 
successfully")
         assert self.get_ti(dr, "w2").state == expected
+
+
[email protected](
+    "map_index, flag_upstream_failed, expected_ti_state",
+    [(2, True, None), (3, True, REMOVED), (4, True, REMOVED), (3, False, 
None)],
+)
+def test_setup_constraint_mapped_task_upstream_removed_and_success(
+    dag_maker,
+    session,
+    get_mapped_task_dagrun,
+    map_index,
+    flag_upstream_failed,
+    expected_ti_state,
+):
+    """
+    Dynamically mapped setup task with successful and removed upstream tasks. 
Expect rule to be
+    successful. State is set to REMOVED for map index >= n success
+    """
+    dr, _, setup_task = get_mapped_task_dagrun(add_setup_tasks=True)
+
+    ti = dr.get_task_instance(task_id="setup_3", map_index=map_index, 
session=session)
+    ti.task = setup_task
+
+    _test_trigger_rule(
+        ti=ti,
+        session=session,
+        flag_upstream_failed=flag_upstream_failed,
+        expected_ti_state=expected_ti_state,
+    )
+
+
[email protected](
+    "flag_upstream_failed, wait_for_past_depends_before_skipping, 
past_depends_met, expected_ti_state, expect_failure",
+    [
+        (False, True, True, None, False),
+        (False, True, False, None, False),
+        (False, False, False, None, False),
+        (False, False, True, None, False),
+        (True, False, False, SKIPPED, False),
+        (True, False, True, SKIPPED, False),
+        (True, True, False, None, True),
+        (True, True, True, SKIPPED, False),
+    ],
+)
+def test_setup_constraint_wait_for_past_depends_before_skipping(
+    dag_maker,
+    session,
+    get_task_instance,
+    monkeypatch,
+    flag_upstream_failed,
+    wait_for_past_depends_before_skipping,
+    past_depends_met,
+    expected_ti_state,
+    expect_failure,
+):
+    """
+    Setup task with a skipped upstream task.
+    * If flag_upstream_failed is False then do not expect either a failure nor 
a modified state.
+    * If flag_upstream_failed is True and 
wait_for_past_depends_before_skipping is False then expect the
+      state to be set to SKIPPED but no failure.
+    * If both flag_upstream_failed and wait_for_past_depends_before_skipping 
are True then if the past
+      depends are met the state is expected to be SKIPPED and no failure, 
otherwise the state is not
+      expected to change but the trigger rule should fail.
+    """
+    ti = get_task_instance(
+        trigger_rule=TriggerRule.ALL_DONE,
+        success=1,
+        skipped=1,
+        failed=0,
+        removed=0,
+        upstream_failed=0,
+        done=2,
+        setup_tasks=["FakeTaskID", "OtherFakeTaskID"],
+    )
+
+    ti.task.xcom_pull.return_value = None
+    xcom_mock = Mock(return_value=True if past_depends_met else None)
+    with mock.patch("airflow.models.taskinstance.TaskInstance.xcom_pull", 
xcom_mock):
+        _test_trigger_rule(
+            ti=ti,
+            session=session,
+            flag_upstream_failed=flag_upstream_failed,
+            
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+            expected_ti_state=expected_ti_state,
+            expected_reason=(
+                "Task should be skipped but the past depends are not met" if 
expect_failure else ""
+            ),
+        )
+
+
[email protected]("flag_upstream_failed, expected_ti_state", [(True, 
SKIPPED), (False, None)])
+def test_setup_mapped_task_group_finished_upstream_before_expand(
+    dag_maker, session, flag_upstream_failed, expected_ti_state
+):
+    """
+    t3 indirectly depends on t1, which was skipped before it was expanded. We 
need to guarantee this lack of
+    expansion does not fail the dependency-checking logic.
+    """
+    with dag_maker(session=session):
+
+        @task(trigger_rule=TriggerRule.ALL_DONE)
+        def t(x):
+            return x
+
+        @task_group
+        def tg(x):
+            return t.override(task_id="t3")(x=x)
+
+        vals = t.override(task_id="t1")([1, 2]).as_setup()
+        t.override(task_id="t2").expand(x=vals).as_setup() >> tg.expand(x=[1, 
2]).as_setup()
+
+    dr: DagRun = dag_maker.create_dagrun()
+
+    tis = {ti.task_id: ti for ti in dr.get_task_instances(session=session)}
+    tis["t1"].set_state(SKIPPED, session=session)
+    tis["t2"].set_state(SUCCESS, session=session)
+    session.flush()
+    _test_trigger_rule(
+        ti=tis["tg.t3"],
+        session=session,
+        flag_upstream_failed=flag_upstream_failed,
+        expected_reason="All setup tasks must complete successfully.",
+        expected_ti_state=expected_ti_state,
+    )
+
+
+def _test_trigger_rule(
+    ti: TaskInstance,
+    session: Session,
+    flag_upstream_failed: bool,
+    wait_for_past_depends_before_skipping: bool = False,
+    expected_reason: str = "",
+    expected_ti_state: TaskInstanceState | None = None,
+) -> None:
+    assert ti.state is None
+    dep_statuses = tuple(
+        TriggerRuleDep()._evaluate_trigger_rule(
+            ti=ti,
+            dep_context=DepContext(
+                flag_upstream_failed=flag_upstream_failed,
+                
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+            ),
+            session=session,
+        )
+    )
+    if expected_reason:
+        assert len(dep_statuses) == 1
+        assert not dep_statuses[0].passed
+        assert expected_reason in dep_statuses[0].reason
+    else:
+        assert not dep_statuses
+    assert ti.state == expected_ti_state

Reply via email to