This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new b4c88f8e44 Fix tasks being wrongly skipped by 
schedule_after_task_execution (#23181)
b4c88f8e44 is described below

commit b4c88f8e44e61a92408ec2cb0a5490eeaf2f0dba
Author: Tanel Kiis <[email protected]>
AuthorDate: Tue Apr 26 13:53:45 2022 +0300

    Fix tasks being wrongly skipped by schedule_after_task_execution (#23181)
    
    In the reproducing example, once branch finishes, it creates a partial_dag
    which includes `task_a`, `task_b` and `task_d` (but does not include 
`task_c`
    because it's not downstream of `branch`). Looking at only this partial_dag, 
the
    "mini scheduler" determines that task_d can be skipped because its only
    upstream task in partial_dag `task_a` is in skipped state. This happens in
    `DagRun._get_ready_tis()` when calling `st.are_dependencies_met()`.
---
 airflow/models/dag.py             |  16 +++++-
 tests/jobs/test_local_task_job.py | 114 +++++++++++++++++++++++++++++++++++++-
 2 files changed, 125 insertions(+), 5 deletions(-)

diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index a96c24ca29..527928adb9 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -18,6 +18,7 @@
 
 import copy
 import functools
+import itertools
 import logging
 import os
 import pathlib
@@ -1972,7 +1973,10 @@ class DAG(LoggingMixin):
             tasks, in addition to matched tasks.
         :param include_upstream: Include all upstream tasks of matched tasks,
             in addition to matched tasks.
+        :param include_direct_upstream: Include all tasks directly upstream of 
matched
+            and downstream (if include_downstream = True) tasks
         """
+
         from airflow.models.baseoperator import BaseOperator
         from airflow.models.mappedoperator import MappedOperator
 
@@ -1992,9 +1996,12 @@ class DAG(LoggingMixin):
                 also_include.extend(t.get_flat_relatives(upstream=False))
             if include_upstream:
                 also_include.extend(t.get_flat_relatives(upstream=True))
-            elif include_direct_upstream:
+
+        direct_upstreams: List[Operator] = []
+        if include_direct_upstream:
+            for t in itertools.chain(matched_tasks, also_include):
                 upstream = (u for u in t.upstream_list if isinstance(u, 
(BaseOperator, MappedOperator)))
-                also_include.extend(upstream)
+                direct_upstreams.extend(upstream)
 
         # Compiling the unique list of tasks that made the cut
         # Make sure to not recursively deepcopy the dag or task_group while 
copying the task.
@@ -2003,7 +2010,10 @@ class DAG(LoggingMixin):
             memo.setdefault(id(t.task_group), None)
             return copy.deepcopy(t, memo)
 
-        dag.task_dict = {t.task_id: _deepcopy_task(t) for t in matched_tasks + 
also_include}
+        dag.task_dict = {
+            t.task_id: _deepcopy_task(t)
+            for t in itertools.chain(matched_tasks, also_include, 
direct_upstreams)
+        }
 
         def filter_task_group(group, parent_group):
             """Exclude tasks not included in the subdag from the given 
TaskGroup."""
diff --git a/tests/jobs/test_local_task_job.py 
b/tests/jobs/test_local_task_job.py
index 34cfc263c8..f888f5ab0c 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -30,20 +30,21 @@ import psutil
 import pytest
 
 from airflow import settings
-from airflow.exceptions import AirflowException, AirflowFailException
+from airflow.exceptions import AirflowException, AirflowFailException, 
AirflowSkipException
 from airflow.executors.sequential_executor import SequentialExecutor
 from airflow.jobs.local_task_job import LocalTaskJob
 from airflow.jobs.scheduler_job import SchedulerJob
 from airflow.models.dagbag import DagBag
 from airflow.models.taskinstance import TaskInstance
 from airflow.operators.empty import EmptyOperator
-from airflow.operators.python import PythonOperator
+from airflow.operators.python import BranchPythonOperator, PythonOperator
 from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
 from airflow.utils import timezone
 from airflow.utils.net import get_hostname
 from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
+from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from tests.test_utils import db
 from tests.test_utils.asserts import assert_queries_count
@@ -799,6 +800,115 @@ class TestLocalTaskJob:
         assert failed_deps[0].dep_name == "Previous Dagrun State"
         assert not failed_deps[0].passed
 
+    @pytest.mark.parametrize(
+        "exception, trigger_rule",
+        [
+            (AirflowFailException(), TriggerRule.ALL_DONE),
+            (AirflowFailException(), TriggerRule.ALL_FAILED),
+            (AirflowSkipException(), TriggerRule.ALL_DONE),
+            (AirflowSkipException(), TriggerRule.ALL_SKIPPED),
+            (AirflowSkipException(), TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS),
+        ],
+    )
+    @conf_vars({('scheduler', 'schedule_after_task_execution'): 'True'})
+    def test_mini_scheduler_works_with_skipped_and_failed(
+        self, exception, trigger_rule, caplog, session, dag_maker
+    ):
+        """
+        In these cases D is running, at no decision can be made about C.
+        """
+
+        def raise_():
+            raise exception
+
+        with dag_maker(catchup=False) as dag:
+            task_a = PythonOperator(task_id='A', python_callable=raise_)
+            task_b = PythonOperator(task_id='B', python_callable=lambda: True)
+            task_c = PythonOperator(task_id='C', python_callable=lambda: True, 
trigger_rule=trigger_rule)
+            task_d = PythonOperator(task_id='D', python_callable=lambda: True)
+            task_a >> task_b >> task_c
+            task_d >> task_c
+
+        dr = dag.create_dagrun(run_id='test_1', state=State.RUNNING, 
execution_date=DEFAULT_DATE)
+        ti_a = TaskInstance(task_a, run_id=dr.run_id, state=State.QUEUED)
+        ti_b = TaskInstance(task_b, run_id=dr.run_id, state=State.NONE)
+        ti_c = TaskInstance(task_c, run_id=dr.run_id, state=State.NONE)
+        ti_d = TaskInstance(task_d, run_id=dr.run_id, state=State.RUNNING)
+
+        session.merge(ti_a)
+        session.merge(ti_b)
+        session.merge(ti_c)
+        session.merge(ti_d)
+        session.flush()
+
+        job1 = LocalTaskJob(task_instance=ti_a, ignore_ti_state=True, 
executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.run()
+
+        ti_b.refresh_from_db(session)
+        ti_c.refresh_from_db(session)
+        assert ti_b.state in (State.SKIPPED, State.UPSTREAM_FAILED)
+        assert ti_c.state == State.NONE
+        assert "0 downstream tasks scheduled from follow-on schedule" in 
caplog.text
+
+        failed_deps = list(ti_c.get_failed_dep_statuses(session=session))
+        assert len(failed_deps) == 1
+        assert failed_deps[0].dep_name == "Trigger Rule"
+        assert not failed_deps[0].passed
+
+    @pytest.mark.parametrize(
+        "trigger_rule",
+        [
+            TriggerRule.ONE_SUCCESS,
+            TriggerRule.ALL_SKIPPED,
+            TriggerRule.NONE_FAILED,
+            TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
+        ],
+    )
+    @conf_vars({('scheduler', 'schedule_after_task_execution'): 'True'})
+    def test_mini_scheduler_works_with_branch_python_operator(self, 
trigger_rule, caplog, session, dag_maker):
+        """
+        In these cases D is running, at no decision can be made about C.
+        """
+        with dag_maker(catchup=False) as dag:
+            task_a = BranchPythonOperator(task_id='A', python_callable=lambda: 
[])
+            task_b = PythonOperator(task_id='B', python_callable=lambda: True)
+            task_c = PythonOperator(
+                task_id='C',
+                python_callable=lambda: True,
+                trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
+            )
+            task_d = PythonOperator(task_id='D', python_callable=lambda: True)
+            task_a >> task_b >> task_c
+            task_d >> task_c
+
+        dr = dag.create_dagrun(run_id='test_1', state=State.RUNNING, 
execution_date=DEFAULT_DATE)
+        ti_a = TaskInstance(task_a, run_id=dr.run_id, state=State.QUEUED)
+        ti_b = TaskInstance(task_b, run_id=dr.run_id, state=State.NONE)
+        ti_c = TaskInstance(task_c, run_id=dr.run_id, state=State.NONE)
+        ti_d = TaskInstance(task_d, run_id=dr.run_id, state=State.RUNNING)
+
+        session.merge(ti_a)
+        session.merge(ti_b)
+        session.merge(ti_c)
+        session.merge(ti_d)
+        session.flush()
+
+        job1 = LocalTaskJob(task_instance=ti_a, ignore_ti_state=True, 
executor=SequentialExecutor())
+        job1.task_runner = StandardTaskRunner(job1)
+        job1.run()
+
+        ti_b.refresh_from_db(session)
+        ti_c.refresh_from_db(session)
+        assert ti_b.state == State.SKIPPED
+        assert ti_c.state == State.NONE
+        assert "0 downstream tasks scheduled from follow-on schedule" in 
caplog.text
+
+        failed_deps = list(ti_c.get_failed_dep_statuses(session=session))
+        assert len(failed_deps) == 1
+        assert failed_deps[0].dep_name == "Trigger Rule"
+        assert not failed_deps[0].passed
+
     @patch('airflow.utils.process_utils.subprocess.check_call')
     def test_task_sigkill_works_with_retries(self, _check_call, caplog, 
dag_maker):
         """

Reply via email to