This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch v2-4-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 6c90cea1ab63b925d58cf4949d94ec0e3b37830e
Author: Ephraim Anierobi <[email protected]>
AuthorDate: Wed Sep 21 13:52:45 2022 +0100

    Fix deadlock when mapped task with removed upstream is rerun (#26518)
    
    When a dag with a mapped downstream tasks that depends on a mapped upstream 
tasks that have some mapped indexes
    removed is rerun, we run into a deadlock because the trigger rules 
evaluation is not accounting for removed
    task instances.
    
    The fix for the deadlocks was to account for the removed task instances 
where possible in the trigger rules
    
    In this fix, I added a case where if we set flag_upstream_failed, then for 
the removed task instance, the downstream of that task instance will be 
removed. That's if the upstream with index 3 is removed, then downstream
    with index 3 will also be removed if flag_upstream_failed is set to True.
    
    (cherry picked from commit e91637f8894cac19c6b467b6669cbcc13184be70)
---
 airflow/ti_deps/deps/trigger_rule_dep.py    |  18 ++-
 tests/models/test_dagrun.py                 |  40 ++++++
 tests/models/test_taskinstance.py           | 178 ++++++++++++++++++++-----
 tests/ti_deps/deps/test_trigger_rule_dep.py | 195 +++++++++++++++++++++++++++-
 4 files changed, 395 insertions(+), 36 deletions(-)

diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow/ti_deps/deps/trigger_rule_dep.py
index 72fa783a8e..691dc3e3a5 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -59,6 +59,7 @@ class TriggerRuleDep(BaseTIDep):
             counter.get(State.SKIPPED, 0),
             counter.get(State.FAILED, 0),
             counter.get(State.UPSTREAM_FAILED, 0),
+            counter.get(State.REMOVED, 0),
             sum(counter.values()),
         )
 
@@ -73,7 +74,7 @@ class TriggerRuleDep(BaseTIDep):
             yield self._passing_status(reason="The task had a always trigger 
rule set.")
             return
         # see if the task name is in the task upstream for our task
-        successes, skipped, failed, upstream_failed, done = 
self._get_states_count_upstream_ti(
+        successes, skipped, failed, upstream_failed, removed, done = 
self._get_states_count_upstream_ti(
             task=ti.task, 
finished_tis=dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
         )
 
@@ -83,6 +84,7 @@ class TriggerRuleDep(BaseTIDep):
             skipped=skipped,
             failed=failed,
             upstream_failed=upstream_failed,
+            removed=removed,
             done=done,
             flag_upstream_failed=dep_context.flag_upstream_failed,
             dep_context=dep_context,
@@ -122,6 +124,7 @@ class TriggerRuleDep(BaseTIDep):
         skipped,
         failed,
         upstream_failed,
+        removed,
         done,
         flag_upstream_failed,
         dep_context: DepContext,
@@ -152,6 +155,7 @@ class TriggerRuleDep(BaseTIDep):
             "successes": successes,
             "skipped": skipped,
             "failed": failed,
+            "removed": removed,
             "upstream_failed": upstream_failed,
             "done": done,
         }
@@ -162,6 +166,9 @@ class TriggerRuleDep(BaseTIDep):
                     changed = ti.set_state(State.UPSTREAM_FAILED, session)
                 elif skipped:
                     changed = ti.set_state(State.SKIPPED, session)
+                elif removed and successes and ti.map_index > -1:
+                    if ti.map_index >= successes:
+                        changed = ti.set_state(State.REMOVED, session)
             elif trigger_rule == TR.ALL_FAILED:
                 if successes or skipped:
                     changed = ti.set_state(State.SKIPPED, session)
@@ -189,6 +196,7 @@ class TriggerRuleDep(BaseTIDep):
             elif trigger_rule == TR.ALL_SKIPPED:
                 if successes or failed:
                     changed = ti.set_state(State.SKIPPED, session)
+
         if changed:
             dep_context.have_changed_ti_states = True
 
@@ -212,6 +220,8 @@ class TriggerRuleDep(BaseTIDep):
                 )
         elif trigger_rule == TR.ALL_SUCCESS:
             num_failures = upstream - successes
