[AIRFLOW-869] Refactor mark success functionality This refactors the mark success functionality in a more generic function that can set multiple states and properly drills down on SubDags.
Closes #2085 from bolkedebruin/AIRFLOW-869 (cherry picked from commit 28cfd2c541c12468b3e4f634545dfa31a77b0091) Signed-off-by: Bolke de Bruin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/563cc9a3 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/563cc9a3 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/563cc9a3 Branch: refs/heads/v1-8-stable Commit: 563cc9a3c8414725a615a93d3910e7a2dbb94999 Parents: eddecd5 Author: Bolke de Bruin <[email protected]> Authored: Fri Feb 17 09:05:41 2017 +0100 Committer: Bolke de Bruin <[email protected]> Committed: Fri Feb 17 09:11:41 2017 +0100 ---------------------------------------------------------------------- airflow/api/common/experimental/mark_tasks.py | 187 ++++++++++++++++++ airflow/jobs.py | 4 +- airflow/models.py | 18 +- airflow/www/templates/airflow/dag.html | 5 - airflow/www/views.py | 119 +++--------- tests/api/__init__.py | 2 + tests/api/common/__init__.py | 13 ++ tests/api/common/mark_tasks.py | 211 +++++++++++++++++++++ tests/core.py | 46 +++-- tests/dags/test_example_bash_operator.py | 55 ++++++ tests/models.py | 2 +- 11 files changed, 536 insertions(+), 126 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/api/common/experimental/mark_tasks.py ---------------------------------------------------------------------- diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py new file mode 100644 index 0000000..0ddbf98 --- /dev/null +++ b/airflow/api/common/experimental/mark_tasks.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime + +from airflow.jobs import BackfillJob +from airflow.models import DagRun, TaskInstance +from airflow.operators.subdag_operator import SubDagOperator +from airflow.settings import Session +from airflow.utils.state import State + +from sqlalchemy import or_ + + +def _create_dagruns(dag, execution_dates, state, run_id_template): + """ + Infers from the dates which dag runs need to be created and does so. + :param dag: the dag to create dag runs for + :param execution_dates: list of execution dates to evaluate + :param state: the state to set the dag run to + :param run_id_template:the template for run id to be with the execution date + :return: newly created and existing dag runs for the execution dates supplied + """ + # find out if we need to create any dag runs + drs = DagRun.find(dag_id=dag.dag_id, execution_date=execution_dates) + dates_to_create = list(set(execution_dates) - set([dr.execution_date for dr in drs])) + + for date in dates_to_create: + dr = dag.create_dagrun( + run_id=run_id_template.format(date.isoformat()), + execution_date=date, + start_date=datetime.datetime.now(), + external_trigger=False, + state=state, + ) + drs.append(dr) + + return drs + + +def set_state(task, execution_date, upstream=False, downstream=False, + future=False, past=False, state=State.SUCCESS, commit=False): + """ + Set the state of a task instance and if needed its relatives. Can set state + for future tasks (calculated from execution_date) and retroactively + for past tasks. Will verify integrity of past dag runs in order to create + tasks that did not exist. It will not create dag runs that are missing + on the schedule (but it will as for subdag dag runs if needed). + :param task: the task from which to work. task.task.dag needs to be set + :param execution_date: the execution date from which to start looking + :param upstream: Mark all parents (upstream tasks) + :param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags + :param future: Mark all future tasks on the interval of the dag up until + last execution date. + :param past: Retroactively mark all tasks starting from start_date of the DAG + :param state: State to which the tasks need to be set + :param commit: Commit tasks to be altered to the database + :return: list of tasks that have been created and updated + """ + assert isinstance(execution_date, datetime.datetime) + + # microseconds are supported by the database, but is not handled + # correctly by airflow on e.g. the filesystem and in other places + execution_date = execution_date.replace(microsecond=0) + + assert task.dag is not None + dag = task.dag + + latest_execution_date = dag.latest_execution_date + assert latest_execution_date is not None + + # determine date range of dag runs and tasks to consider + end_date = latest_execution_date if future else execution_date + + if 'start_date' in dag.default_args: + start_date = dag.default_args['start_date'] + elif dag.start_date: + start_date = dag.start_date + else: + start_date = execution_date + + start_date = execution_date if not past else start_date + + if dag.schedule_interval == '@once': + dates = [start_date] + else: + dates = dag.date_range(start_date=start_date, end_date=end_date) + + # find relatives (siblings = downstream, parents = upstream) if needed + task_ids = [task.task_id] + if downstream: + relatives = task.get_flat_relatives(upstream=False) + task_ids += [t.task_id for t in relatives] + if upstream: + relatives = task.get_flat_relatives(upstream=True) + task_ids += [t.task_id for t in relatives] + + # verify the integrity of the dag runs in case a task was added or removed + # set the confirmed execution dates as they might be different + # from what was provided + confirmed_dates = [] + drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates) + for dr in drs: + dr.dag = dag + dr.verify_integrity() + confirmed_dates.append(dr.execution_date) + + # go through subdagoperators and create dag runs. We will only work + # within the scope of the subdag. We wont propagate to the parent dag, + # but we will propagate from parent to subdag. + session = Session() + dags = [dag] + sub_dag_ids = [] + while len(dags) > 0: + current_dag = dags.pop() + for task_id in task_ids: + if not current_dag.has_task(task_id): + continue + + current_task = current_dag.get_task(task_id) + if isinstance(current_task, SubDagOperator): + # this works as a kind of integrity check + # it creates missing dag runs for subdagoperators, + # maybe this should be moved to dagrun.verify_integrity + drs = _create_dagruns(current_task.subdag, + execution_dates=confirmed_dates, + state=State.RUNNING, + run_id_template=BackfillJob.ID_FORMAT_PREFIX) + + for dr in drs: + dr.dag = current_task.subdag + dr.verify_integrity() + if commit: + dr.state = state + session.merge(dr) + + dags.append(current_task.subdag) + sub_dag_ids.append(current_task.subdag.dag_id) + + # now look for the task instances that are affected + TI = TaskInstance + + # get all tasks of the main dag that will be affected by a state change + qry_dag = session.query(TI).filter( + TI.dag_id==dag.dag_id, + TI.execution_date.in_(confirmed_dates), + TI.task_id.in_(task_ids)).filter( + or_(TI.state.is_(None), + TI.state != state) + ) + + # get *all* tasks of the sub dags + if len(sub_dag_ids) > 0: + qry_sub_dag = session.query(TI).filter( + TI.dag_id.in_(sub_dag_ids), + TI.execution_date.in_(confirmed_dates)).filter( + or_(TI.state.is_(None), + TI.state != state) + ) + + if commit: + tis_altered = qry_dag.with_for_update().all() + if len(sub_dag_ids) > 0: + tis_altered += qry_sub_dag.with_for_update().all() + for ti in tis_altered: + ti.state = state + session.commit() + else: + tis_altered = qry_dag.all() + if len(sub_dag_ids) > 0: + tis_altered += qry_sub_dag.all() + + session.close() + + return tis_altered + http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/jobs.py ---------------------------------------------------------------------- diff --git a/airflow/jobs.py b/airflow/jobs.py index 1362814..3ca0070 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -1632,6 +1632,8 @@ class BackfillJob(BaseJob): triggers a set of task instance runs, in the right order and lasts for as long as it takes for the set of task instance to be completed. """ + ID_PREFIX = 'backfill_' + ID_FORMAT_PREFIX = ID_PREFIX + '{0}' __mapper_args__ = { 'polymorphic_identity': 'BackfillJob' @@ -1716,7 +1718,7 @@ class BackfillJob(BaseJob): active_dag_runs = [] while next_run_date and next_run_date <= end_date: - run_id = 'backfill_' + next_run_date.isoformat() + run_id = BackfillJob.ID_FORMAT_PREFIX.format(next_run_date.isoformat()) # check if we are scheduling on top of a already existing dag_run # we could find a "scheduled" run instead of a "backfill" http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index b9af58e..ba8d051 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -2317,6 +2317,7 @@ class BaseOperator(object): qry = qry.filter(TI.task_id.in_(tasks)) count = qry.count() + clear_task_instances(qry, session) session.commit() @@ -2931,13 +2932,11 @@ class DAG(BaseDag, LoggingMixin): @property def latest_execution_date(self): """ - Returns the latest date for which at least one task instance exists + Returns the latest date for which at least one dag run exists """ - TI = TaskInstance session = settings.Session() - execution_date = session.query(func.max(TI.execution_date)).filter( - TI.dag_id == self.dag_id, - TI.task_id.in_(self.task_ids) + execution_date = session.query(func.max(DagRun.execution_date)).filter( + DagRun.dag_id == self.dag_id ).scalar() session.commit() session.close() @@ -3330,7 +3329,7 @@ class DAG(BaseDag, LoggingMixin): # add a placeholder row into DagStat table if not session.query(DagStat).filter(DagStat.dag_id == self.dag_id).first(): - session.add(DagStat(dag_id=self.dag_id, state=State.RUNNING, count=0, dirty=True)) + session.add(DagStat(dag_id=self.dag_id, state=state, count=0, dirty=True)) session.commit() return run @@ -3801,6 +3800,8 @@ class DagRun(Base): def set_state(self, state): if self._state != state: self._state = state + # something really weird goes on here: if you try to close the session + # dag runs will end up detached session = settings.Session() DagStat.set_dirty(self.dag_id, session=session) @@ -3859,7 +3860,10 @@ class DagRun(Base): if run_id: qry = qry.filter(DR.run_id == run_id) if execution_date: - qry = qry.filter(DR.execution_date == execution_date) + if isinstance(execution_date, list): + qry = qry.filter(DR.execution_date.in_(execution_date)) + else: + qry = qry.filter(DR.execution_date == execution_date) if state: qry = qry.filter(DR.state == state) if external_trigger: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/www/templates/airflow/dag.html ---------------------------------------------------------------------- diff --git a/airflow/www/templates/airflow/dag.html b/airflow/www/templates/airflow/dag.html index b9b1afa..8a4793d 100644 --- a/airflow/www/templates/airflow/dag.html +++ b/airflow/www/templates/airflow/dag.html @@ -206,10 +206,6 @@ type="button" class="btn" data-toggle="button"> Downstream </button> - <button id="btn_success_recursive" - type="button" class="btn" data-toggle="button"> - Recursive - </button> </span> </div> <div class="modal-footer"> @@ -340,7 +336,6 @@ function updateQueryStringParameter(uri, key, value) { "&downstream=" + $('#btn_success_downstream').hasClass('active') + "&future=" + $('#btn_success_future').hasClass('active') + "&past=" + $('#btn_success_past').hasClass('active') + - "&recursive=" + $('#btn_success_recursive').hasClass('active') + "&execution_date=" + execution_date + "&origin=" + encodeURIComponent(window.location); http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/www/views.py ---------------------------------------------------------------------- diff --git a/airflow/www/views.py b/airflow/www/views.py index b80d83e..b98bd74 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -23,7 +23,6 @@ from functools import wraps from datetime import datetime, timedelta import dateutil.parser import copy -from itertools import chain, product import json import inspect @@ -1085,6 +1084,7 @@ class Airflow(BaseView): origin = request.args.get('origin') dag = dagbag.get_dag(dag_id) task = dag.get_task(task_id) + task.dag = dag execution_date = request.args.get('execution_date') execution_date = dateutil.parser.parse(execution_date) @@ -1093,110 +1093,39 @@ class Airflow(BaseView): downstream = request.args.get('downstream') == "true" future = request.args.get('future') == "true" past = request.args.get('past') == "true" - recursive = request.args.get('recursive') == "true" - MAX_PERIODS = 5000 - # Flagging tasks as successful - session = settings.Session() - task_ids = [task_id] - dag_ids = [dag_id] - task_id_to_dag = { - task_id: dag - } - end_date = ((dag.latest_execution_date or datetime.now()) - if future else execution_date) + if not dag: + flash("Cannot find DAG: {}".format(dag_id)) + return redirect(origin) - if 'start_date' in dag.default_args: - start_date = dag.default_args['start_date'] - elif dag.start_date: - start_date = dag.start_date - else: - start_date = execution_date + if not task: + flash("Cannot find task {} in DAG {}".format(task_id, dag.dag_id)) + return redirect(origin) - start_date = execution_date if not past else start_date + from airflow.api.common.experimental.mark_tasks import set_state - if recursive: - recurse_tasks(task, task_ids, dag_ids, task_id_to_dag) - - if downstream: - relatives = task.get_flat_relatives(upstream=False) - task_ids += [t.task_id for t in relatives] - if recursive: - recurse_tasks(relatives, task_ids, dag_ids, task_id_to_dag) - if upstream: - relatives = task.get_flat_relatives(upstream=False) - task_ids += [t.task_id for t in relatives] - if recursive: - recurse_tasks(relatives, task_ids, dag_ids, task_id_to_dag) - TI = models.TaskInstance + if confirmed: + altered = set_state(task=task, execution_date=execution_date, + upstream=upstream, downstream=downstream, + future=future, past=past, state=State.SUCCESS, + commit=True) - if dag.schedule_interval == '@once': - dates = [start_date] - else: - dates = dag.date_range(start_date, end_date=end_date) - - tis = session.query(TI).filter( - TI.dag_id.in_(dag_ids), - TI.execution_date.in_(dates), - TI.task_id.in_(task_ids)).all() - tis_to_change = session.query(TI).filter( - TI.dag_id.in_(dag_ids), - TI.execution_date.in_(dates), - TI.task_id.in_(task_ids), - TI.state != State.SUCCESS).all() - tasks = list(product(task_ids, dates)) - tis_to_create = list( - set(tasks) - - set([(ti.task_id, ti.execution_date) for ti in tis])) - - tis_all_altered = list(chain( - [(ti.task_id, ti.execution_date) for ti in tis_to_change], - tis_to_create)) - - if len(tis_all_altered) > MAX_PERIODS: - flash("Too many tasks at once (>{0})".format( - MAX_PERIODS), 'error') + flash("Marked success on {} task instances".format(len(altered))) return redirect(origin) - if confirmed: - for ti in tis_to_change: - ti.state = State.SUCCESS - session.commit() + to_be_altered = set_state(task=task, execution_date=execution_date, + upstream=upstream, downstream=downstream, + future=future, past=past, state=State.SUCCESS, + commit=False) - for task_id, task_execution_date in tis_to_create: - ti = TI( - task=task_id_to_dag[task_id].get_task(task_id), - execution_date=task_execution_date, - state=State.SUCCESS) - session.add(ti) - session.commit() + details = "\n".join([str(t) for t in to_be_altered]) - session.commit() - session.close() - flash("Marked success on {} task instances".format( - len(tis_all_altered))) - - return redirect(origin) - else: - if not tis_all_altered: - flash("No task instances to mark as successful", 'error') - response = redirect(origin) - else: - tis = [] - for task_id, task_execution_date in tis_all_altered: - tis.append(TI( - task=task_id_to_dag[task_id].get_task(task_id), - execution_date=task_execution_date, - state=State.SUCCESS)) - details = "\n".join([str(t) for t in tis]) + response = self.render("airflow/confirm.html", + message=("Here's the list of task instances you are " + "about to mark as successful:"), + details=details) - response = self.render( - 'airflow/confirm.html', - message=( - "Here's the list of task instances you are about " - "to mark as successful:"), - details=details,) - return response + return response @expose('/tree') @login_required http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/api/__init__.py ---------------------------------------------------------------------- diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 2db97ad..37d59f0 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -15,3 +15,5 @@ from __future__ import absolute_import from .client import * +from .common import * + http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/api/common/__init__.py ---------------------------------------------------------------------- diff --git a/tests/api/common/__init__.py b/tests/api/common/__init__.py new file mode 100644 index 0000000..9d7677a --- /dev/null +++ b/tests/api/common/__init__.py @@ -0,0 +1,13 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/api/common/mark_tasks.py ---------------------------------------------------------------------- diff --git a/tests/api/common/mark_tasks.py b/tests/api/common/mark_tasks.py new file mode 100644 index 0000000..e01f3ad --- /dev/null +++ b/tests/api/common/mark_tasks.py @@ -0,0 +1,211 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import unittest + +from airflow import models +from airflow.api.common.experimental.mark_tasks import set_state, _create_dagruns +from airflow.settings import Session +from airflow.utils.dates import days_ago +from airflow.utils.state import State + + +DEV_NULL = "/dev/null" + + +class TestMarkTasks(unittest.TestCase): + def setUp(self): + self.dagbag = models.DagBag(include_examples=True) + self.dag1 = self.dagbag.dags['test_example_bash_operator'] + self.dag2 = self.dagbag.dags['example_subdag_operator'] + + self.execution_dates = [days_ago(2), days_ago(1)] + + drs = _create_dagruns(self.dag1, self.execution_dates, + state=State.RUNNING, + run_id_template="scheduled__{}") + for dr in drs: + dr.dag = self.dag1 + dr.verify_integrity() + + drs = _create_dagruns(self.dag2, + [self.dag2.default_args['start_date']], + state=State.RUNNING, + run_id_template="scheduled__{}") + + for dr in drs: + dr.dag = self.dag2 + dr.verify_integrity() + + self.session = Session() + + def snapshot_state(self, dag, execution_dates): + TI = models.TaskInstance + tis = self.session.query(TI).filter( + TI.dag_id==dag.dag_id, + TI.execution_date.in_(execution_dates) + ).all() + + self.session.expunge_all() + + return tis + + def verify_state(self, dag, task_ids, execution_dates, state, old_tis): + TI = models.TaskInstance + + tis = self.session.query(TI).filter( + TI.dag_id==dag.dag_id, + TI.execution_date.in_(execution_dates) + ).all() + + self.assertTrue(len(tis) > 0) + + for ti in tis: + if ti.task_id in task_ids and ti.execution_date in execution_dates: + self.assertEqual(ti.state, state) + else: + for old_ti in old_tis: + if (old_ti.task_id == ti.task_id + and old_ti.execution_date == ti.execution_date): + self.assertEqual(ti.state, old_ti.state) + + def test_mark_tasks_now(self): + # set one task to success but do not commit + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=False) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + None, snapshot) + + # set one and only one task to success + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.SUCCESS, snapshot) + + # set no tasks + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 0) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.SUCCESS, snapshot) + + # set task to other than success + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.FAILED, commit=True) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.FAILED, snapshot) + + # dont alter other tasks + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_0") + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 1) + self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]], + State.SUCCESS, snapshot) + + def test_mark_downstream(self): + # test downstream + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + relatives = task.get_flat_relatives(upstream=False) + task_ids = [t.task_id for t in relatives] + task_ids.append(task.task_id) + + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=True, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 3) + self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], + State.SUCCESS, snapshot) + + def test_mark_upstream(self): + # test upstream + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("run_after_loop") + relatives = task.get_flat_relatives(upstream=True) + task_ids = [t.task_id for t in relatives] + task_ids.append(task.task_id) + + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=True, downstream=False, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 4) + self.verify_state(self.dag1, task_ids, [self.execution_dates[0]], + State.SUCCESS, snapshot) + + def test_mark_tasks_future(self): + # set one task to success towards end of scheduled dag runs + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=False, future=True, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 2) + self.verify_state(self.dag1, [task.task_id], self.execution_dates, + State.SUCCESS, snapshot) + + def test_mark_tasks_past(self): + # set one task to success towards end of scheduled dag runs + snapshot = self.snapshot_state(self.dag1, self.execution_dates) + task = self.dag1.get_task("runme_1") + altered = set_state(task=task, execution_date=self.execution_dates[1], + upstream=False, downstream=False, future=False, + past=True, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 2) + self.verify_state(self.dag1, [task.task_id], self.execution_dates, + State.SUCCESS, snapshot) + + def test_mark_tasks_subdag(self): + # set one task to success towards end of scheduled dag runs + task = self.dag2.get_task("section-1") + relatives = task.get_flat_relatives(upstream=False) + task_ids = [t.task_id for t in relatives] + task_ids.append(task.task_id) + + altered = set_state(task=task, execution_date=self.execution_dates[0], + upstream=False, downstream=True, future=False, + past=False, state=State.SUCCESS, commit=True) + self.assertEqual(len(altered), 14) + + # cannot use snapshot here as that will require drilling down the + # the sub dag tree essentially recreating the same code as in the + # tested logic. + self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], + State.SUCCESS, []) + + def tearDown(self): + self.dag1.clear() + self.dag2.clear() + + # just to make sure we are fully cleaned up + self.session.query(models.DagRun).delete() + self.session.query(models.TaskInstance).delete() + self.session.commit() + + self.session.close() + +if __name__ == '__main__': + unittest.main() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/core.py ---------------------------------------------------------------------- diff --git a/tests/core.py b/tests/core.py index 0f7e41d..e35809d 100644 --- a/tests/core.py +++ b/tests/core.py @@ -1374,6 +1374,7 @@ class CliTests(unittest.TestCase): os.remove('variables1.json') os.remove('variables2.json') + class WebUiTests(unittest.TestCase): def setUp(self): configuration.load_test_config() @@ -1383,11 +1384,26 @@ class WebUiTests(unittest.TestCase): app.config['TESTING'] = True self.app = app.test_client() - self.dagbag = models.DagBag( - dag_folder=DEV_NULL, include_examples=True) + self.dagbag = models.DagBag(include_examples=True) self.dag_bash = self.dagbag.dags['example_bash_operator'] + self.dag_bash2 = self.dagbag.dags['test_example_bash_operator'] + self.sub_dag = self.dagbag.dags['example_subdag_operator'] self.runme_0 = self.dag_bash.get_task('runme_0') + self.dag_bash2.create_dagrun( + run_id="test_{}".format(models.DagRun.id_for_date(datetime.now())), + execution_date=DEFAULT_DATE, + start_date=datetime.now(), + state=State.RUNNING + ) + + self.sub_dag.create_dagrun( + run_id="test_{}".format(models.DagRun.id_for_date(datetime.now())), + execution_date=DEFAULT_DATE, + start_date=datetime.now(), + state=State.RUNNING + ) + def test_index(self): response = self.app.get('/', follow_redirects=True) assert "DAGs" in response.data.decode('utf-8') @@ -1470,7 +1486,7 @@ class WebUiTests(unittest.TestCase): assert "example_bash_operator" in response.data.decode('utf-8') url = ( "/admin/airflow/success?task_id=run_this_last&" - "dag_id=example_bash_operator&upstream=false&downstream=false&" + "dag_id=test_example_bash_operator&upstream=false&downstream=false&" "future=false&past=false&execution_date={}&" "origin=/admin".format(DEFAULT_DATE_DS)) response = self.app.get(url) @@ -1478,7 +1494,7 @@ class WebUiTests(unittest.TestCase): response = self.app.get(url + "&confirmed=true") response = self.app.get( '/admin/airflow/clear?task_id=run_this_last&' - 'dag_id=example_bash_operator&future=true&past=false&' + 'dag_id=test_example_bash_operator&future=true&past=false&' 'upstream=true&downstream=false&' 'execution_date={}&' 'origin=/admin'.format(DEFAULT_DATE_DS)) @@ -1486,7 +1502,7 @@ class WebUiTests(unittest.TestCase): url = ( "/admin/airflow/success?task_id=section-1&" "dag_id=example_subdag_operator&upstream=true&downstream=true&" - "recursive=true&future=false&past=false&execution_date={}&" + "future=false&past=false&execution_date={}&" "origin=/admin".format(DEFAULT_DATE_DS)) response = self.app.get(url) assert "Wait a minute" in response.data.decode('utf-8') @@ -1498,7 +1514,7 @@ class WebUiTests(unittest.TestCase): response = self.app.get(url + "&confirmed=true") url = ( "/admin/airflow/clear?task_id=runme_1&" - "dag_id=example_bash_operator&future=false&past=false&" + "dag_id=test_example_bash_operator&future=false&past=false&" "upstream=false&downstream=true&" "execution_date={}&" "origin=/admin".format(DEFAULT_DATE_DS)) @@ -1542,23 +1558,19 @@ class WebUiTests(unittest.TestCase): def test_fetch_task_instance(self): url = ( "/admin/airflow/object/task_instances?" - "dag_id=example_bash_operator&" + "dag_id=test_example_bash_operator&" "execution_date={}".format(DEFAULT_DATE_DS)) response = self.app.get(url) - assert "{}" in response.data.decode('utf-8') - - TI = models.TaskInstance - ti = TI( - task=self.runme_0, execution_date=DEFAULT_DATE) - job = jobs.LocalTaskJob(task_instance=ti, ignore_ti_state=True) - job.run() - - response = self.app.get(url) - assert "runme_0" in response.data.decode('utf-8') + self.assertIn("run_this_last", response.data.decode('utf-8')) def tearDown(self): configuration.conf.set("webserver", "expose_config", "False") self.dag_bash.clear(start_date=DEFAULT_DATE, end_date=datetime.now()) + session = Session() + session.query(models.DagRun).delete() + session.query(models.TaskInstance).delete() + session.commit() + session.close() class WebPasswordAuthTest(unittest.TestCase): http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/dags/test_example_bash_operator.py ---------------------------------------------------------------------- diff --git a/tests/dags/test_example_bash_operator.py b/tests/dags/test_example_bash_operator.py new file mode 100644 index 0000000..ad03353 --- /dev/null +++ b/tests/dags/test_example_bash_operator.py @@ -0,0 +1,55 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import airflow +from builtins import range +from airflow.operators.bash_operator import BashOperator +from airflow.operators.dummy_operator import DummyOperator +from airflow.models import DAG +from datetime import timedelta + + +args = { + 'owner': 'airflow', + 'start_date': airflow.utils.dates.days_ago(2) +} + +dag = DAG( + dag_id='test_example_bash_operator', default_args=args, + schedule_interval='0 0 * * *', + dagrun_timeout=timedelta(minutes=60)) + +cmd = 'ls -l' +run_this_last = DummyOperator(task_id='run_this_last', dag=dag) + +run_this = BashOperator( + task_id='run_after_loop', bash_command='echo 1', dag=dag) +run_this.set_downstream(run_this_last) + +for i in range(3): + i = str(i) + task = BashOperator( + task_id='runme_'+i, + bash_command='echo "{{ task_instance_key_str }}" && sleep 1', + dag=dag) + task.set_downstream(run_this) + +task = BashOperator( + task_id='also_run_this', + bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"', + dag=dag) +task.set_downstream(run_this_last) + +if __name__ == "__main__": + dag.cli() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/models.py ---------------------------------------------------------------------- diff --git a/tests/models.py b/tests/models.py index 003fb21..868ea36 100644 --- a/tests/models.py +++ b/tests/models.py @@ -188,7 +188,7 @@ class DagBagTest(unittest.TestCase): class TestDagBag(models.DagBag): process_file_calls = 0 def process_file(self, filepath, only_if_updated=True, safe_mode=True): - if 'example_bash_operator.py' in filepath: + if 'example_bash_operator.py' == os.path.basename(filepath): TestDagBag.process_file_calls += 1 super(TestDagBag, self).process_file(filepath, only_if_updated, safe_mode)
