Repository: incubator-airflow Updated Branches: refs/heads/master 2090011bb -> eaa03dbc7
[AIRFLOW-1786] Enforce correct behavior for soft-fail sensors Soft-fail sensor failure causes skip of all downstream tasks. It also enables ability to set up non-blocking and soft-fail sensors in the same way as for regular sensors. Closes #3509 from artem-kirillov/AIRFLOW-1786 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/eaa03dbc Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/eaa03dbc Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/eaa03dbc Branch: refs/heads/master Commit: eaa03dbc7404b1afbba2e232f31c1efb735a6f3b Parents: 2090011 Author: Artem Kirillov <[email protected]> Authored: Sun Jun 17 21:47:55 2018 +0200 Committer: Fokko Driesprong <[email protected]> Committed: Sun Jun 17 21:47:55 2018 +0200 ---------------------------------------------------------------------- airflow/models.py | 6 +- airflow/sensors/base_sensor_operator.py | 16 ++- tests/sensors/test_base_sensor.py | 160 +++++++++++++++++++++++++++ 3 files changed, 178 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eaa03dbc/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index eda1511..4706c2d 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -1742,7 +1742,7 @@ class TaskInstance(Base, LoggingMixin): # try_number contains the current try_number (not the next). We # only mark task instance as FAILED if the next task instance # try_number exceeds the max_tries. - if task.retries and self.try_number <= self.max_tries: + if self.is_eligible_to_retry(): self.state = State.UP_FOR_RETRY self.log.info('Marking task as UP_FOR_RETRY') if task.email_on_retry and task.email: @@ -1773,6 +1773,10 @@ class TaskInstance(Base, LoggingMixin): session.merge(self) session.commit() + def is_eligible_to_retry(self): + """Is task instance is eligible for retry""" + return self.task.retries and self.try_number <= self.max_tries + @provide_session def get_template_context(self, session=None): task = self.task http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eaa03dbc/airflow/sensors/base_sensor_operator.py ---------------------------------------------------------------------- diff --git a/airflow/sensors/base_sensor_operator.py b/airflow/sensors/base_sensor_operator.py index 45e4a52..74b0e0f 100644 --- a/airflow/sensors/base_sensor_operator.py +++ b/airflow/sensors/base_sensor_operator.py @@ -22,12 +22,12 @@ from time import sleep from airflow.exceptions import AirflowException, AirflowSensorTimeout, \ AirflowSkipException -from airflow.models import BaseOperator +from airflow.models import BaseOperator, SkipMixin from airflow.utils import timezone from airflow.utils.decorators import apply_defaults -class BaseSensorOperator(BaseOperator): +class BaseSensorOperator(BaseOperator, SkipMixin): """ Sensor operators are derived from this class an inherit these attributes. @@ -67,9 +67,19 @@ class BaseSensorOperator(BaseOperator): started_at = timezone.utcnow() while not self.poke(context): if (timezone.utcnow() - started_at).total_seconds() > self.timeout: - if self.soft_fail: + # If sensor is in soft fail mode but will be retried then + # give it a chance and fail with timeout. + # This gives the ability to set up non-blocking AND soft-fail sensors. + if self.soft_fail and not context['ti'].is_eligible_to_retry(): + self._do_skip_downstream_tasks(context) raise AirflowSkipException('Snap. Time is OUT.') else: raise AirflowSensorTimeout('Snap. Time is OUT.') sleep(self.poke_interval) self.log.info("Success criteria met. Exiting.") + + def _do_skip_downstream_tasks(self, context): + downstream_tasks = context['task'].get_flat_relatives(upstream=False) + self.log.debug("Downstream task_ids %s", downstream_tasks) + if downstream_tasks: + self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/eaa03dbc/tests/sensors/test_base_sensor.py ---------------------------------------------------------------------- diff --git a/tests/sensors/test_base_sensor.py b/tests/sensors/test_base_sensor.py new file mode 100644 index 0000000..adb7a5d --- /dev/null +++ b/tests/sensors/test_base_sensor.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest + +from airflow import DAG, configuration, settings +from airflow.exceptions import AirflowSensorTimeout +from airflow.models import DagRun, TaskInstance +from airflow.operators.dummy_operator import DummyOperator +from airflow.sensors.base_sensor_operator import BaseSensorOperator +from airflow.utils import timezone +from airflow.utils.state import State +from airflow.utils.timezone import datetime +from datetime import timedelta +from time import sleep + +configuration.load_test_config() + +DEFAULT_DATE = datetime(2015, 1, 1) +TEST_DAG_ID = 'unit_test_dag' +DUMMY_OP = 'dummy_op' +SENSOR_OP = 'sensor_op' + + +class DummySensor(BaseSensorOperator): + def __init__(self, return_value=False, **kwargs): + super(DummySensor, self).__init__(**kwargs) + self.return_value = return_value + + def poke(self, context): + return self.return_value + + +class BaseSensorTest(unittest.TestCase): + def setUp(self): + configuration.load_test_config() + args = { + 'owner': 'airflow', + 'start_date': DEFAULT_DATE + } + self.dag = DAG(TEST_DAG_ID, default_args=args) + + session = settings.Session() + session.query(DagRun).delete() + session.query(TaskInstance).delete() + session.commit() + + def _make_dag_run(self): + return self.dag.create_dagrun( + run_id='manual__', + start_date=timezone.utcnow(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + def _make_sensor(self, return_value, **kwargs): + poke_interval = 'poke_interval' + timeout = 'timeout' + if poke_interval not in kwargs: + kwargs[poke_interval] = 0 + if timeout not in kwargs: + kwargs[timeout] = 0 + + sensor = DummySensor( + task_id=SENSOR_OP, + return_value=return_value, + dag=self.dag, + **kwargs + ) + + dummy_op = DummyOperator( + task_id=DUMMY_OP, + dag=self.dag + ) + dummy_op.set_upstream(sensor) + return sensor + + @classmethod + def _run(cls, task): + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) + + def test_ok(self): + sensor = self._make_sensor(True) + dr = self._make_dag_run() + + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.SUCCESS) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + def test_fail(self): + sensor = self._make_sensor(False) + dr = self._make_dag_run() + + with self.assertRaises(AirflowSensorTimeout): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.FAILED) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + def test_soft_fail(self): + sensor = self._make_sensor(False, soft_fail=True) + dr = self._make_dag_run() + + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + self.assertEquals(ti.state, State.SKIPPED) + + def test_soft_fail_with_retries(self): + sensor = self._make_sensor( + return_value=False, + soft_fail=True, + retries=1, + retry_delay=timedelta(milliseconds=1)) + dr = self._make_dag_run() + + # first run fails and task instance is marked up to retry + with self.assertRaises(AirflowSensorTimeout): + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + if ti.task_id == SENSOR_OP: + self.assertEquals(ti.state, State.UP_FOR_RETRY) + if ti.task_id == DUMMY_OP: + self.assertEquals(ti.state, State.NONE) + + sleep(0.001) + # after retry DAG run is skipped + self._run(sensor) + tis = dr.get_task_instances() + self.assertEquals(len(tis), 2) + for ti in tis: + self.assertEquals(ti.state, State.SKIPPED)
