Repository: incubator-airflow Updated Branches: refs/heads/v1-8-test 9070a8277 -> dff6d21bf
Merge pull request #2195 from bolkedebruin/AIRFLOW-719 (cherry picked from commit 4a6bef69d1817a5fc3ddd6ffe14c2578eaa49cf0) Signed-off-by: Bolke de Bruin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/dff6d21b Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/dff6d21b Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/dff6d21b Branch: refs/heads/v1-8-test Commit: dff6d21bfd9a2585ca484fc8fd56aa100f640908 Parents: 9070a82 Author: Bolke de Bruin <[email protected]> Authored: Tue Apr 4 17:04:12 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Wed Apr 5 19:16:22 2017 +0200 ---------------------------------------------------------------------- airflow/operators/latest_only_operator.py | 30 ++- airflow/operators/python_operator.py | 82 +++++-- airflow/ti_deps/deps/trigger_rule_dep.py | 6 +- scripts/ci/requirements.txt | 1 + tests/dags/test_dagrun_short_circuit_false.py | 38 ---- tests/models.py | 77 +++---- tests/operators/__init__.py | 2 + tests/operators/latest_only_operator.py | 12 +- tests/operators/python_operator.py | 244 +++++++++++++++++++++ 9 files changed, 384 insertions(+), 108 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/operators/latest_only_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py index 8b4e614..9d5defb 100644 --- a/airflow/operators/latest_only_operator.py +++ b/airflow/operators/latest_only_operator.py @@ -34,7 +34,7 @@ class LatestOnlyOperator(BaseOperator): def execute(self, context): # If the DAG Run is externally triggered, then return without # skipping downstream tasks - if context['dag_run'].external_trigger: + if context['dag_run'] and context['dag_run'].external_trigger: logging.info("""Externally triggered DAG_Run: allowing execution to proceed.""") return @@ -46,17 +46,39 @@ class LatestOnlyOperator(BaseOperator): logging.info( 'Checking latest only with left_window: %s right_window: %s ' 'now: %s', left_window, right_window, now) + if not left_window < now <= right_window: logging.info('Not latest execution, skipping downstream.') session = settings.Session() - for task in context['task'].downstream_list: - ti = TaskInstance( - task, execution_date=context['ti'].execution_date) + + TI = TaskInstance + tis = session.query(TI).filter( + TI.execution_date == context['ti'].execution_date, + TI.task_id.in_(context['task'].downstream_task_ids) + ).with_for_update().all() + + for ti in tis: logging.info('Skipping task: %s', ti.task_id) ti.state = State.SKIPPED ti.start_date = now ti.end_date = now session.merge(ti) + + # this is defensive against dag runs that are not complete + for task in context['task'].downstream_list: + if task.task_id in tis: + continue + + logging.warning("Task {} was not part of a dag run. " + "This should not happen." + .format(task)) + now = datetime.datetime.now() + ti = TaskInstance(task, execution_date=context['ti'].execution_date) + ti.state = State.SKIPPED + ti.start_date = now + ti.end_date = now + session.merge(ti) + session.commit() session.close() logging.info('Done.') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index b5f6386..114bc7e 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -106,14 +106,36 @@ class BranchPythonOperator(PythonOperator): logging.info("Following branch " + branch) logging.info("Marking other directly downstream tasks as skipped") session = settings.Session() + + TI = TaskInstance + tis = session.query(TI).filter( + TI.execution_date == context['ti'].execution_date, + TI.task_id.in_(context['task'].downstream_task_ids), + TI.task_id != branch, + ).with_for_update().all() + + for ti in tis: + logging.info('Skipping task: %s', ti.task_id) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + + # this is defensive against dag runs that are not complete for task in context['task'].downstream_list: - if task.task_id != branch: - ti = TaskInstance( - task, execution_date=context['ti'].execution_date) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - session.merge(ti) + if task.task_id in tis: + continue + + if task.task_id == branch: + continue + + logging.warning("Task {} was not part of a dag run. This should not happen." + .format(task)) + ti = TaskInstance(task, execution_date=context['ti'].execution_date) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + session.merge(ti) + session.commit() session.close() logging.info("Done.") @@ -134,19 +156,39 @@ class ShortCircuitOperator(PythonOperator): def execute(self, context): condition = super(ShortCircuitOperator, self).execute(context) logging.info("Condition result is {}".format(condition)) + if condition: logging.info('Proceeding with downstream tasks...') return - else: - logging.info('Skipping downstream tasks...') - session = settings.Session() - for task in context['task'].downstream_list: - ti = TaskInstance( - task, execution_date=context['ti'].execution_date) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - session.merge(ti) - session.commit() - session.close() - logging.info("Done.") + + logging.info('Skipping downstream tasks...') + session = settings.Session() + + TI = TaskInstance + tis = session.query(TI).filter( + TI.execution_date == context['ti'].execution_date, + TI.task_id.in_(context['task'].downstream_task_ids), + ).with_for_update().all() + + for ti in tis: + logging.info('Skipping task: %s', ti.task_id) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + + # this is defensive against dag runs that are not complete + for task in context['task'].downstream_list: + if task.task_id in tis: + continue + + logging.warning("Task {} was not part of a dag run. This should not happen." + .format(task)) + ti = TaskInstance(task, execution_date=context['ti'].execution_date) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + session.merge(ti) + + session.commit() + session.close() + logging.info("Done.") http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/ti_deps/deps/trigger_rule_dep.py ---------------------------------------------------------------------- diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index da13bba..281ed51 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -135,7 +135,7 @@ class TriggerRuleDep(BaseTIDep): if tr == TR.ALL_SUCCESS: if upstream_failed or failed: ti.set_state(State.UPSTREAM_FAILED, session) - elif skipped == upstream: + elif skipped: ti.set_state(State.SKIPPED, session) elif tr == TR.ALL_FAILED: if successes or skipped: @@ -148,7 +148,7 @@ class TriggerRuleDep(BaseTIDep): ti.set_state(State.SKIPPED, session) if tr == TR.ONE_SUCCESS: - if successes <= 0 and skipped <= 0: + if successes <= 0: yield self._failing_status( reason="Task's trigger rule '{0}' requires one upstream " "task success, but none were found. " @@ -162,7 +162,7 @@ class TriggerRuleDep(BaseTIDep): "upstream_tasks_state={1}, upstream_task_ids={2}" .format(tr, upstream_tasks_state, task.upstream_task_ids)) elif tr == TR.ALL_SUCCESS: - num_failures = upstream - (successes + skipped) + num_failures = upstream - successes if num_failures > 0: yield self._failing_status( reason="Task's trigger rule '{0}' requires all upstream " http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/scripts/ci/requirements.txt ---------------------------------------------------------------------- diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt index a5786f6..9a2bce2 100644 --- a/scripts/ci/requirements.txt +++ b/scripts/ci/requirements.txt @@ -20,6 +20,7 @@ flask-cache flask-login==0.2.11 Flask-WTF flower +freezegun future gunicorn hdfs http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/dags/test_dagrun_short_circuit_false.py ---------------------------------------------------------------------- diff --git a/tests/dags/test_dagrun_short_circuit_false.py b/tests/dags/test_dagrun_short_circuit_false.py deleted file mode 100644 index 805ab67..0000000 --- a/tests/dags/test_dagrun_short_circuit_false.py +++ /dev/null @@ -1,38 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime - -from airflow.models import DAG -from airflow.operators.python_operator import ShortCircuitOperator -from airflow.operators.dummy_operator import DummyOperator - - -# DAG that has its short circuit op fail and skip multiple downstream tasks -dag = DAG( - dag_id='test_dagrun_short_circuit_false', - start_date=datetime(2017, 1, 1) -) -dag_task1 = ShortCircuitOperator( - task_id='test_short_circuit_false', - dag=dag, - python_callable=lambda: False) -dag_task2 = DummyOperator( - task_id='test_state_skipped1', - dag=dag) -dag_task3 = DummyOperator( - task_id='test_state_skipped2', - dag=dag) -dag_task1.set_downstream(dag_task2) -dag_task2.set_downstream(dag_task3) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/models.py ---------------------------------------------------------------------- diff --git a/tests/models.py b/tests/models.py index 83183f8..9478088 100644 --- a/tests/models.py +++ b/tests/models.py @@ -31,11 +31,12 @@ from airflow.models import DagModel from airflow.operators.dummy_operator import DummyOperator from airflow.operators.bash_operator import BashOperator from airflow.operators.python_operator import PythonOperator +from airflow.operators.python_operator import ShortCircuitOperator from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils.state import State from mock import patch from nose_parameterized import parameterized -from tests.core import TEST_DAG_FOLDER + DEFAULT_DATE = datetime.datetime(2016, 1, 1) TEST_DAGS_FOLDER = os.path.join( @@ -235,17 +236,13 @@ class DagTest(unittest.TestCase): class DagRunTest(unittest.TestCase): - def setUp(self): - self.dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER) - - def create_dag_run(self, dag_id, state=State.RUNNING, task_states=None): + def create_dag_run(self, dag, state=State.RUNNING, task_states=None): now = datetime.datetime.now() - dag = self.dagbag.get_dag(dag_id) dag_run = dag.create_dagrun( run_id='manual__' + now.isoformat(), execution_date=now, start_date=now, - state=State.RUNNING, + state=state, external_trigger=False, ) @@ -298,33 +295,34 @@ class DagRunTest(unittest.TestCase): self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id2, external_trigger=True))) self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id2, external_trigger=False))) - def test_dagrun_running_when_upstream_skipped(self): - """ - Tests that a DAG run is not failed when an upstream task is skipped - """ - initial_task_states = { - 'test_short_circuit_false': State.SUCCESS, - 'test_state_skipped1': State.SKIPPED, - 'test_state_skipped2': State.NONE, - } - # dags/test_dagrun_short_circuit_false.py - dag_run = self.create_dag_run('test_dagrun_short_circuit_false', - state=State.RUNNING, - task_states=initial_task_states) - updated_dag_state = dag_run.update_state() - self.assertEqual(State.RUNNING, updated_dag_state) - def test_dagrun_success_when_all_skipped(self): """ Tests that a DAG run succeeds when all tasks are skipped """ + dag = DAG( + dag_id='test_dagrun_success_when_all_skipped', + start_date=datetime.datetime(2017, 1, 1) + ) + dag_task1 = ShortCircuitOperator( + task_id='test_short_circuit_false', + dag=dag, + python_callable=lambda: False) + dag_task2 = DummyOperator( + task_id='test_state_skipped1', + dag=dag) + dag_task3 = DummyOperator( + task_id='test_state_skipped2', + dag=dag) + dag_task1.set_downstream(dag_task2) + dag_task2.set_downstream(dag_task3) + initial_task_states = { 'test_short_circuit_false': State.SUCCESS, 'test_state_skipped1': State.SKIPPED, 'test_state_skipped2': State.SKIPPED, } - # dags/test_dagrun_short_circuit_false.py - dag_run = self.create_dag_run('test_dagrun_short_circuit_false', + + dag_run = self.create_dag_run(dag=dag, state=State.RUNNING, task_states=initial_task_states) updated_dag_state = dag_run.update_state() @@ -385,10 +383,17 @@ class DagRunTest(unittest.TestCase): """ Make sure that a proper value is returned when a dagrun has no task instances """ + dag = DAG( + dag_id='test_get_task_instance_on_empty_dagrun', + start_date=datetime.datetime(2017, 1, 1) + ) + dag_task1 = ShortCircuitOperator( + task_id='test_short_circuit_false', + dag=dag, + python_callable=lambda: False) + session = settings.Session() - # Any dag will work for this - dag = self.dagbag.get_dag('test_dagrun_short_circuit_false') now = datetime.datetime.now() # Don't use create_dagrun since it will create the task instances too which we @@ -784,7 +789,7 @@ class TaskInstanceTest(unittest.TestCase): self.assertEqual(dt, ti.end_date+max_delay) def test_depends_on_past(self): - dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER) + dagbag = models.DagBag() dag = dagbag.get_dag('test_depends_on_past') dag.clear() task = dag.tasks[0] @@ -813,11 +818,10 @@ class TaskInstanceTest(unittest.TestCase): # # Tests for all_success # - ['all_success', 5, 0, 0, 0, 5, True, None, True], - ['all_success', 2, 0, 0, 0, 2, True, None, False], - ['all_success', 2, 0, 1, 0, 3, True, ST.UPSTREAM_FAILED, False], - ['all_success', 2, 1, 0, 0, 3, True, None, False], - ['all_success', 0, 5, 0, 0, 5, True, ST.SKIPPED, True], + ['all_success', 5, 0, 0, 0, 0, True, None, True], + ['all_success', 2, 0, 0, 0, 0, True, None, False], + ['all_success', 2, 0, 1, 0, 0, True, ST.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 0, True, ST.SKIPPED, False], # # Tests for one_success # @@ -825,7 +829,6 @@ class TaskInstanceTest(unittest.TestCase): ['one_success', 2, 0, 0, 0, 2, True, None, True], ['one_success', 2, 0, 1, 0, 3, True, None, True], ['one_success', 2, 1, 0, 0, 3, True, None, True], - ['one_success', 0, 2, 0, 0, 2, True, None, True], # # Tests for all_failed # @@ -837,9 +840,9 @@ class TaskInstanceTest(unittest.TestCase): # # Tests for one_failed # - ['one_failed', 5, 0, 0, 0, 5, True, ST.SKIPPED, False], - ['one_failed', 2, 0, 0, 0, 2, True, None, False], - ['one_failed', 2, 0, 1, 0, 2, True, None, True], + ['one_failed', 5, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 1, 0, 0, True, None, True], ['one_failed', 2, 1, 0, 0, 3, True, None, False], ['one_failed', 2, 3, 0, 0, 5, True, ST.SKIPPED, False], # http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/__init__.py ---------------------------------------------------------------------- diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py index 1fb0e5e..aeb243c 100644 --- a/tests/operators/__init__.py +++ b/tests/operators/__init__.py @@ -18,3 +18,5 @@ from .operators import * from .sensors import * from .hive_operator import * from .s3_to_hive_operator import * +from .python_operator import * +from .latest_only_operator import * http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/latest_only_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/latest_only_operator.py b/tests/operators/latest_only_operator.py index 37aec38..9137491 100644 --- a/tests/operators/latest_only_operator.py +++ b/tests/operators/latest_only_operator.py @@ -77,17 +77,17 @@ class LatestOnlyOperatorTest(unittest.TestCase): latest_instances = get_task_instances('latest') exec_date_to_latest_state = { ti.execution_date: ti.state for ti in latest_instances} - assert exec_date_to_latest_state == { + self.assertEqual({ datetime.datetime(2016, 1, 1): 'success', datetime.datetime(2016, 1, 1, 12): 'success', - datetime.datetime(2016, 1, 2): 'success', - } + datetime.datetime(2016, 1, 2): 'success', }, + exec_date_to_latest_state) downstream_instances = get_task_instances('downstream') exec_date_to_downstream_state = { ti.execution_date: ti.state for ti in downstream_instances} - assert exec_date_to_downstream_state == { + self.assertEqual({ datetime.datetime(2016, 1, 1): 'skipped', datetime.datetime(2016, 1, 1, 12): 'skipped', - datetime.datetime(2016, 1, 2): 'success', - } + datetime.datetime(2016, 1, 2): 'success',}, + exec_date_to_downstream_state) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py new file mode 100644 index 0000000..3aa8b6c --- /dev/null +++ b/tests/operators/python_operator.py @@ -0,0 +1,244 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function, unicode_literals + +import datetime +import unittest + +from airflow import configuration, DAG +from airflow.models import TaskInstance as TI +from airflow.operators.python_operator import PythonOperator, BranchPythonOperator +from airflow.operators.python_operator import ShortCircuitOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.settings import Session +from airflow.utils.state import State + +from airflow.exceptions import AirflowException + +DEFAULT_DATE = datetime.datetime(2016, 1, 1) +END_DATE = datetime.datetime(2016, 1, 2) +INTERVAL = datetime.timedelta(hours=12) +FROZEN_NOW = datetime.datetime(2016, 1, 2, 12, 1, 1) + + +class PythonOperatorTest(unittest.TestCase): + + def setUp(self): + super(PythonOperatorTest, self).setUp() + configuration.load_test_config() + self.dag = DAG( + 'test_dag', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL) + self.addCleanup(self.dag.clear) + self.clear_run() + self.addCleanup(self.clear_run) + + def do_run(self): + self.run = True + + def clear_run(self): + self.run = False + + def is_run(self): + return self.run + + def test_python_operator_run(self): + """Tests that the python callable is invoked on task run.""" + task = PythonOperator( + python_callable=self.do_run, + task_id='python_operator', + dag=self.dag) + self.assertFalse(self.is_run()) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.assertTrue(self.is_run()) + + def test_python_operator_python_callable_is_callable(self): + """Tests that PythonOperator will only instantiate if + the python_callable argument is callable.""" + not_callable = {} + with self.assertRaises(AirflowException): + PythonOperator( + python_callable=not_callable, + task_id='python_operator', + dag=self.dag) + not_callable = None + with self.assertRaises(AirflowException): + PythonOperator( + python_callable=not_callable, + task_id='python_operator', + dag=self.dag) + + +class BranchOperatorTest(unittest.TestCase): + def setUp(self): + self.dag = DAG('branch_operator_test', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL) + self.branch_op = BranchPythonOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: 'branch_1') + + self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) + self.branch_1.set_upstream(self.branch_op) + self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) + self.branch_2.set_upstream(self.branch_op) + self.dag.clear() + + def test_without_dag_run(self): + """This checks the defensive against non existent tasks in a dag run""" + self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + session = Session() + tis = session.query(TI).filter( + TI.dag_id == self.dag.dag_id, + TI.execution_date == DEFAULT_DATE + ) + session.close() + + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + # should not exist + raise + elif ti.task_id == 'branch_2': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + def test_with_dag_run(self): + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=datetime.datetime.now(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.NONE) + elif ti.task_id == 'branch_2': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + +class ShortCircuitOperatorTest(unittest.TestCase): + def setUp(self): + self.dag = DAG('shortcircuit_operator_test', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL) + self.short_op = ShortCircuitOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: self.value) + + self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) + self.branch_1.set_upstream(self.short_op) + self.upstream = DummyOperator(task_id='upstream', dag=self.dag) + self.upstream.set_downstream(self.short_op) + self.dag.clear() + + self.value = True + + def test_without_dag_run(self): + """This checks the defensive against non existent tasks in a dag run""" + self.value = False + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + session = Session() + tis = session.query(TI).filter( + TI.dag_id == self.dag.dag_id, + TI.execution_date == DEFAULT_DATE + ) + + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + # should not exist + raise + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + self.value = True + self.dag.clear() + + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + # should not exist + raise + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.NONE) + else: + raise + + session.close() + + def test_with_dag_run(self): + self.value = False + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=datetime.datetime.now(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + self.value = True + self.dag.clear() + + self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.NONE) + else: + raise
