This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1d43826b5d Fix dag run state determination logic re ignoring teardowns 
(#31658)
1d43826b5d is described below

commit 1d43826b5ddd40eafca738d6b640cddfdad94cfe
Author: Daniel Standish <[email protected]>
AuthorDate: Thu Jun 1 22:46:01 2023 -0700

    Fix dag run state determination logic re ignoring teardowns (#31658)
    
    If teardown not defined with `on_failure_fail_dagrun=True`, then leaves 
should be calculated as though the teardown is not there.
    
    Co-authored-by: Ephraim Anierobi <[email protected]>
---
 airflow/models/dagrun.py    | 43 ++++++++++++++++----------
 tests/models/test_dagrun.py | 73 +++++++++++++++++++++++++++++++++++++++++++--
 2 files changed, 98 insertions(+), 18 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 7579062c01..80d65350f0 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -533,6 +533,28 @@ class DagRun(Base, LoggingMixin):
             .first()
         )
 
+    def _tis_for_dagrun_state(self, *, dag, tis):
+        """
+        Return the collection of tasks that should be considered for 
evaluation of terminal dag run state.
+
+        Teardown tasks by default are not considered for the purpose of dag 
run state.  But
+        users may enable such consideration with on_failure_fail_dagrun.
+        """
+
+        def is_effective_leaf(task):
+            for down_task_id in task.downstream_task_ids:
+                down_task = dag.get_task(down_task_id)
+                if not down_task.is_teardown or 
down_task.on_failure_fail_dagrun:
+                    # we found a down task that is not ignorable; not a leaf
+                    return False
+            # we found no ignorable downstreams
+            # evaluate whether task is itself ignorable
+            return not task.is_teardown or task.on_failure_fail_dagrun
+
+        leaf_task_ids = {x.task_id for x in dag.tasks if is_effective_leaf(x)}
+        leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if 
ti.state != TaskInstanceState.REMOVED}
+        return leaf_tis
+
     @provide_session
     def update_state(
         self, session: Session = NEW_SESSION, execute_callbacks: bool = True
@@ -595,21 +617,10 @@ class DagRun(Base, LoggingMixin):
                     if changed_by_upstream:  # Something changed, we need to 
recalculate!
                         unfinished = unfinished.recalculate()
 
-        leaf_task_ids = {t.task_id for t in dag.leaves}
-        leaf_tis = {ti for ti in tis if ti.task_id in leaf_task_ids if 
ti.state != TaskInstanceState.REMOVED}
-        if dag.teardowns:
-            # when on_failure_fail_dagrun is `False`, the final state of the 
DagRun
-            # will be computed as if the teardown task simply didn't exist.
-            teardown_task_ids = {t.task_id for t in dag.teardowns}
-            upstream_of_teardowns = {t.task_id for t in 
dag.tasks_upstream_of_teardowns}
-            teardown_tis = {ti for ti in tis if ti.task_id in 
teardown_task_ids}
-            on_failure_fail_tis = (ti for ti in teardown_tis if 
ti.task.on_failure_fail_dagrun)
-            tis_upstream_of_teardowns = (ti for ti in tis if ti.task_id in 
upstream_of_teardowns)
-            leaf_tis -= teardown_tis
-            leaf_tis.update(on_failure_fail_tis, tis_upstream_of_teardowns)
-
-        # if all roots finished and at least one failed, the run failed
-        if not unfinished.tis and any(leaf_ti.state in State.failed_states for 
leaf_ti in leaf_tis):
+        tis_for_dagrun_state = self._tis_for_dagrun_state(dag=dag, tis=tis)
+
+        # if all tasks finished and at least one failed, the run failed
+        if not unfinished.tis and any(x.state in State.failed_states for x in 
tis_for_dagrun_state):
             self.log.error("Marking run %s failed", self)
             self.set_state(DagRunState.FAILED)
             self.notify_dagrun_state_changed(msg="task_failure")
@@ -630,7 +641,7 @@ class DagRun(Base, LoggingMixin):
                 )
 
         # if all leaves succeeded and no unfinished tasks, the run succeeded
