Repository: incubator-airflow
Updated Branches:
  refs/heads/master 2090011bb -> eaa03dbc7


[AIRFLOW-1786] Enforce correct behavior for soft-fail sensors

Soft-fail sensor failure causes skip of all
downstream tasks. It also enables ability to set
up non-blocking and soft-fail sensors in the same
way as for regular sensors.

Closes #3509 from artem-kirillov/AIRFLOW-1786


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/eaa03dbc
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/eaa03dbc
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/eaa03dbc

Branch: refs/heads/master
Commit: eaa03dbc7404b1afbba2e232f31c1efb735a6f3b
Parents: 2090011
Author: Artem Kirillov <[email protected]>
Authored: Sun Jun 17 21:47:55 2018 +0200
Committer: Fokko Driesprong <[email protected]>
Committed: Sun Jun 17 21:47:55 2018 +0200

----------------------------------------------------------------------
 airflow/models.py                       |   6 +-
 airflow/sensors/base_sensor_operator.py |  16 ++-
 tests/sensors/test_base_sensor.py       | 160 +++++++++++++++++++++++++++
 3 files changed, 178 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eaa03dbc/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index eda1511..4706c2d 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -1742,7 +1742,7 @@ class TaskInstance(Base, LoggingMixin):
             # try_number contains the current try_number (not the next). We
             # only mark task instance as FAILED if the next task instance
             # try_number exceeds the max_tries.
-            if task.retries and self.try_number <= self.max_tries:
+            if self.is_eligible_to_retry():
                 self.state = State.UP_FOR_RETRY
                 self.log.info('Marking task as UP_FOR_RETRY')
                 if task.email_on_retry and task.email:
@@ -1773,6 +1773,10 @@ class TaskInstance(Base, LoggingMixin):
             session.merge(self)
         session.commit()
 
+    def is_eligible_to_retry(self):
+        """Is task instance is eligible for retry"""
+        return self.task.retries and self.try_number <= self.max_tries
+
     @provide_session
     def get_template_context(self, session=None):
         task = self.task

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eaa03dbc/airflow/sensors/base_sensor_operator.py
----------------------------------------------------------------------
diff --git a/airflow/sensors/base_sensor_operator.py 
b/airflow/sensors/base_sensor_operator.py
index 45e4a52..74b0e0f 100644
--- a/airflow/sensors/base_sensor_operator.py
+++ b/airflow/sensors/base_sensor_operator.py
@@ -22,12 +22,12 @@ from time import sleep
 
 from airflow.exceptions import AirflowException, AirflowSensorTimeout, \
     AirflowSkipException
-from airflow.models import BaseOperator
+from airflow.models import BaseOperator, SkipMixin
 from airflow.utils import timezone
 from airflow.utils.decorators import apply_defaults
 
 
