This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4dc98451ba8 fix: handle unmapped task deadlock when upstream tasks are 
removed (#62034)
4dc98451ba8 is described below

commit 4dc98451ba83a956a8c96072df171eddcf1b8775
Author: Zhen-Lun (Kevin) Hong <[email protected]>
AuthorDate: Sun Jun 14 19:36:14 2026 +0800

    fix: handle unmapped task deadlock when upstream tasks are removed (#62034)
    
    * fix: prevent deadlock when the number of mapped tasks is reduced
    
    * chore: add unit tests
    
    * chore: add test to check rerunning with an upstream task removed
    
    * add unit test of unmapped tasks
---
 .../src/airflow/ti_deps/deps/trigger_rule_dep.py   |  23 +---
 airflow-core/tests/unit/models/test_dagrun.py      | 121 +++++++++++++++++++++
 .../unit/ti_deps/deps/test_trigger_rule_dep.py     |  90 +++++++++++++++
 3 files changed, 217 insertions(+), 17 deletions(-)

diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 9f9f9bd77aa..54d5a4d9309 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -503,9 +503,7 @@ class TriggerRuleDep(BaseTIDep):
                         )
                     )
             elif trigger_rule == TR.ALL_SUCCESS:
-                num_failures = upstream - success
-                if ti.map_index > -1:
-                    num_failures -= removed
+                num_failures = upstream - success - removed
                 if num_failures > 0:
                     yield self._failing_status(
                         reason=(
@@ -516,9 +514,7 @@ class TriggerRuleDep(BaseTIDep):
                         )
                     )
             elif trigger_rule == TR.ALL_FAILED:
-                num_success = upstream - failed - upstream_failed
-                if ti.map_index > -1:
-                    num_success -= removed
+                num_success = upstream - failed - upstream_failed - removed
                 if num_success > 0:
                     yield self._failing_status(
                         reason=(
@@ -539,9 +535,7 @@ class TriggerRuleDep(BaseTIDep):
                         )
                     )
             elif trigger_rule == TR.NONE_FAILED:
-                num_failures = upstream - success - skipped
-                if ti.map_index > -1:
-                    num_failures -= removed
+                num_failures = upstream - success - skipped - removed
                 if num_failures > 0:
                     yield self._failing_status(
                         reason=(
@@ -552,9 +546,7 @@ class TriggerRuleDep(BaseTIDep):
                         )
                     )
             elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
-                num_failures = upstream - success - skipped
-                if ti.map_index > -1:
-                    num_failures -= removed
+                num_failures = upstream - success - skipped - removed
                 if num_failures > 0:
                     yield self._failing_status(
                         reason=(
@@ -614,11 +606,8 @@ class TriggerRuleDep(BaseTIDep):
                     )
             elif trigger_rule == TR.ALL_DONE_MIN_ONE_SUCCESS:
                 # For this trigger rule, skipped tasks are not considered 
"done"
-                non_skipped_done = success + failed + upstream_failed + removed
-                non_skipped_upstream = upstream - skipped
-                if ti.map_index > -1:
-                    non_skipped_upstream -= removed
-                    non_skipped_done -= removed
+                non_skipped_done = success + failed + upstream_failed
+                non_skipped_upstream = upstream - skipped - removed
 
                 if skipped > 0:
                     yield self._failing_status(
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index a8a831fb6fa..a8d384d9480 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -3229,6 +3229,127 @@ def 
test_mapped_task_rerun_with_different_length_of_args(session, dag_maker, rer
     assert len(success_tis) == rerun_length
 
 
+def test_mapped_task_length_reduction_rerun_downstream_not_deadlocked(session, 
dag_maker):
+    @task
+    def producer():
+        context = get_current_context()
+        if context["ti"].try_number == 0:
+            return [i for i in range(3)]
+        return [i for i in range(2)]
+
+    @task
+    def work(arg):
+        return arg
+
+    @task
+    def finish(data):
+        return sum(data)
+
+    def _task_ids(tis):
+        return [(ti.task_id, ti.map_index) for ti in tis]
+
+    with dag_maker(session=session):
+        produced = producer()
+        mapped = work.expand(arg=produced)
+        done = finish(produced)
+        mapped >> done
+
+    dr: DagRun = dag_maker.create_dagrun()
+
+    # First run with 3 mapped task instances.
+    dag_maker.run_ti("producer", dr)
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1), 
("work", 2)]
+
+    for ti in decision.schedulable_tis:
+        dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
+    dag_maker.run_ti("finish", dr)
+
+    # Clear and rerun with one fewer mapped task instance.
+    clear_task_instances(dr.get_task_instances(session=session), 
session=session)
+    ti = dr.get_task_instance(task_id="producer", session=session)
+    ti.try_number += 1
+    session.merge(ti)
+
+    dag_maker.run_ti("producer", dr)
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert _task_ids(decision.schedulable_tis) == [("work", 0), ("work", 1)]
+
+    mapped_states = session.execute(
+        select(TI.map_index, TI.state)
+        .where(TI.task_id == "work", TI.dag_id == dr.dag_id, TI.run_id == 
dr.run_id)
+        .order_by(TI.map_index)
+    ).all()
+    assert mapped_states == [
+        (0, State.NONE),
+        (1, State.NONE),
+        (2, TaskInstanceState.REMOVED),
+    ]
+
+    for ti in decision.schedulable_tis:
+        dag_maker.run_ti(ti.task_id, dr, map_index=ti.map_index)
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert _task_ids(decision.schedulable_tis) == [("finish", -1)]
+
+    dag_maker.run_ti("finish", dr)
+    finish_ti = dr.get_task_instance(task_id="finish", map_index=-1, 
session=session)
+    assert finish_ti
+    assert finish_ti.state == TaskInstanceState.SUCCESS
+
+
+def test_rerun_with_upstream_task_removed(session, dag_maker):
+    def _task_ids(tis):
+        return [(ti.task_id, ti.map_index) for ti in tis]
+
+    with dag_maker("test", session=session):
+        upstream_1 = EmptyOperator(task_id="upstream_1")
+        upstream_2 = EmptyOperator(task_id="upstream_2")
+        downstream = EmptyOperator(task_id="downstream")
+        [upstream_1, upstream_2] >> downstream
+
+    dr: DagRun = dag_maker.create_dagrun()
+
+    dag_maker.run_ti("upstream_1", dr)
+    dag_maker.run_ti("upstream_2", dr)
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert _task_ids(decision.schedulable_tis) == [("downstream", -1)]
+
+    dag_maker.run_ti("downstream", dr)
+    dr.update_state(session=session)
+    assert dr.state == DagRunState.SUCCESS
+
+    # Rerun with upstream_1 removed
+    with dag_maker("test", session=session, serialized=True) as dag:
+        upstream_2 = EmptyOperator(task_id="upstream_2")
+        downstream = EmptyOperator(task_id="downstream")
+        upstream_2 >> downstream
+
+    latest_version = DagVersion.get_latest_version(dag.dag_id)
+    assert latest_version.version_number == 2
+
+    clear_task_instances(
+        dr.get_task_instances(session=session),
+        session=session,
+        run_on_latest_version=True,
+    )
+
+    upstream_1 = dr.get_task_instance(task_id="upstream_1", map_index=-1, 
session=session)
+    assert upstream_1.state == TaskInstanceState.REMOVED
+
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert _task_ids(decision.schedulable_tis) == [("upstream_2", -1)]
+
+    dag_maker.run_ti("upstream_2", dr)
+    decision = dr.task_instance_scheduling_decisions(session=session)
+    assert _task_ids(decision.schedulable_tis) == [("downstream", -1)]
+
+    dag_maker.run_ti("downstream", dr)
+    dr.update_state(session=session)
+    assert dr.state == DagRunState.SUCCESS
+
+
 def test_operator_mapped_task_group_receives_value(dag_maker, session):
     with dag_maker(session=session):
 
diff --git a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py 
b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
index bd4a576f9f1..c5f805cd553 100644
--- a/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
+++ b/airflow-core/tests/unit/ti_deps/deps/test_trigger_rule_dep.py
@@ -1561,6 +1561,96 @@ class TestTriggerRuleDep:
             expected_ti_state=expected_ti_state,
         )
 