+            if ti.map_index > -1:
+                num_failures -= removed
             if num_failures > 0:
                 yield self._failing_status(
                     reason=(
@@ -223,6 +233,8 @@ class TriggerRuleDep(BaseTIDep):
                 )
         elif trigger_rule == TR.ALL_FAILED:
             num_successes = upstream - failed - upstream_failed
+            if ti.map_index > -1:
+                num_successes -= removed
             if num_successes > 0:
                 yield self._failing_status(
                     reason=(
@@ -244,6 +256,8 @@ class TriggerRuleDep(BaseTIDep):
                 )
         elif trigger_rule == TR.NONE_FAILED:
             num_failures = upstream - successes - skipped
+            if ti.map_index > -1:
+                num_failures -= removed
             if num_failures > 0:
                 yield self._failing_status(
                     reason=(
@@ -255,6 +269,8 @@ class TriggerRuleDep(BaseTIDep):
                 )
         elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
             num_failures = upstream - successes - skipped
+            if ti.map_index > -1:
+                num_failures -= removed
             if num_failures > 0:
                 yield self._failing_status(
                     reason=(
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 16b892f76e..50e9e9a3d8 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1905,3 +1905,43 @@ def test_mapped_skip_upstream_not_deadlock(dag_maker):
     dr.update_state(session=session)
     assert dr.state == DagRunState.SUCCESS
     assert tis['add_one__1'].state == TaskInstanceState.SKIPPED
+
+
+def 
test_schedulable_task_exist_when_rerun_removed_upstream_mapped_task(session, 
dag_maker):
+    from airflow.decorators import task
+
+    @task
+    def do_something(i):
+        return 1
+
+    @task
+    def do_something_else(i):
+        return 1
+
+    with dag_maker():
+        nums = do_something.expand(i=[i + 1 for i in range(5)])
+        do_something_else.expand(i=nums)
+
+    dr = dag_maker.create_dagrun()
+
+    ti = dr.get_task_instance('do_something_else', session=session)
+    ti.map_index = 0
+    task = ti.task
+    for map_index in range(1, 5):
+        ti = TI(task, run_id=dr.run_id, map_index=map_index)
+        ti.dag_run = dr
+        session.add(ti)
+    session.flush()
+    tis = dr.get_task_instances()
+    for ti in tis:
+        if ti.task_id == 'do_something':
+            if ti.map_index > 2:
+                ti.state = TaskInstanceState.REMOVED
+            else:
+                ti.state = TaskInstanceState.SUCCESS
+            session.merge(ti)
+    session.commit()
+    # The Upstream is done with 2 removed tis and 3 success tis
+    (tis, _) = dr.update_state()
+    assert len(tis)
+    assert dr.state != DagRunState.FAILED
diff --git a/tests/models/test_taskinstance.py 
b/tests/models/test_taskinstance.py
index a9b5dd69d2..4eff50092b 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -1065,55 +1065,55 @@ class TestTaskInstance:
     # Parameterized tests to check for the correct firing
     # of the trigger_rule under various circumstances
     # Numeric fields are in order:
-    #   successes, skipped, failed, upstream_failed, done
+    #   successes, skipped, failed, upstream_failed, done, removed
     @pytest.mark.parametrize(
-        "trigger_rule,successes,skipped,failed,upstream_failed,done,"
+        "trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
         "flag_upstream_failed,expect_state,expect_completed",
         [
             #
             # Tests for all_success
             #
-            ['all_success', 5, 0, 0, 0, 0, True, None, True],
-            ['all_success', 2, 0, 0, 0, 0, True, None, False],
-            ['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False],
-            ['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False],
+            ['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
+            ['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
+            ['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
             #
             # Tests for one_success
             #
-            ['one_success', 5, 0, 0, 0, 5, True, None, True],
-            ['one_success', 2, 0, 0, 0, 2, True, None, True],
-            ['one_success', 2, 0, 1, 0, 3, True, None, True],
-            ['one_success', 2, 1, 0, 0, 3, True, None, True],
-            ['one_success', 0, 5, 0, 0, 5, True, State.SKIPPED, False],
-            ['one_success', 0, 4, 1, 0, 5, True, State.UPSTREAM_FAILED, False],
-            ['one_success', 0, 3, 1, 1, 5, True, State.UPSTREAM_FAILED, False],
-            ['one_success', 0, 4, 0, 1, 5, True, State.UPSTREAM_FAILED, False],
-            ['one_success', 0, 0, 5, 0, 5, True, State.UPSTREAM_FAILED, False],
-            ['one_success', 0, 0, 4, 1, 5, True, State.UPSTREAM_FAILED, False],
-            ['one_success', 0, 0, 0, 5, 5, True, State.UPSTREAM_FAILED, False],
+            ['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
+            ['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
+            ['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
+            ['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
+            ['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
+            ['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, 
False],
             #
             # Tests for all_failed
             #
-            ['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False],
-            ['all_failed', 0, 0, 5, 0, 5, True, None, True],
-            ['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False],
-            ['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False],
-            ['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False],
+            ['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
+            ['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
+            ['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
+            ['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
+            ['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
             #
             # Tests for one_failed
             #
-            ['one_failed', 5, 0, 0, 0, 0, True, None, False],
-            ['one_failed', 2, 0, 0, 0, 0, True, None, False],
-            ['one_failed', 2, 0, 1, 0, 0, True, None, True],
-            ['one_failed', 2, 1, 0, 0, 3, True, None, False],
-            ['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False],
+            ['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
+            ['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
+            ['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
+            ['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
+            ['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
             #
             # Tests for done
             #
-            ['all_done', 5, 0, 0, 0, 5, True, None, True],
-            ['all_done', 2, 0, 0, 0, 2, True, None, False],
-            ['all_done', 2, 0, 1, 0, 3, True, None, False],
-            ['all_done', 2, 1, 0, 0, 3, True, None, False],
+            ['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
+            ['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
+            ['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
+            ['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
         ],
     )
     def test_check_task_dependencies(
@@ -1122,6 +1122,7 @@ class TestTaskInstance:
         successes: int,
         skipped: int,
         failed: int,
+        removed: int,
         upstream_failed: int,
         done: int,
         flag_upstream_failed: bool,
@@ -1144,6 +1145,121 @@ class TestTaskInstance:
             successes=successes,
             skipped=skipped,
             failed=failed,
+            removed=removed,
+            upstream_failed=upstream_failed,
+            done=done,
+            dep_context=DepContext(),
+            flag_upstream_failed=flag_upstream_failed,
+        )
+        completed = all(dep.passed for dep in dep_results)
+
+        assert completed == expect_completed
+        assert ti.state == expect_state
+
+    # Parameterized tests to check for the correct firing
+    # of the trigger_rule under various circumstances of mapped task
+    # Numeric fields are in order:
+    #   successes, skipped, failed, upstream_failed, done,removed
+    @pytest.mark.parametrize(
+        "trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
+        "flag_upstream_failed,expect_state,expect_completed",
+        [
+            #
+            # Tests for all_success
+            #
+            ['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
+            ['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
+            ['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
+            ['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True],  # 
ti.map_index >=successes
+            #
+            # Tests for one_success
+            #
+            ['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
+            ['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
+            ['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
+            ['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
+            ['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
+            ['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            ['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, 
False],
+            #
+            # Tests for all_failed
+            #
+            ['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
+            ['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
+            ['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
+            ['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
+            ['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
+            ['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False],  # 
One removed
+            #
+            # Tests for one_failed
+            #
+            ['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
+            ['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
+            ['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
+            ['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
+            ['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
+            ['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False],  # 
One removed
+            #
+            # Tests for done
+            #
+            ['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
+            ['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
+            ['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
+            ['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
+        ],
+    )
+    def test_check_task_dependencies_for_mapped(
+        self,
+        trigger_rule: str,
+        successes: int,
+        skipped: int,
+        failed: int,
+        removed: int,
+        upstream_failed: int,
+        done: int,
+        flag_upstream_failed: bool,
+        expect_state: State,
+        expect_completed: bool,
+        dag_maker,
+        session,
+    ):
+        from airflow.decorators import task
+
+        @task
+        def do_something(i):
+            return 1
+
+        @task(trigger_rule=trigger_rule)
+        def do_something_else(i):
+            return 1
+
+        with dag_maker(dag_id='test_dag'):
+            nums = do_something.expand(i=[i + 1 for i in range(5)])
+            do_something_else.expand(i=nums)
+
+        dr = dag_maker.create_dagrun()
+
+        ti = dr.get_task_instance('do_something_else', session=session)
+        ti.map_index = 0
+        for map_index in range(1, 5):
+            ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
+            ti.dag_run = dr
+            session.add(ti)
+        session.flush()
+        downstream = ti.task
+        ti = dr.get_task_instance(task_id='do_something_else', map_index=3, 
session=session)
+        ti.task = downstream
+        dep_results = TriggerRuleDep()._evaluate_trigger_rule(
+            ti=ti,
+            successes=successes,
+            skipped=skipped,
+            failed=failed,
+            removed=removed,
             upstream_failed=upstream_failed,
             done=done,
             dep_context=DepContext(),
diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py 
b/tests/ti_deps/deps/test_trigger_rule_dep.py
index fc6a4d546c..4deeebe254 100644
--- a/tests/ti_deps/deps/test_trigger_rule_dep.py
+++ b/tests/ti_deps/deps/test_trigger_rule_dep.py
@@ -25,12 +25,13 @@ import pytest
 from airflow import settings
 from airflow.models import DAG
 from airflow.models.baseoperator import BaseOperator
+from airflow.models.taskinstance import TaskInstance
 from airflow.operators.empty import EmptyOperator
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
 from airflow.utils import timezone
 from airflow.utils.session import create_session
-from airflow.utils.state import State
+from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.trigger_rule import TriggerRule
 from tests.models import DEFAULT_DATE
 from tests.test_utils.db import clear_db_runs
@@ -53,6 +54,46 @@ def get_task_instance(session, dag_maker):
     return _get_task_instance
 
 
[email protected]
+def get_mapped_task_dagrun(session, dag_maker):
+    def _get_dagrun(trigger_rule=TriggerRule.ALL_SUCCESS, state=State.SUCCESS):
+        from airflow.decorators import task
+
+        @task
+        def do_something(i):
+            return 1
+
+        @task(trigger_rule=trigger_rule)
+        def do_something_else(i):
+            return 1
+
+        with dag_maker(dag_id='test_dag'):
+            nums = do_something.expand(i=[i + 1 for i in range(5)])
+            do_something_else.expand(i=nums)
+
+        dr = dag_maker.create_dagrun()
+
+        ti = dr.get_task_instance('do_something_else', session=session)
+        ti.map_index = 0
+        for map_index in range(1, 5):
+            ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
+            ti.dag_run = dr
+            session.add(ti)
+        session.flush()
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'do_something':
+                if ti.map_index > 2:
+                    ti.state = TaskInstanceState.REMOVED
+                else:
+                    ti.state = state
+                session.merge(ti)
+        session.commit()
+        return dr, ti.task
+
+    return _get_dagrun
+
+
 class TestTriggerRuleDep:
     def test_no_upstream_tasks(self, get_task_instance):
         """
@@ -79,6 +120,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=2,
                 failed=2,
+                removed=0,
                 upstream_failed=2,
                 done=2,
                 flag_upstream_failed=False,
@@ -99,6 +141,7 @@ class TestTriggerRuleDep:
                 successes=0,
                 skipped=2,
                 failed=2,
+                removed=0,
                 upstream_failed=2,
                 done=2,
                 flag_upstream_failed=False,
@@ -120,6 +163,7 @@ class TestTriggerRuleDep:
                 successes=2,
                 skipped=0,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -141,6 +185,7 @@ class TestTriggerRuleDep:
                 successes=0,
                 skipped=2,
                 failed=2,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -156,6 +201,7 @@ class TestTriggerRuleDep:
                 successes=0,
                 skipped=2,
                 failed=0,
+                removed=0,
                 upstream_failed=2,
                 done=2,
                 flag_upstream_failed=False,
@@ -176,6 +222,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=0,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=1,
                 flag_upstream_failed=False,
@@ -196,6 +243,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=0,
                 failed=1,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -217,6 +265,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=1,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -239,6 +288,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=1,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=True,
@@ -261,6 +311,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=1,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -281,6 +332,7 @@ class TestTriggerRuleDep:
                 successes=0,
                 skipped=2,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=True,
@@ -304,6 +356,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=1,
                 failed=1,
+                removed=0,
                 upstream_failed=0,
                 done=3,
                 flag_upstream_failed=False,
@@ -327,6 +380,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=1,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -349,6 +403,7 @@ class TestTriggerRuleDep:
                 successes=0,
                 skipped=2,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=True,
@@ -373,6 +428,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=1,
                 failed=1,
+                removed=0,
                 upstream_failed=0,
                 done=3,
                 flag_upstream_failed=False,
@@ -394,6 +450,7 @@ class TestTriggerRuleDep:
                 successes=0,
                 skipped=0,
                 failed=2,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -414,6 +471,7 @@ class TestTriggerRuleDep:
                 successes=2,
                 skipped=0,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -435,6 +493,7 @@ class TestTriggerRuleDep:
                 successes=2,
                 skipped=0,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=2,
                 flag_upstream_failed=False,
@@ -455,6 +514,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=0,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=1,
                 flag_upstream_failed=False,
@@ -479,6 +539,7 @@ class TestTriggerRuleDep:
                     successes=0,
                     skipped=3,
                     failed=0,
+                    removed=0,
                     upstream_failed=0,
                     done=3,
                     flag_upstream_failed=False,
@@ -495,6 +556,7 @@ class TestTriggerRuleDep:
                     successes=0,
                     skipped=3,
                     failed=0,
+                    removed=0,
                     upstream_failed=0,
                     done=3,
                     flag_upstream_failed=True,
@@ -515,6 +577,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=0,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=1,
                 flag_upstream_failed=False,
@@ -539,6 +602,7 @@ class TestTriggerRuleDep:
                     successes=2,
                     skipped=0,
                     failed=1,
+                    removed=0,
                     upstream_failed=0,
                     done=3,
                     flag_upstream_failed=False,
@@ -555,6 +619,7 @@ class TestTriggerRuleDep:
                     successes=0,
                     skipped=0,
                     failed=3,
+                    removed=0,
                     upstream_failed=0,
                     done=3,
                     flag_upstream_failed=True,
@@ -577,6 +642,7 @@ class TestTriggerRuleDep:
                     successes=1,
                     skipped=1,
                     failed=0,
+                    removed=0,
                     upstream_failed=0,
                     done=2,
                     flag_upstream_failed=False,
@@ -594,6 +660,7 @@ class TestTriggerRuleDep:
                     successes=1,
                     skipped=1,
                     failed=0,
+                    removed=0,
                     upstream_failed=0,
                     done=2,
                     flag_upstream_failed=True,
@@ -611,6 +678,7 @@ class TestTriggerRuleDep:
                     successes=0,
                     skipped=0,
                     failed=0,
+                    removed=0,
                     upstream_failed=0,
                     done=0,
                     flag_upstream_failed=False,
@@ -633,6 +701,7 @@ class TestTriggerRuleDep:
                 successes=1,
                 skipped=0,
                 failed=0,
+                removed=0,
                 upstream_failed=0,
                 done=1,
                 flag_upstream_failed=False,
@@ -693,10 +762,128 @@ class TestTriggerRuleDep:
 
         # check handling with cases that tasks are triggered from backfill 
with no finished tasks
         finished_tis = DepContext().ensure_finished_tis(ti_op2.dag_run, 
session)
-        assert get_states_count_upstream_ti(finished_tis=finished_tis, 
task=op2) == (1, 0, 0, 0, 1)
+        assert get_states_count_upstream_ti(finished_tis=finished_tis, 
task=op2) == (1, 0, 0, 0, 0, 1)
         finished_tis = dr.get_task_instances(state=State.finished, 
session=session)
-        assert get_states_count_upstream_ti(finished_tis=finished_tis, 
task=op4) == (1, 0, 1, 0, 2)
-        assert get_states_count_upstream_ti(finished_tis=finished_tis, 
task=op5) == (2, 0, 1, 0, 3)
+        assert get_states_count_upstream_ti(finished_tis=finished_tis, 
task=op4) == (1, 0, 1, 0, 0, 2)
+        assert get_states_count_upstream_ti(finished_tis=finished_tis, 
task=op5) == (2, 0, 1, 0, 0, 3)
 
         dr.update_state()
         assert State.SUCCESS == dr.state
+
+    def test_mapped_task_upstream_removed_with_all_success_trigger_rules(
+        self, session, get_mapped_task_dagrun
+    ):
+        """
+        Test ALL_SUCCESS trigger rule with mapped task upstream removed
+        """
+        dr, task = get_mapped_task_dagrun()
+
+        # ti with removed upstream ti
+        ti = dr.get_task_instance(task_id='do_something_else', map_index=3, 
session=session)
+        ti.task = task
+
+        dep_statuses = tuple(
+            TriggerRuleDep()._evaluate_trigger_rule(
+                ti=ti,
+                successes=3,
+                skipped=0,
+                failed=0,
+                removed=2,
+                upstream_failed=0,
+                done=5,
+                flag_upstream_failed=True,  # marks the task as removed if 
upstream is removed
+                dep_context=DepContext(),
+                session=session,
+            )
+        )
+
+        assert len(dep_statuses) == 0
+        assert ti.state == TaskInstanceState.REMOVED
+
+    def test_mapped_task_upstream_removed_with_all_failed_trigger_rules(
+        self, session, get_mapped_task_dagrun
+    ):
+        """
+        Test ALL_FAILED trigger rule with mapped task upstream removed
+        """
+
+        dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, 
state=State.FAILED)
+
+        # ti with removed upstream ti
+        ti = dr.get_task_instance(task_id='do_something_else', map_index=3, 
session=session)
+        ti.task = task
+
+        dep_statuses = tuple(
+            TriggerRuleDep()._evaluate_trigger_rule(
+                ti=ti,
+                successes=0,
+                skipped=0,
+                failed=3,
+                removed=2,
+                upstream_failed=0,
+                done=5,
+                flag_upstream_failed=False,
+                dep_context=DepContext(),
+                session=session,
+            )
+        )
+
+        assert len(dep_statuses) == 0
+
+    def test_mapped_task_upstream_removed_with_none_failed_trigger_rules(
+        self, session, get_mapped_task_dagrun
+    ):
+        """
+        Test NONE_FAILED trigger rule with mapped task upstream removed
+        """
+        dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.NONE_FAILED)
+
+        # ti with removed upstream ti
+        ti = dr.get_task_instance(task_id='do_something_else', map_index=3, 
session=session)
+        ti.task = task
+
+        dep_statuses = tuple(
+            TriggerRuleDep()._evaluate_trigger_rule(
+                ti=ti,
+                successes=3,
+                skipped=0,
+                failed=0,
+                removed=2,
+                upstream_failed=0,
+                done=5,
+                flag_upstream_failed=False,
+                dep_context=DepContext(),
+                session=session,
+            )
+        )
+
+        assert len(dep_statuses) == 0
+
+    def 
test_mapped_task_upstream_removed_with_none_failed_min_one_success_trigger_rules(
+        self, session, get_mapped_task_dagrun
+    ):
+        """
+        Test NONE_FAILED_MIN_ONE_SUCCESS trigger rule with mapped task 
upstream removed
+        """
+        dr, task = 
get_mapped_task_dagrun(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
+
+        # ti with removed upstream ti
+        ti = dr.get_task_instance(task_id='do_something_else', map_index=3, 
session=session)
+        ti.task = task
+
+        dep_statuses = tuple(
+            TriggerRuleDep()._evaluate_trigger_rule(
+                ti=ti,
+                successes=3,
+                skipped=0,
+                failed=0,
+                removed=2,
+                upstream_failed=0,
+                done=5,
+                flag_upstream_failed=False,
+                dep_context=DepContext(),
+                session=session,
+            )
+        )
+
+        assert len(dep_statuses) == 0

Reply via email to