-class BaseSensorOperator(BaseOperator):
+class BaseSensorOperator(BaseOperator, SkipMixin):
     """
     Sensor operators are derived from this class an inherit these attributes.
 
@@ -67,9 +67,19 @@ class BaseSensorOperator(BaseOperator):
         started_at = timezone.utcnow()
         while not self.poke(context):
             if (timezone.utcnow() - started_at).total_seconds() > self.timeout:
-                if self.soft_fail:
+                # If sensor is in soft fail mode but will be retried then
+                # give it a chance and fail with timeout.
+                # This gives the ability to set up non-blocking AND soft-fail 
sensors.
+                if self.soft_fail and not context['ti'].is_eligible_to_retry():
+                    self._do_skip_downstream_tasks(context)
                     raise AirflowSkipException('Snap. Time is OUT.')
                 else:
                     raise AirflowSensorTimeout('Snap. Time is OUT.')
             sleep(self.poke_interval)
         self.log.info("Success criteria met. Exiting.")
+
+    def _do_skip_downstream_tasks(self, context):
+        downstream_tasks = context['task'].get_flat_relatives(upstream=False)
+        self.log.debug("Downstream task_ids %s", downstream_tasks)
+        if downstream_tasks:
+            self.skip(context['dag_run'], context['ti'].execution_date, 
downstream_tasks)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eaa03dbc/tests/sensors/test_base_sensor.py
----------------------------------------------------------------------
diff --git a/tests/sensors/test_base_sensor.py 
b/tests/sensors/test_base_sensor.py
new file mode 100644
index 0000000..adb7a5d
--- /dev/null
+++ b/tests/sensors/test_base_sensor.py
@@ -0,0 +1,160 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+from airflow import DAG, configuration, settings
+from airflow.exceptions import AirflowSensorTimeout
+from airflow.models import DagRun, TaskInstance
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.sensors.base_sensor_operator import BaseSensorOperator
+from airflow.utils import timezone
+from airflow.utils.state import State
+from airflow.utils.timezone import datetime
+from datetime import timedelta
+from time import sleep
+
+configuration.load_test_config()
+
+DEFAULT_DATE = datetime(2015, 1, 1)
+TEST_DAG_ID = 'unit_test_dag'
+DUMMY_OP = 'dummy_op'
+SENSOR_OP = 'sensor_op'
+
+
+class DummySensor(BaseSensorOperator):
+    def __init__(self, return_value=False, **kwargs):
+        super(DummySensor, self).__init__(**kwargs)
+        self.return_value = return_value
+
+    def poke(self, context):
+        return self.return_value
+
+
+class BaseSensorTest(unittest.TestCase):
+    def setUp(self):
+        configuration.load_test_config()
+        args = {
+            'owner': 'airflow',
+            'start_date': DEFAULT_DATE
+        }
+        self.dag = DAG(TEST_DAG_ID, default_args=args)
+
+        session = settings.Session()
+        session.query(DagRun).delete()
+        session.query(TaskInstance).delete()
+        session.commit()
+
+    def _make_dag_run(self):
+        return self.dag.create_dagrun(
+            run_id='manual__',
+            start_date=timezone.utcnow(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+
+    def _make_sensor(self, return_value, **kwargs):
+        poke_interval = 'poke_interval'
+        timeout = 'timeout'
+        if poke_interval not in kwargs:
+            kwargs[poke_interval] = 0
+        if timeout not in kwargs:
+            kwargs[timeout] = 0
+
+        sensor = DummySensor(
+            task_id=SENSOR_OP,
+            return_value=return_value,
+            dag=self.dag,
+            **kwargs
+        )
+
+        dummy_op = DummyOperator(
+            task_id=DUMMY_OP,
+            dag=self.dag
+        )
+        dummy_op.set_upstream(sensor)
+        return sensor
+
+    @classmethod
+    def _run(cls, task):
+        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, 
ignore_ti_state=True)
+
+    def test_ok(self):
+        sensor = self._make_sensor(True)
+        dr = self._make_dag_run()
+
+        self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.SUCCESS)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+    def test_fail(self):
+        sensor = self._make_sensor(False)
+        dr = self._make_dag_run()
+
+        with self.assertRaises(AirflowSensorTimeout):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.FAILED)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+    def test_soft_fail(self):
+        sensor = self._make_sensor(False, soft_fail=True)
+        dr = self._make_dag_run()
+
+        self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            self.assertEquals(ti.state, State.SKIPPED)
+
+    def test_soft_fail_with_retries(self):
+        sensor = self._make_sensor(
+            return_value=False,
+            soft_fail=True,
+            retries=1,
+            retry_delay=timedelta(milliseconds=1))
+        dr = self._make_dag_run()
+
+        # first run fails and task instance is marked up to retry
+        with self.assertRaises(AirflowSensorTimeout):
+            self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            if ti.task_id == SENSOR_OP:
+                self.assertEquals(ti.state, State.UP_FOR_RETRY)
+            if ti.task_id == DUMMY_OP:
+                self.assertEquals(ti.state, State.NONE)
+
+        sleep(0.001)
+        # after retry DAG run is skipped
+        self._run(sensor)
+        tis = dr.get_task_instances()
+        self.assertEquals(len(tis), 2)
+        for ti in tis:
+            self.assertEquals(ti.state, State.SKIPPED)

Reply via email to