This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b4c88f8e44 Fix tasks being wrongly skipped by
schedule_after_task_execution (#23181)
b4c88f8e44 is described below
commit b4c88f8e44e61a92408ec2cb0a5490eeaf2f0dba
Author: Tanel Kiis <[email protected]>
AuthorDate: Tue Apr 26 13:53:45 2022 +0300
Fix tasks being wrongly skipped by schedule_after_task_execution (#23181)
In the reproducing example, once branch finishes, it creates a partial_dag
which includes `task_a`, `task_b` and `task_d` (but does not include
`task_c`
because it's not downstream of `branch`). Looking at only this partial_dag,
the
"mini scheduler" determines that task_d can be skipped because its only
upstream task in partial_dag `task_a` is in skipped state. This happens in
`DagRun._get_ready_tis()` when calling `st.are_dependencies_met()`.
---
airflow/models/dag.py | 16 +++++-
tests/jobs/test_local_task_job.py | 114 +++++++++++++++++++++++++++++++++++++-
2 files changed, 125 insertions(+), 5 deletions(-)
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index a96c24ca29..527928adb9 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -18,6 +18,7 @@
import copy
import functools
+import itertools
import logging
import os
import pathlib
@@ -1972,7 +1973,10 @@ class DAG(LoggingMixin):
tasks, in addition to matched tasks.
:param include_upstream: Include all upstream tasks of matched tasks,
in addition to matched tasks.
+ :param include_direct_upstream: Include all tasks directly upstream of
matched
+ and downstream (if include_downstream = True) tasks
"""
+
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
@@ -1992,9 +1996,12 @@ class DAG(LoggingMixin):
also_include.extend(t.get_flat_relatives(upstream=False))
if include_upstream:
also_include.extend(t.get_flat_relatives(upstream=True))
- elif include_direct_upstream:
+
+ direct_upstreams: List[Operator] = []
+ if include_direct_upstream:
+ for t in itertools.chain(matched_tasks, also_include):
upstream = (u for u in t.upstream_list if isinstance(u,
(BaseOperator, MappedOperator)))
- also_include.extend(upstream)
+ direct_upstreams.extend(upstream)
# Compiling the unique list of tasks that made the cut
# Make sure to not recursively deepcopy the dag or task_group while
copying the task.
@@ -2003,7 +2010,10 @@ class DAG(LoggingMixin):
memo.setdefault(id(t.task_group), None)
return copy.deepcopy(t, memo)
- dag.task_dict = {t.task_id: _deepcopy_task(t) for t in matched_tasks +
also_include}
+ dag.task_dict = {
+ t.task_id: _deepcopy_task(t)
+ for t in itertools.chain(matched_tasks, also_include,
direct_upstreams)
+ }
def filter_task_group(group, parent_group):
"""Exclude tasks not included in the subdag from the given
TaskGroup."""
diff --git a/tests/jobs/test_local_task_job.py
b/tests/jobs/test_local_task_job.py
index 34cfc263c8..f888f5ab0c 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -30,20 +30,21 @@ import psutil
import pytest
from airflow import settings
-from airflow.exceptions import AirflowException, AirflowFailException
+from airflow.exceptions import AirflowException, AirflowFailException,
AirflowSkipException
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.operators.empty import EmptyOperator
-from airflow.operators.python import PythonOperator
+from airflow.operators.python import BranchPythonOperator, PythonOperator
from airflow.task.task_runner.standard_task_runner import StandardTaskRunner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
+from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.types import DagRunType
from tests.test_utils import db
from tests.test_utils.asserts import assert_queries_count
@@ -799,6 +800,115 @@ class TestLocalTaskJob:
assert failed_deps[0].dep_name == "Previous Dagrun State"
assert not failed_deps[0].passed
+ @pytest.mark.parametrize(
+ "exception, trigger_rule",
+ [
+ (AirflowFailException(), TriggerRule.ALL_DONE),
+ (AirflowFailException(), TriggerRule.ALL_FAILED),
+ (AirflowSkipException(), TriggerRule.ALL_DONE),
+ (AirflowSkipException(), TriggerRule.ALL_SKIPPED),
+ (AirflowSkipException(), TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS),
+ ],
+ )
+ @conf_vars({('scheduler', 'schedule_after_task_execution'): 'True'})
+ def test_mini_scheduler_works_with_skipped_and_failed(
+ self, exception, trigger_rule, caplog, session, dag_maker
+ ):
+ """
+ In these cases D is running, at no decision can be made about C.
+ """
+
+ def raise_():
+ raise exception
+
+ with dag_maker(catchup=False) as dag:
+ task_a = PythonOperator(task_id='A', python_callable=raise_)
+ task_b = PythonOperator(task_id='B', python_callable=lambda: True)
+ task_c = PythonOperator(task_id='C', python_callable=lambda: True,
trigger_rule=trigger_rule)
+ task_d = PythonOperator(task_id='D', python_callable=lambda: True)
+ task_a >> task_b >> task_c
+ task_d >> task_c
+
+ dr = dag.create_dagrun(run_id='test_1', state=State.RUNNING,
execution_date=DEFAULT_DATE)
+ ti_a = TaskInstance(task_a, run_id=dr.run_id, state=State.QUEUED)
+ ti_b = TaskInstance(task_b, run_id=dr.run_id, state=State.NONE)
+ ti_c = TaskInstance(task_c, run_id=dr.run_id, state=State.NONE)
+ ti_d = TaskInstance(task_d, run_id=dr.run_id, state=State.RUNNING)
+
+ session.merge(ti_a)
+ session.merge(ti_b)
+ session.merge(ti_c)
+ session.merge(ti_d)
+ session.flush()
+
+ job1 = LocalTaskJob(task_instance=ti_a, ignore_ti_state=True,
executor=SequentialExecutor())
+ job1.task_runner = StandardTaskRunner(job1)
+ job1.run()
+
+ ti_b.refresh_from_db(session)
+ ti_c.refresh_from_db(session)
+ assert ti_b.state in (State.SKIPPED, State.UPSTREAM_FAILED)
+ assert ti_c.state == State.NONE
+ assert "0 downstream tasks scheduled from follow-on schedule" in
caplog.text
+
+ failed_deps = list(ti_c.get_failed_dep_statuses(session=session))
+ assert len(failed_deps) == 1
+ assert failed_deps[0].dep_name == "Trigger Rule"
+ assert not failed_deps[0].passed
+
+ @pytest.mark.parametrize(
+ "trigger_rule",
+ [
+ TriggerRule.ONE_SUCCESS,
+ TriggerRule.ALL_SKIPPED,
+ TriggerRule.NONE_FAILED,
+ TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
+ ],
+ )
+ @conf_vars({('scheduler', 'schedule_after_task_execution'): 'True'})
+ def test_mini_scheduler_works_with_branch_python_operator(self,
trigger_rule, caplog, session, dag_maker):
+ """
+ In these cases D is running, at no decision can be made about C.
+ """
+ with dag_maker(catchup=False) as dag:
+ task_a = BranchPythonOperator(task_id='A', python_callable=lambda:
[])
+ task_b = PythonOperator(task_id='B', python_callable=lambda: True)
+ task_c = PythonOperator(
+ task_id='C',
+ python_callable=lambda: True,
+ trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS,
+ )
+ task_d = PythonOperator(task_id='D', python_callable=lambda: True)
+ task_a >> task_b >> task_c
+ task_d >> task_c
+
+ dr = dag.create_dagrun(run_id='test_1', state=State.RUNNING,
execution_date=DEFAULT_DATE)
+ ti_a = TaskInstance(task_a, run_id=dr.run_id, state=State.QUEUED)
+ ti_b = TaskInstance(task_b, run_id=dr.run_id, state=State.NONE)
+ ti_c = TaskInstance(task_c, run_id=dr.run_id, state=State.NONE)
+ ti_d = TaskInstance(task_d, run_id=dr.run_id, state=State.RUNNING)
+
+ session.merge(ti_a)
+ session.merge(ti_b)
+ session.merge(ti_c)
+ session.merge(ti_d)
+ session.flush()
+
+ job1 = LocalTaskJob(task_instance=ti_a, ignore_ti_state=True,
executor=SequentialExecutor())
+ job1.task_runner = StandardTaskRunner(job1)
+ job1.run()
+
+ ti_b.refresh_from_db(session)
+ ti_c.refresh_from_db(session)
+ assert ti_b.state == State.SKIPPED
+ assert ti_c.state == State.NONE
+ assert "0 downstream tasks scheduled from follow-on schedule" in
caplog.text
+
+ failed_deps = list(ti_c.get_failed_dep_statuses(session=session))
+ assert len(failed_deps) == 1
+ assert failed_deps[0].dep_name == "Trigger Rule"
+ assert not failed_deps[0].passed
+
@patch('airflow.utils.process_utils.subprocess.check_call')
def test_task_sigkill_works_with_retries(self, _check_call, caplog,
dag_maker):
"""