[AIRFLOW-719] Prevent DAGs from ending prematurely DAGs using ALL_SUCCESS and ONE_SUCCESS trigger rules were ending prematurely when upstream tasks were skipped. Changes mean that the ALL_SUCCESS and ONE_SUCCESS triggers rule encompasses both SUCCESS and SKIPPED tasks.
Closes #2125 from dhuang/AIRFLOW-719 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/4077c6de Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/4077c6de Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/4077c6de Branch: refs/heads/v1-8-test Commit: 4077c6de297566a4c598065867a9a27324ae6eb1 Parents: 157054e Author: Daniel Huang <[email protected]> Authored: Sat Mar 4 17:33:23 2017 +0100 Committer: Bolke de Bruin <[email protected]> Committed: Sun Mar 12 08:27:30 2017 -0700 ---------------------------------------------------------------------- airflow/ti_deps/deps/trigger_rule_dep.py | 6 +- tests/dags/test_dagrun_short_circuit_false.py | 38 +++++++++++ tests/models.py | 79 +++++++++++++++++++--- 3 files changed, 111 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4077c6de/airflow/ti_deps/deps/trigger_rule_dep.py ---------------------------------------------------------------------- diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 281ed51..da13bba 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -135,7 +135,7 @@ class TriggerRuleDep(BaseTIDep): if tr == TR.ALL_SUCCESS: if upstream_failed or failed: ti.set_state(State.UPSTREAM_FAILED, session) - elif skipped: + elif skipped == upstream: ti.set_state(State.SKIPPED, session) elif tr == TR.ALL_FAILED: if successes or skipped: @@ -148,7 +148,7 @@ class TriggerRuleDep(BaseTIDep): ti.set_state(State.SKIPPED, session) if tr == TR.ONE_SUCCESS: - if successes <= 0: + if successes <= 0 and skipped <= 0: yield self._failing_status( reason="Task's trigger rule '{0}' requires one upstream " "task success, but none were found. " @@ -162,7 +162,7 @@ class TriggerRuleDep(BaseTIDep): "upstream_tasks_state={1}, upstream_task_ids={2}" .format(tr, upstream_tasks_state, task.upstream_task_ids)) elif tr == TR.ALL_SUCCESS: - num_failures = upstream - successes + num_failures = upstream - (successes + skipped) if num_failures > 0: yield self._failing_status( reason="Task's trigger rule '{0}' requires all upstream " http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4077c6de/tests/dags/test_dagrun_short_circuit_false.py ---------------------------------------------------------------------- diff --git a/tests/dags/test_dagrun_short_circuit_false.py b/tests/dags/test_dagrun_short_circuit_false.py new file mode 100644 index 0000000..805ab67 --- /dev/null +++ b/tests/dags/test_dagrun_short_circuit_false.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from airflow.models import DAG +from airflow.operators.python_operator import ShortCircuitOperator +from airflow.operators.dummy_operator import DummyOperator + + +# DAG that has its short circuit op fail and skip multiple downstream tasks +dag = DAG( + dag_id='test_dagrun_short_circuit_false', + start_date=datetime(2017, 1, 1) +) +dag_task1 = ShortCircuitOperator( + task_id='test_short_circuit_false', + dag=dag, + python_callable=lambda: False) +dag_task2 = DummyOperator( + task_id='test_state_skipped1', + dag=dag) +dag_task3 = DummyOperator( + task_id='test_state_skipped2', + dag=dag) +dag_task1.set_downstream(dag_task2) +dag_task2.set_downstream(dag_task3) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/4077c6de/tests/models.py ---------------------------------------------------------------------- diff --git a/tests/models.py b/tests/models.py index 7ca01e7..d904ff3 100644 --- a/tests/models.py +++ b/tests/models.py @@ -34,6 +34,7 @@ from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils.state import State from mock import patch from nose_parameterized import parameterized +from tests.core import TEST_DAG_FOLDER DEFAULT_DATE = datetime.datetime(2016, 1, 1) TEST_DAGS_FOLDER = os.path.join( @@ -117,13 +118,71 @@ class DagTest(unittest.TestCase): self.assertEqual(dag.dag_id, 'creating_dag_in_cm') self.assertEqual(dag.tasks[0].task_id, 'op6') + class DagRunTest(unittest.TestCase): + + def setUp(self): + self.dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER) + + def create_dag_run(self, dag_id, state=State.RUNNING, task_states=None): + now = datetime.datetime.now() + dag = self.dagbag.get_dag(dag_id) + dag_run = dag.create_dagrun( + run_id='manual__' + now.isoformat(), + execution_date=now, + start_date=now, + state=State.RUNNING, + external_trigger=False, + ) + + if task_states is not None: + session = settings.Session() + for task_id, state in task_states.items(): + ti = dag_run.get_task_instance(task_id) + ti.set_state(state, session) + session.close() + + return dag_run + def test_id_for_date(self): run_id = models.DagRun.id_for_date( datetime.datetime(2015, 1, 2, 3, 4, 5, 6, None)) - assert run_id == 'scheduled__2015-01-02T03:04:05', ( + self.assertEqual( + 'scheduled__2015-01-02T03:04:05', run_id, 'Generated run_id did not match expectations: {0}'.format(run_id)) + def test_dagrun_running_when_upstream_skipped(self): + """ + Tests that a DAG run is not failed when an upstream task is skipped + """ + initial_task_states = { + 'test_short_circuit_false': State.SUCCESS, + 'test_state_skipped1': State.SKIPPED, + 'test_state_skipped2': State.NONE, + } + # dags/test_dagrun_short_circuit_false.py + dag_run = self.create_dag_run('test_dagrun_short_circuit_false', + state=State.RUNNING, + task_states=initial_task_states) + updated_dag_state = dag_run.update_state() + self.assertEqual(State.RUNNING, updated_dag_state) + + def test_dagrun_success_when_all_skipped(self): + """ + Tests that a DAG run succeeds when all tasks are skipped + """ + initial_task_states = { + 'test_short_circuit_false': State.SUCCESS, + 'test_state_skipped1': State.SKIPPED, + 'test_state_skipped2': State.SKIPPED, + } + # dags/test_dagrun_short_circuit_false.py + dag_run = self.create_dag_run('test_dagrun_short_circuit_false', + state=State.RUNNING, + task_states=initial_task_states) + updated_dag_state = dag_run.update_state() + self.assertEqual(State.SUCCESS, updated_dag_state) + class DagBagTest(unittest.TestCase): @@ -501,7 +560,7 @@ class TaskInstanceTest(unittest.TestCase): self.assertEqual(dt, ti.end_date+max_delay) def test_depends_on_past(self): - dagbag = models.DagBag() + dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER) dag = dagbag.get_dag('test_depends_on_past') dag.clear() task = dag.tasks[0] @@ -530,10 +589,11 @@ class TaskInstanceTest(unittest.TestCase): # # Tests for all_success # - ['all_success', 5, 0, 0, 0, 0, True, None, True], - ['all_success', 2, 0, 0, 0, 0, True, None, False], - ['all_success', 2, 0, 1, 0, 0, True, ST.UPSTREAM_FAILED, False], - ['all_success', 2, 1, 0, 0, 0, True, ST.SKIPPED, False], + ['all_success', 5, 0, 0, 0, 5, True, None, True], + ['all_success', 2, 0, 0, 0, 2, True, None, False], + ['all_success', 2, 0, 1, 0, 3, True, ST.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 3, True, None, False], + ['all_success', 0, 5, 0, 0, 5, True, ST.SKIPPED, True], # # Tests for one_success # @@ -541,6 +601,7 @@ class TaskInstanceTest(unittest.TestCase): ['one_success', 2, 0, 0, 0, 2, True, None, True], ['one_success', 2, 0, 1, 0, 3, True, None, True], ['one_success', 2, 1, 0, 0, 3, True, None, True], + ['one_success', 0, 2, 0, 0, 2, True, None, True], # # Tests for all_failed # @@ -552,9 +613,9 @@ class TaskInstanceTest(unittest.TestCase): # # Tests for one_failed # - ['one_failed', 5, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 1, 0, 0, True, None, True], + ['one_failed', 5, 0, 0, 0, 5, True, ST.SKIPPED, False], + ['one_failed', 2, 0, 0, 0, 2, True, None, False], + ['one_failed', 2, 0, 1, 0, 2, True, None, True], ['one_failed', 2, 1, 0, 0, 3, True, None, False], ['one_failed', 2, 3, 0, 0, 5, True, ST.SKIPPED, False], #