+    @pytest.mark.parametrize("flag_upstream_failed", [True, False])
+    @pytest.mark.parametrize(
+        ("trigger_rule", "upstream_states"),
+        [
+            (
+                TriggerRule.ALL_SUCCESS,
+                _UpstreamTIStates(
+                    success=3,
+                    skipped=0,
+                    failed=0,
+                    upstream_failed=0,
+                    removed=2,
+                    done=5,
+                    skipped_setup=0,
+                    success_setup=0,
+                ),
+            ),
+            (
+                TriggerRule.ALL_FAILED,
+                _UpstreamTIStates(
+                    success=0,
+                    skipped=0,
+                    failed=3,
+                    upstream_failed=0,
+                    removed=2,
+                    done=5,
+                    skipped_setup=0,
+                    success_setup=0,
+                ),
+            ),
+            (
+                TriggerRule.NONE_FAILED,
+                _UpstreamTIStates(
+                    success=3,
+                    skipped=0,
+                    failed=0,
+                    upstream_failed=0,
+                    removed=2,
+                    done=5,
+                    skipped_setup=0,
+                    success_setup=0,
+                ),
+            ),
+            (
+                TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
+                _UpstreamTIStates(
+                    success=3,
+                    skipped=0,
+                    failed=0,
+                    upstream_failed=0,
+                    removed=2,
+                    done=5,
+                    skipped_setup=0,
+                    success_setup=0,
+                ),
+            ),
+            (
+                TriggerRule.ALL_DONE_MIN_ONE_SUCCESS,
+                _UpstreamTIStates(
+                    success=3,
+                    skipped=0,
+                    failed=0,
+                    upstream_failed=0,
+                    removed=2,
+                    done=5,
+                    skipped_setup=0,
+                    success_setup=0,
+                ),
+            ),
+        ],
+    )
+    def test_non_mapped_task_ignores_removed_upstream_tis(
+        self,
+        monkeypatch,
+        session,
+        get_task_instance,
+        flag_upstream_failed,
+        trigger_rule,
+        upstream_states,
+    ):
+        """
+        Non-mapped trigger-rule checks should exclude removed upstream task 
instances.
+        """
+        ti = get_task_instance(
+            trigger_rule,
+            normal_tasks=["upstream_1", "upstream_2", "upstream_3", 
"upstream_4", "upstream_5"],
+        )
+        monkeypatch.setattr(_UpstreamTIStates, "calculate", lambda *_: 
upstream_states)
+        _test_trigger_rule(ti=ti, session=session, 
flag_upstream_failed=flag_upstream_failed)
+
 
 def test_upstream_in_mapped_group_triggers_only_relevant(dag_maker, session):
     from airflow.sdk import task, task_group

Reply via email to