-        elif not unfinished.tis and all(leaf_ti.state in State.success_states 
for leaf_ti in leaf_tis):
+        elif not unfinished.tis and all(x.state in State.success_states for x 
in tis_for_dagrun_state):
             self.log.info("Marking run %s successful", self)
             self.set_state(DagRunState.SUCCESS)
             self.notify_dagrun_state_changed(msg="success")
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index c44b0a48df..1a8bb7bd36 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import datetime
+from functools import reduce
 from typing import Mapping
 from unittest import mock
 from unittest.mock import call
@@ -2364,7 +2365,7 @@ def test_teardown_failure_behaviour_on_dagrun(dag_maker, 
session, dag_run_state,
 @pytest.mark.parametrize(
     "dag_run_state, on_failure_fail_dagrun", [[DagRunState.SUCCESS, False], 
[DagRunState.FAILED, True]]
 )
-def test_teardown_failure_on_non_leave_behaviour_on_dagrun(
+def test_teardown_failure_on_non_leaf_behaviour_on_dagrun(
     dag_maker, session, dag_run_state, on_failure_fail_dagrun
 ):
     with dag_maker():
@@ -2435,7 +2436,7 @@ def 
test_work_task_failure_when_setup_teardown_are_successful(dag_maker, session
     assert dr.state == DagRunState.FAILED
 
 
-def test_failure_of_leave_task_not_connected_to_teardown_task(dag_maker, 
session):
+def test_failure_of_leaf_task_not_connected_to_teardown_task(dag_maker, 
session):
     with dag_maker():
 
         @setup
@@ -2469,3 +2470,71 @@ def 
test_failure_of_leave_task_not_connected_to_teardown_task(dag_maker, session
     session.flush()
     dr = session.query(DagRun).one()
     assert dr.state == DagRunState.FAILED
+
+
[email protected](
+    "input, expected",
+    [
+        (["s1 >> w1 >> t1"], {"w1"}),  # t1 ignored
+        (["s1 >> w1 >> t1", "s1 >> t1"], {"w1"}),  # t1 ignored; properly 
wired to setup
+        (["s1 >> w1"], {"w1"}),  # no teardown
+        (["s1 >> w1 >> t1_"], {"t1_"}),  # t1_ is natural leaf and OFFD=True;
+        (["s1 >> w1 >> t1_", "s1 >> t1_"], {"t1_"}),  # t1_ is natural leaf 
and OFFD=True; wired to setup
+        (["s1 >> w1 >> t1_ >> w2", "s1 >> t1_"], {"w2"}),  # t1_ is not a 
natural leaf so excluded anyway
+    ],
+)
+def test_tis_considered_for_state(dag_maker, session, input, expected):
+    """
+    We use a convenience notation to wire up test scenarios:
+
+    t<num> -- teardown task
+    t<num>_ -- teardown task with on_failure_fail_dagrun = True
+    s<num> -- setup task
+    w<num> -- work task (a.k.a. normal task)
+
+    In the test input, each line is a statement. We'll automatically create 
the tasks and wire them up
+    as indicated in the test input.
+    """
+
+    @teardown()
+    def teardown_task():
+        print(1)
+
+    # todo: should not have to do this; should be able to use override
+    @teardown(on_failure_fail_dagrun=True)
+    def teardown_task_offd():
+        print(1)
+
+    @task
+    def work_task():
+        print(1)
+
+    @task
+    def setup_task():
+        print(1)
+
+    def make_task(task_id, dag):
+        """
+        Task factory helper.
+
+        Will give a setup, teardown, work, or teardown-with-dagrun-failure 
task depending on input.
+        """
+        if task_id.startswith("s"):
+            factory = setup_task
+        elif task_id.startswith("w"):
+            factory = work_task
+        elif task_id.endswith("_"):
+            factory = teardown_task_offd
+        else:
+            factory = teardown_task
+        return dag.task_dict.get(task_id) or 
factory.override(task_id=task_id)()
+
+    with dag_maker() as dag:
+        for line in input:
+            tasks = [make_task(x, dag) for x in line.split(" >> ")]
+            reduce(lambda x, y: x >> y, tasks)
+
+    dr = dag_maker.create_dagrun()
+    tis = dr.task_instance_scheduling_decisions(session).tis
+    tis_for_state = {x.task_id for x in dr._tis_for_dagrun_state(dag=dag, 
tis=tis)}
+    assert tis_for_state == expected

Reply via email to