Repository: incubator-airflow Updated Branches: refs/heads/master ea240cd1d -> a45e2d188
[AIRFLOW-1296] Propagate SKIPPED to all downstream tasks The ShortCircuitOperator and LatestOnlyOperator did not mark all downstream tasks as skipped, but only direct downstream tasks. Closes #2365 from bolkedebruin/AIRFLOW-719-3 Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/a45e2d18 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/a45e2d18 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/a45e2d18 Branch: refs/heads/master Commit: a45e2d1888ffb19dab8401e07b10724090bf20f0 Parents: ea240cd Author: Bolke de Bruin <[email protected]> Authored: Wed Jun 21 10:12:09 2017 +0200 Committer: Bolke de Bruin <[email protected]> Committed: Wed Jun 21 10:12:09 2017 +0200 ---------------------------------------------------------------------- airflow/models.py | 41 ++++++++++++ airflow/operators/latest_only_operator.py | 43 +++---------- airflow/operators/python_operator.py | 88 ++++++-------------------- tests/operators/latest_only_operator.py | 82 ++++++++++++++++++++++++ tests/operators/python_operator.py | 16 +++-- 5 files changed, 161 insertions(+), 109 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index c628958..2c433ad 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -1802,6 +1802,47 @@ class Log(Base): self.owner = owner or task_owner +class SkipMixin(object): + def skip(self, dag_run, execution_date, tasks): + """ + Sets tasks instances to skipped from the same dag run. + :param dag_run: the DagRun for which to set the tasks to skipped + :param execution_date: execution_date + :param tasks: tasks to skip (not task_ids) + """ + if not tasks: + return + + task_ids = [d.task_id for d in tasks] + now = datetime.now() + session = settings.Session() + + if dag_run: + session.query(TaskInstance).filter( + TaskInstance.dag_id == dag_run.dag_id, + TaskInstance.execution_date == dag_run.execution_date, + TaskInstance.task_id.in_(task_ids) + ).update({TaskInstance.state : State.SKIPPED, + TaskInstance.start_date: now, + TaskInstance.end_date: now}, + synchronize_session=False) + session.commit() + else: + assert execution_date is not None, "Execution date is None and no dag run" + + logging.warning("No DAG RUN present this should not happen") + # this is defensive against dag runs that are not complete + for task in tasks: + ti = TaskInstance(task, execution_date=execution_date) + ti.state = State.SKIPPED + ti.start_date = now + ti.end_date = now + session.merge(ti) + + session.commit() + session.close() + + @functools.total_ordering class BaseOperator(object): """ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/airflow/operators/latest_only_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py index f1d8085..909a211 100644 --- a/airflow/operators/latest_only_operator.py +++ b/airflow/operators/latest_only_operator.py @@ -15,12 +15,10 @@ import datetime import logging -from airflow.models import BaseOperator, TaskInstance -from airflow.utils.state import State -from airflow import settings +from airflow.models import BaseOperator, SkipMixin -class LatestOnlyOperator(BaseOperator): +class LatestOnlyOperator(BaseOperator, SkipMixin): """ Allows a workflow to skip tasks that are not running during the most recent schedule interval. @@ -49,39 +47,14 @@ class LatestOnlyOperator(BaseOperator): if not left_window < now <= right_window: logging.info('Not latest execution, skipping downstream.') - downstream_task_ids = context['task'].downstream_task_ids - if downstream_task_ids: - session = settings.Session() - TI = TaskInstance - tis = session.query(TI).filter( - TI.execution_date == context['ti'].execution_date, - TI.task_id.in_(downstream_task_ids) - ).with_for_update().all() - for ti in tis: - logging.info('Skipping task: %s', ti.task_id) - ti.state = State.SKIPPED - ti.start_date = now - ti.end_date = now - session.merge(ti) + downstream_tasks = context['task'].get_flat_relatives(upstream=False) + logging.debug("Downstream task_ids {}".format(downstream_tasks)) - # this is defensive against dag runs that are not complete - for task in context['task'].downstream_list: - if task.task_id in tis: - continue - - logging.warning("Task {} was not part of a dag run. " - "This should not happen." - .format(task)) - now = datetime.datetime.now() - ti = TaskInstance(task, execution_date=context['ti'].execution_date) - ti.state = State.SKIPPED - ti.start_date = now - ti.end_date = now - session.merge(ti) - - session.commit() - session.close() + if downstream_tasks: + self.skip(context['dag_run'], + context['ti'].execution_date, + downstream_tasks) logging.info('Done.') else: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/airflow/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py index cf240f2..bef9bb0 100644 --- a/airflow/operators/python_operator.py +++ b/airflow/operators/python_operator.py @@ -13,14 +13,11 @@ # limitations under the License. from builtins import str -from datetime import datetime import logging from airflow.exceptions import AirflowException -from airflow.models import BaseOperator, TaskInstance -from airflow.utils.state import State +from airflow.models import BaseOperator, SkipMixin from airflow.utils.decorators import apply_defaults -from airflow import settings class PythonOperator(BaseOperator): @@ -85,7 +82,7 @@ class PythonOperator(BaseOperator): return return_value -class BranchPythonOperator(PythonOperator): +class BranchPythonOperator(PythonOperator, SkipMixin): """ Allows a workflow to "branch" or follow a single path following the execution of this task. @@ -106,45 +103,20 @@ class BranchPythonOperator(PythonOperator): """ def execute(self, context): branch = super(BranchPythonOperator, self).execute(context) - logging.info("Following branch " + branch) + logging.info("Following branch {}".format(branch)) logging.info("Marking other directly downstream tasks as skipped") - session = settings.Session() - - TI = TaskInstance - tis = session.query(TI).filter( - TI.execution_date == context['ti'].execution_date, - TI.task_id.in_(context['task'].downstream_task_ids), - TI.task_id != branch, - ).with_for_update().all() - - for ti in tis: - logging.info('Skipping task: %s', ti.task_id) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - - # this is defensive against dag runs that are not complete - for task in context['task'].downstream_list: - if task.task_id in tis: - continue - - if task.task_id == branch: - continue - - logging.warning("Task {} was not part of a dag run. This should not happen." - .format(task)) - ti = TaskInstance(task, execution_date=context['ti'].execution_date) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - session.merge(ti) - - session.commit() - session.close() + + downstream_tasks = context['task'].downstream_list + logging.debug("Downstream task_ids {}".format(downstream_tasks)) + + skip_tasks = [t for t in downstream_tasks if t.task_id != branch] + if downstream_tasks: + self.skip(context['dag_run'], context['ti'].execution_date, skip_tasks) + logging.info("Done.") -class ShortCircuitOperator(PythonOperator): +class ShortCircuitOperator(PythonOperator, SkipMixin): """ Allows a workflow to continue only if a condition is met. Otherwise, the workflow "short-circuits" and downstream tasks are skipped. @@ -165,33 +137,11 @@ class ShortCircuitOperator(PythonOperator): return logging.info('Skipping downstream tasks...') - session = settings.Session() - - TI = TaskInstance - tis = session.query(TI).filter( - TI.execution_date == context['ti'].execution_date, - TI.task_id.in_(context['task'].downstream_task_ids), - ).with_for_update().all() - - for ti in tis: - logging.info('Skipping task: %s', ti.task_id) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - - # this is defensive against dag runs that are not complete - for task in context['task'].downstream_list: - if task.task_id in tis: - continue - - logging.warning("Task {} was not part of a dag run. This should not happen." - .format(task)) - ti = TaskInstance(task, execution_date=context['ti'].execution_date) - ti.state = State.SKIPPED - ti.start_date = datetime.now() - ti.end_date = datetime.now() - session.merge(ti) - - session.commit() - session.close() + + downstream_tasks = context['task'].get_flat_relatives(upstream=False) + logging.debug("Downstream task_ids {}".format(downstream_tasks)) + + if downstream_tasks: + self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks) + logging.info("Done.") http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/tests/operators/latest_only_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/latest_only_operator.py b/tests/operators/latest_only_operator.py index 9137491..225d24f 100644 --- a/tests/operators/latest_only_operator.py +++ b/tests/operators/latest_only_operator.py @@ -23,6 +23,7 @@ from airflow.jobs import BackfillJob from airflow.models import TaskInstance from airflow.operators.latest_only_operator import LatestOnlyOperator from airflow.operators.dummy_operator import DummyOperator +from airflow.utils.state import State from freezegun import freeze_time DEFAULT_DATE = datetime.datetime(2016, 1, 1) @@ -69,10 +70,82 @@ class LatestOnlyOperatorTest(unittest.TestCase): downstream_task = DummyOperator( task_id='downstream', dag=self.dag) + downstream_task2 = DummyOperator( + task_id='downstream_2', + dag=self.dag) + + downstream_task.set_upstream(latest_task) + downstream_task2.set_upstream(downstream_task) + + latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) + + latest_instances = get_task_instances('latest') + exec_date_to_latest_state = { + ti.execution_date: ti.state for ti in latest_instances} + self.assertEqual({ + datetime.datetime(2016, 1, 1): 'success', + datetime.datetime(2016, 1, 1, 12): 'success', + datetime.datetime(2016, 1, 2): 'success', }, + exec_date_to_latest_state) + + downstream_instances = get_task_instances('downstream') + exec_date_to_downstream_state = { + ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual({ + datetime.datetime(2016, 1, 1): 'skipped', + datetime.datetime(2016, 1, 1, 12): 'skipped', + datetime.datetime(2016, 1, 2): 'success',}, + exec_date_to_downstream_state) + + downstream_instances = get_task_instances('downstream_2') + exec_date_to_downstream_state = { + ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual({ + datetime.datetime(2016, 1, 1): 'skipped', + datetime.datetime(2016, 1, 1, 12): 'skipped', + datetime.datetime(2016, 1, 2): 'success',}, + exec_date_to_downstream_state) + + def test_skipping_dagrun(self): + latest_task = LatestOnlyOperator( + task_id='latest', + dag=self.dag) + downstream_task = DummyOperator( + task_id='downstream', + dag=self.dag) + downstream_task2 = DummyOperator( + task_id='downstream_2', + dag=self.dag) + downstream_task.set_upstream(latest_task) + downstream_task2.set_upstream(downstream_task) + + dr1 = self.dag.create_dagrun( + run_id="manual__1", + start_date=datetime.datetime.now(), + execution_date=DEFAULT_DATE, + state=State.RUNNING + ) + + dr2 = self.dag.create_dagrun( + run_id="manual__2", + start_date=datetime.datetime.now(), + execution_date=datetime.datetime(2016, 1, 1, 12), + state=State.RUNNING + ) + + dr2 = self.dag.create_dagrun( + run_id="manual__3", + start_date=datetime.datetime.now(), + execution_date=END_DATE, + state=State.RUNNING + ) latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE) + downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE) latest_instances = get_task_instances('latest') exec_date_to_latest_state = { @@ -91,3 +164,12 @@ class LatestOnlyOperatorTest(unittest.TestCase): datetime.datetime(2016, 1, 1, 12): 'skipped', datetime.datetime(2016, 1, 2): 'success',}, exec_date_to_downstream_state) + + downstream_instances = get_task_instances('downstream_2') + exec_date_to_downstream_state = { + ti.execution_date: ti.state for ti in downstream_instances} + self.assertEqual({ + datetime.datetime(2016, 1, 1): 'skipped', + datetime.datetime(2016, 1, 1, 12): 'skipped', + datetime.datetime(2016, 1, 2): 'success',}, + exec_date_to_downstream_state) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/tests/operators/python_operator.py ---------------------------------------------------------------------- diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py index 3aa8b6c..71432af 100644 --- a/tests/operators/python_operator.py +++ b/tests/operators/python_operator.py @@ -26,6 +26,7 @@ from airflow.settings import Session from airflow.utils.state import State from airflow.exceptions import AirflowException +import logging DEFAULT_DATE = datetime.datetime(2016, 1, 1) END_DATE = datetime.datetime(2016, 1, 2) @@ -158,6 +159,8 @@ class ShortCircuitOperatorTest(unittest.TestCase): self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag) self.branch_1.set_upstream(self.short_op) + self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag) + self.branch_2.set_upstream(self.branch_1) self.upstream = DummyOperator(task_id='upstream', dag=self.dag) self.upstream.set_downstream(self.short_op) self.dag.clear() @@ -181,7 +184,7 @@ class ShortCircuitOperatorTest(unittest.TestCase): elif ti.task_id == 'upstream': # should not exist raise - elif ti.task_id == 'branch_1': + elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise @@ -196,7 +199,7 @@ class ShortCircuitOperatorTest(unittest.TestCase): elif ti.task_id == 'upstream': # should not exist raise - elif ti.task_id == 'branch_1': + elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.NONE) else: raise @@ -205,6 +208,7 @@ class ShortCircuitOperatorTest(unittest.TestCase): def test_with_dag_run(self): self.value = False + logging.error("Tasks {}".format(self.dag.tasks)) dr = self.dag.create_dagrun( run_id="manual__", start_date=datetime.datetime.now(), @@ -216,29 +220,31 @@ class ShortCircuitOperatorTest(unittest.TestCase): self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() + self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEquals(ti.state, State.SUCCESS) - elif ti.task_id == 'branch_1': + elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.SKIPPED) else: raise self.value = True self.dag.clear() - + dr.verify_integrity() self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE) tis = dr.get_task_instances() + self.assertEqual(len(tis), 4) for ti in tis: if ti.task_id == 'make_choice': self.assertEquals(ti.state, State.SUCCESS) elif ti.task_id == 'upstream': self.assertEquals(ti.state, State.SUCCESS) - elif ti.task_id == 'branch_1': + elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2': self.assertEquals(ti.state, State.NONE) else: raise
