[AIRFLOW-719] Fix race condition in ShortCircuit, Branch and LatestOnly

Both the ShortCircuitOperator, Branchoperator and LatestOnlyOperator
 were arbitrarily changing the states of TaskInstances without locking
them in the database. As the scheduler checks the state of dag runs
asynchronously the dag run state could be set to failed while the
operators are updating the downstream tasks.

A better fix would to use the dag run iteself in the context of the
Operator.


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/eb705fd5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/eb705fd5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/eb705fd5

Branch: refs/heads/master
Commit: eb705fd55c30cea778282140d927f51b4a649c73
Parents: 92965e8
Author: Bolke de Bruin <[email protected]>
Authored: Tue Mar 28 16:29:39 2017 -0700
Committer: Bolke de Bruin <[email protected]>
Committed: Mon Apr 3 10:38:12 2017 +0200

----------------------------------------------------------------------
 airflow/operators/latest_only_operator.py |  30 ++++-
 airflow/operators/python_operator.py      |  82 +++++++++---
 scripts/ci/requirements.txt               |   1 +
 tests/operators/__init__.py               |   2 +
 tests/operators/latest_only_operator.py   |   2 +-
 tests/operators/python_operator.py        | 167 ++++++++++++++++++++++++-
 6 files changed, 258 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/airflow/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/latest_only_operator.py 
b/airflow/operators/latest_only_operator.py
index 8b4e614..9d5defb 100644
--- a/airflow/operators/latest_only_operator.py
+++ b/airflow/operators/latest_only_operator.py
@@ -34,7 +34,7 @@ class LatestOnlyOperator(BaseOperator):
     def execute(self, context):
         # If the DAG Run is externally triggered, then return without
         # skipping downstream tasks
-        if context['dag_run'].external_trigger:
+        if context['dag_run'] and context['dag_run'].external_trigger:
             logging.info("""Externally triggered DAG_Run:
                          allowing execution to proceed.""")
             return
@@ -46,17 +46,39 @@ class LatestOnlyOperator(BaseOperator):
         logging.info(
             'Checking latest only with left_window: %s right_window: %s '
             'now: %s', left_window, right_window, now)
+
         if not left_window < now <= right_window:
             logging.info('Not latest execution, skipping downstream.')
             session = settings.Session()
-            for task in context['task'].downstream_list:
-                ti = TaskInstance(
-                    task, execution_date=context['ti'].execution_date)
+
+            TI = TaskInstance
+            tis = session.query(TI).filter(
+                TI.execution_date == context['ti'].execution_date,
+                TI.task_id.in_(context['task'].downstream_task_ids)
+            ).with_for_update().all()
+
+            for ti in tis:
                 logging.info('Skipping task: %s', ti.task_id)
                 ti.state = State.SKIPPED
                 ti.start_date = now
                 ti.end_date = now
                 session.merge(ti)
+
+            # this is defensive against dag runs that are not complete
+            for task in context['task'].downstream_list:
+                if task.task_id in tis:
+                    continue
+
+                logging.warning("Task {} was not part of a dag run. "
+                                "This should not happen."
+                                .format(task))
+                now = datetime.datetime.now()
+                ti = TaskInstance(task, 
execution_date=context['ti'].execution_date)
+                ti.state = State.SKIPPED
+                ti.start_date = now
+                ti.end_date = now
+                session.merge(ti)
+
             session.commit()
             session.close()
             logging.info('Done.')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/airflow/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/python_operator.py 
b/airflow/operators/python_operator.py
index a17e6fa..cf240f2 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -109,14 +109,36 @@ class BranchPythonOperator(PythonOperator):
         logging.info("Following branch " + branch)
         logging.info("Marking other directly downstream tasks as skipped")
         session = settings.Session()
+
+        TI = TaskInstance
+        tis = session.query(TI).filter(
+            TI.execution_date == context['ti'].execution_date,
+            TI.task_id.in_(context['task'].downstream_task_ids),
+            TI.task_id != branch,
+        ).with_for_update().all()
+
+        for ti in tis:
+            logging.info('Skipping task: %s', ti.task_id)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+
+        # this is defensive against dag runs that are not complete
         for task in context['task'].downstream_list:
-            if task.task_id != branch:
-                ti = TaskInstance(
-                    task, execution_date=context['ti'].execution_date)
-                ti.state = State.SKIPPED
-                ti.start_date = datetime.now()
-                ti.end_date = datetime.now()
-                session.merge(ti)
+            if task.task_id in tis:
+                continue
+
+            if task.task_id == branch:
+                continue
+
+            logging.warning("Task {} was not part of a dag run. This should 
not happen."
+                            .format(task))
+            ti = TaskInstance(task, 
execution_date=context['ti'].execution_date)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+            session.merge(ti)
+
         session.commit()
         session.close()
         logging.info("Done.")
@@ -137,19 +159,39 @@ class ShortCircuitOperator(PythonOperator):
     def execute(self, context):
         condition = super(ShortCircuitOperator, self).execute(context)
         logging.info("Condition result is {}".format(condition))
+
         if condition:
             logging.info('Proceeding with downstream tasks...')
             return
