[AIRFLOW-719] Fix race condition in ShortCircuit, Branch and LatestOnly Both the ShortCircuitOperator, Branchoperator and LatestOnlyOperator were arbitrarily changing the states of TaskInstances without locking them in the database. As the scheduler checks the state of dag runs asynchronously the dag run state could be set to failed while the operators are updating the downstream tasks.
A better fix would to use the dag run iteself in the context of the Operator. Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/eb705fd5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/eb705fd5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/eb705fd5 Branch: refs/heads/master Commit: eb705fd55c30cea778282140d927f51b4a649c73 Parents: 92965e8 Author: Bolke de Bruin <[email protected]> Authored: Tue Mar 28 16:29:39 2017 -0700 Committer: Bolke de Bruin <[email protected]> Committed: Mon Apr 3 10:38:12 2017 +0200 ---------------------------------------------------------------------- airflow/operators/latest_only_operator.py | 30 ++++- airflow/operators/python_operator.py | 82 +++++++++--- scripts/ci/requirements.txt | 1 + tests/operators/__init__.py | 2 + tests/operators/latest_only_operator.py | 2 +- tests/operators/python_operator.py | 167 ++++++++++++++++++++++++- 6 files changed, 258 insertions(+), 26 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/airflow/operators/latest_only_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py index 8b4e614..9d5defb 100644 --- a/airflow/operators/latest_only_operator.py +++ b/airflow/operators/latest_only_operator.py @@ -34,7 +34,7 @@ class LatestOnlyOperator(BaseOperator): def execute(self, context): # If the DAG Run is externally triggered, then return without # skipping downstream tasks - if context['dag_run'].external_trigger: + if context['dag_run'] and context['dag_run'].external_trigger: logging.info("""Externally triggered DAG_Run: allowing execution to proceed.""") return @@ -46,17 +46,39 @@ class LatestOnlyOperator(BaseOperator): logging.info( 'Checking latest only with left_window: %s right_window: %s ' 'now: %s', left_window, right_window, now) + if not left_window < now <= right_window: logging.info('Not latest execution, skipping downstream.') session = settings.Session() - for task in context['task'].downstream_list: - ti = TaskInstance( - task, execution_date=context['ti'].execution_date) + + TI = TaskInstance + tis = session.query(TI).filter( + TI.execution_date == context['ti'].execution_date, + TI.task_id.in_(context['task'].downstream_task_ids) + ).with_for_update().all() + + for ti in tis: logging.info('Skipping task: %s', ti.task_id) ti.state = State.SKIPPED ti.start_date = now ti.end_date = now session.merge(ti) + + # this is defensive against dag runs that are not complete + for task in context['task'].downstream_list: + if task.task_id in tis: + continue + + logging.warning("Task {} was not part of a dag run. " + "This should not happen." + .format(task)) + now = datetime.datetime.now() + ti = TaskInstance(task, execution_date=context['ti'].execution_date) + ti.state = State.SKIPPED + ti.start_date = now + ti.end_date = now + session.merge(ti) + session.commit() session.close() logging.info('Done.') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/airflow/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index a17e6fa..cf240f2 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -109,14 +109,36 @@ class BranchPythonOperator(PythonOperator): logging.info("Following branch " + branch) logging.info("Marking other directly downstream tasks as skipped") session = settings.Session() + + TI = TaskInstance + tis = session.query(TI).filter( + TI.execution_date == context['ti'].execution_date, + TI.task_id.in_(context['task'].downstream_task_ids), + TI.task_id != branch, + ).with_for_update().all() + + for ti in tis: + logging.info('Skipping task: %s', ti.task_id) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + + # this is defensive against dag runs that are not complete for task in context['task'].downstream_list: - if task.task_id != branch: - ti = TaskInstance( - task, execution_date=context['ti'].execution_date) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - session.merge(ti) + if task.task_id in tis: + continue + + if task.task_id == branch: + continue + + logging.warning("Task {} was not part of a dag run. This should not happen." + .format(task)) + ti = TaskInstance(task, execution_date=context['ti'].execution_date) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + session.merge(ti) + session.commit() session.close() logging.info("Done.") @@ -137,19 +159,39 @@ class ShortCircuitOperator(PythonOperator): def execute(self, context): condition = super(ShortCircuitOperator, self).execute(context) logging.info("Condition result is {}".format(condition)) + if condition: logging.info('Proceeding with downstream tasks...') return - else: - logging.info('Skipping downstream tasks...') - session = settings.Session() - for task in context['task'].downstream_list: - ti = TaskInstance( - task, execution_date=context['ti'].execution_date) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - session.merge(ti) - session.commit() - session.close() - logging.info("Done.") + + logging.info('Skipping downstream tasks...') + session = settings.Session() + + TI = TaskInstance + tis = session.query(TI).filter( + TI.execution_date == context['ti'].execution_date, + TI.task_id.in_(context['task'].downstream_task_ids), + ).with_for_update().all() + + for ti in tis: + logging.info('Skipping task: %s', ti.task_id) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + + # this is defensive against dag runs that are not complete + for task in context['task'].downstream_list: + if task.task_id in tis: + continue + + logging.warning("Task {} was not part of a dag run. This should not happen." + .format(task)) + ti = TaskInstance(task, execution_date=context['ti'].execution_date) + ti.state = State.SKIPPED + ti.start_date = datetime.now() + ti.end_date = datetime.now() + session.merge(ti) + + session.commit() + session.close() + logging.info("Done.") http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/scripts/ci/requirements.txt ---------------------------------------------------------------------- diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt index 7fdd18e..d206f16 100644 --- a/scripts/ci/requirements.txt +++ b/scripts/ci/requirements.txt @@ -22,6 +22,7 @@ flask-cache flask-login==0.2.11 Flask-WTF flower +freezegun future gunicorn hdfs http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/tests/operators/__init__.py ---------------------------------------------------------------------- diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py index 7a517a1..e6f6830 100644 --- a/tests/operators/__init__.py +++ b/tests/operators/__init__.py @@ -19,3 +19,5 @@ from .sensors import * from .hive_operator import * from .s3_to_hive_operator import * from .python_operator import * +from .latest_only_operator import * + http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/tests/operators/latest_only_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/latest_only_operator.py b/tests/operators/latest_only_operator.py index 3ac5fac..9137491 100644 --- a/tests/operators/latest_only_operator.py +++ b/tests/operators/latest_only_operator.py @@ -80,7 +80,7 @@ class LatestOnlyOperatorTest(unittest.TestCase): self.assertEqual({ datetime.datetime(2016, 1, 1): 'success', datetime.datetime(2016, 1, 1, 12): 'success', - datetime.datetime(2016, 1, 2): 'success', }, + datetime.datetime(2016, 1, 2): 'success', }, exec_date_to_latest_state) downstream_instances = get_task_instances('downstream') http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/tests/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py index 621172f..3aa8b6c 100644 --- a/tests/operators/python_operator.py +++ b/tests/operators/python_operator.py @@ -18,7 +18,12 @@ import datetime import unittest from airflow import configuration, DAG -from airflow.operators.python_operator import PythonOperator +from airflow.models import TaskInstance as TI +from airflow.operators.python_operator import PythonOperator, BranchPythonOperator +from airflow.operators.python_operator import ShortCircuitOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.settings import Session +from airflow.utils.state import State from airflow.exceptions import AirflowException @@ -77,3 +82,163 @@ class PythonOperatorTest(unittest.TestCase): python_callable=not_callable, task_id='python_operator', dag=self.dag) + + +class BranchOperatorTest(unittest.TestCase): + def setUp(self): + self.dag = DAG('branch_operator_test', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL) + self.branch_op = BranchPythonOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: 'branch_1') + + self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) + self.branch_1.set_upstream(self.branch_op) + self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) + self.branch_2.set_upstream(self.branch_op) + self.dag.clear() + + def test_without_dag_run(self): + """This checks the defensive against non existent tasks in a dag run""" + self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + session = Session() + tis = session.query(TI).filter( + TI.dag_id == self.dag.dag_id, + TI.execution_date == DEFAULT_DATE + ) + session.close() + + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + # should not exist + raise + elif ti.task_id == 'branch_2': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + def test_with_dag_run(self): + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=datetime.datetime.now(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.NONE) + elif ti.task_id == 'branch_2': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + +class ShortCircuitOperatorTest(unittest.TestCase): + def setUp(self): + self.dag = DAG('shortcircuit_operator_test', + default_args={ + 'owner': 'airflow', + 'start_date': DEFAULT_DATE}, + schedule_interval=INTERVAL) + self.short_op = ShortCircuitOperator(task_id='make_choice', + dag=self.dag, + python_callable=lambda: self.value) + + self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) + self.branch_1.set_upstream(self.short_op) + self.upstream = DummyOperator(task_id='upstream', dag=self.dag) + self.upstream.set_downstream(self.short_op) + self.dag.clear() + + self.value = True + + def test_without_dag_run(self): + """This checks the defensive against non existent tasks in a dag run""" + self.value = False + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + session = Session() + tis = session.query(TI).filter( + TI.dag_id == self.dag.dag_id, + TI.execution_date == DEFAULT_DATE + ) + + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + # should not exist + raise + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + self.value = True + self.dag.clear() + + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + # should not exist + raise + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.NONE) + else: + raise + + session.close() + + def test_with_dag_run(self): + self.value = False + dr = self.dag.create_dagrun( + run_id="manual__", + start_date=datetime.datetime.now(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.SKIPPED) + else: + raise + + self.value = True + self.dag.clear() + + self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) + + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'make_choice': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'upstream': + self.assertEquals(ti.state, State.SUCCESS) + elif ti.task_id == 'branch_1': + self.assertEquals(ti.state, State.NONE) + else: + raise