-        else:
-            logging.info('Skipping downstream tasks...')
-            session = settings.Session()
-            for task in context['task'].downstream_list:
-                ti = TaskInstance(
-                    task, execution_date=context['ti'].execution_date)
-                ti.state = State.SKIPPED
-                ti.start_date = datetime.now()
-                ti.end_date = datetime.now()
-                session.merge(ti)
-            session.commit()
-            session.close()
-            logging.info("Done.")
+
+        logging.info('Skipping downstream tasks...')
+        session = settings.Session()
+
+        TI = TaskInstance
+        tis = session.query(TI).filter(
+            TI.execution_date == context['ti'].execution_date,
+            TI.task_id.in_(context['task'].downstream_task_ids),
+        ).with_for_update().all()
+
+        for ti in tis:
+            logging.info('Skipping task: %s', ti.task_id)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+
+        # this is defensive against dag runs that are not complete
+        for task in context['task'].downstream_list:
+            if task.task_id in tis:
+                continue
+
+            logging.warning("Task {} was not part of a dag run. This should 
not happen."
+                            .format(task))
+            ti = TaskInstance(task, 
execution_date=context['ti'].execution_date)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+            session.merge(ti)
+
+        session.commit()
+        session.close()
+        logging.info("Done.")

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/scripts/ci/requirements.txt
----------------------------------------------------------------------
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index 7fdd18e..d206f16 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -22,6 +22,7 @@ flask-cache
 flask-login==0.2.11
 Flask-WTF
 flower
+freezegun
 future
 gunicorn
 hdfs

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/tests/operators/__init__.py
----------------------------------------------------------------------
diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py
index 7a517a1..e6f6830 100644
--- a/tests/operators/__init__.py
+++ b/tests/operators/__init__.py
@@ -19,3 +19,5 @@ from .sensors import *
 from .hive_operator import *
 from .s3_to_hive_operator import *
 from .python_operator import *
+from .latest_only_operator import *
+

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/tests/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/latest_only_operator.py 
b/tests/operators/latest_only_operator.py
index 3ac5fac..9137491 100644
--- a/tests/operators/latest_only_operator.py
+++ b/tests/operators/latest_only_operator.py
@@ -80,7 +80,7 @@ class LatestOnlyOperatorTest(unittest.TestCase):
         self.assertEqual({
             datetime.datetime(2016, 1, 1): 'success',
             datetime.datetime(2016, 1, 1, 12): 'success',
-            datetime.datetime(2016, 1, 2): 'success', }, 
+            datetime.datetime(2016, 1, 2): 'success', },
             exec_date_to_latest_state)
 
         downstream_instances = get_task_instances('downstream')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eb705fd5/tests/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/python_operator.py 
b/tests/operators/python_operator.py
index 621172f..3aa8b6c 100644
--- a/tests/operators/python_operator.py
+++ b/tests/operators/python_operator.py
@@ -18,7 +18,12 @@ import datetime
 import unittest
 
 from airflow import configuration, DAG
-from airflow.operators.python_operator import PythonOperator
+from airflow.models import TaskInstance as TI
+from airflow.operators.python_operator import PythonOperator, 
BranchPythonOperator
+from airflow.operators.python_operator import ShortCircuitOperator
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.settings import Session
+from airflow.utils.state import State
 
 from airflow.exceptions import AirflowException
 
@@ -77,3 +82,163 @@ class PythonOperatorTest(unittest.TestCase):
                 python_callable=not_callable,
                 task_id='python_operator',
                 dag=self.dag)
+
+
+class BranchOperatorTest(unittest.TestCase):
+    def setUp(self):
+        self.dag = DAG('branch_operator_test',
+                       default_args={
+                           'owner': 'airflow',
+                           'start_date': DEFAULT_DATE},
+                       schedule_interval=INTERVAL)
+        self.branch_op = BranchPythonOperator(task_id='make_choice',
+                                              dag=self.dag,
+                                              python_callable=lambda: 
'branch_1')
+
+        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
+        self.branch_2.set_upstream(self.branch_op)
+        self.dag.clear()
+
+    def test_without_dag_run(self):
+        """This checks the defensive against non existent tasks in a dag run"""
+        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        session = Session()
+        tis = session.query(TI).filter(
+            TI.dag_id == self.dag.dag_id,
+            TI.execution_date == DEFAULT_DATE
+        )
+        session.close()
+
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                # should not exist
+                raise
+            elif ti.task_id == 'branch_2':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+    def test_with_dag_run(self):
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=datetime.datetime.now(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+
+        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.NONE)
+            elif ti.task_id == 'branch_2':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+
+class ShortCircuitOperatorTest(unittest.TestCase):
+    def setUp(self):
+        self.dag = DAG('shortcircuit_operator_test',
+                       default_args={
+                           'owner': 'airflow',
+                           'start_date': DEFAULT_DATE},
+                       schedule_interval=INTERVAL)
+        self.short_op = ShortCircuitOperator(task_id='make_choice',
+                                             dag=self.dag,
+                                             python_callable=lambda: 
self.value)
+
+        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
+        self.branch_1.set_upstream(self.short_op)
+        self.upstream = DummyOperator(task_id='upstream', dag=self.dag)
+        self.upstream.set_downstream(self.short_op)
+        self.dag.clear()
+
+        self.value = True
+
+    def test_without_dag_run(self):
+        """This checks the defensive against non existent tasks in a dag run"""
+        self.value = False
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        session = Session()
+        tis = session.query(TI).filter(
+            TI.dag_id == self.dag.dag_id,
+            TI.execution_date == DEFAULT_DATE
+        )
+
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                # should not exist
+                raise
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+        self.value = True
+        self.dag.clear()
+
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                # should not exist
+                raise
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.NONE)
+            else:
+                raise
+
+        session.close()
+
+    def test_with_dag_run(self):
+        self.value = False
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=datetime.datetime.now(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+
+        self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+        self.value = True
+        self.dag.clear()
+
+        self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.NONE)
+            else:
+                raise

Reply via email to