[AIRFLOW-869] Refactor mark success functionality

This refactors the mark success functionality in a
more generic function that can set multiple states
and properly drills down on SubDags.

Closes #2085 from bolkedebruin/AIRFLOW-869

(cherry picked from commit 28cfd2c541c12468b3e4f634545dfa31a77b0091)
Signed-off-by: Bolke de Bruin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/563cc9a3
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/563cc9a3
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/563cc9a3

Branch: refs/heads/v1-8-stable
Commit: 563cc9a3c8414725a615a93d3910e7a2dbb94999
Parents: eddecd5
Author: Bolke de Bruin <[email protected]>
Authored: Fri Feb 17 09:05:41 2017 +0100
Committer: Bolke de Bruin <[email protected]>
Committed: Fri Feb 17 09:11:41 2017 +0100

----------------------------------------------------------------------
 airflow/api/common/experimental/mark_tasks.py | 187 ++++++++++++++++++
 airflow/jobs.py                               |   4 +-
 airflow/models.py                             |  18 +-
 airflow/www/templates/airflow/dag.html        |   5 -
 airflow/www/views.py                          | 119 +++---------
 tests/api/__init__.py                         |   2 +
 tests/api/common/__init__.py                  |  13 ++
 tests/api/common/mark_tasks.py                | 211 +++++++++++++++++++++
 tests/core.py                                 |  46 +++--
 tests/dags/test_example_bash_operator.py      |  55 ++++++
 tests/models.py                               |   2 +-
 11 files changed, 536 insertions(+), 126 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/api/common/experimental/mark_tasks.py
----------------------------------------------------------------------
diff --git a/airflow/api/common/experimental/mark_tasks.py 
b/airflow/api/common/experimental/mark_tasks.py
new file mode 100644
index 0000000..0ddbf98
--- /dev/null
+++ b/airflow/api/common/experimental/mark_tasks.py
@@ -0,0 +1,187 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+
+from airflow.jobs import BackfillJob
+from airflow.models import DagRun, TaskInstance
+from airflow.operators.subdag_operator import SubDagOperator
+from airflow.settings import Session
+from airflow.utils.state import State
+
+from sqlalchemy import or_
+
+
+def _create_dagruns(dag, execution_dates, state, run_id_template):
+    """
+    Infers from the dates which dag runs need to be created and does so.
+    :param dag: the dag to create dag runs for
+    :param execution_dates: list of execution dates to evaluate
+    :param state: the state to set the dag run to
+    :param run_id_template:the template for run id to be with the execution 
date
+    :return: newly created and existing dag runs for the execution dates 
supplied
+    """
+    # find out if we need to create any dag runs
+    drs = DagRun.find(dag_id=dag.dag_id, execution_date=execution_dates)
+    dates_to_create = list(set(execution_dates) - set([dr.execution_date for 
dr in drs]))
+
+    for date in dates_to_create:
+        dr = dag.create_dagrun(
+            run_id=run_id_template.format(date.isoformat()),
+            execution_date=date,
+            start_date=datetime.datetime.now(),
+            external_trigger=False,
+            state=state,
+        )
+        drs.append(dr)
+
+    return drs
+
+
+def set_state(task, execution_date, upstream=False, downstream=False,
+              future=False, past=False, state=State.SUCCESS, commit=False):
+    """
+    Set the state of a task instance and if needed its relatives. Can set state
+    for future tasks (calculated from execution_date) and retroactively
+    for past tasks. Will verify integrity of past dag runs in order to create
+    tasks that did not exist. It will not create dag runs that are missing
+    on the schedule (but it will as for subdag dag runs if needed).
+    :param task: the task from which to work. task.task.dag needs to be set
+    :param execution_date: the execution date from which to start looking
+    :param upstream: Mark all parents (upstream tasks)
+    :param downstream: Mark all siblings (downstream tasks) of task_id, 
including SubDags
+    :param future: Mark all future tasks on the interval of the dag up until
+        last execution date.
+    :param past: Retroactively mark all tasks starting from start_date of the 
DAG
+    :param state: State to which the tasks need to be set
+    :param commit: Commit tasks to be altered to the database
+    :return: list of tasks that have been created and updated
+    """
+    assert isinstance(execution_date, datetime.datetime)
+
+    # microseconds are supported by the database, but is not handled
+    # correctly by airflow on e.g. the filesystem and in other places
+    execution_date = execution_date.replace(microsecond=0)
+
+    assert task.dag is not None
+    dag = task.dag
+
+    latest_execution_date = dag.latest_execution_date
+    assert latest_execution_date is not None
+
+    # determine date range of dag runs and tasks to consider
+    end_date = latest_execution_date if future else execution_date
+
+    if 'start_date' in dag.default_args:
+        start_date = dag.default_args['start_date']
+    elif dag.start_date:
+        start_date = dag.start_date
+    else:
+        start_date = execution_date
+
+    start_date = execution_date if not past else start_date
+
+    if dag.schedule_interval == '@once':
+        dates = [start_date]
+    else:
+        dates = dag.date_range(start_date=start_date, end_date=end_date)
+
+    # find relatives (siblings = downstream, parents = upstream) if needed
+    task_ids = [task.task_id]
+    if downstream:
+        relatives = task.get_flat_relatives(upstream=False)
+        task_ids += [t.task_id for t in relatives]
+    if upstream:
+        relatives = task.get_flat_relatives(upstream=True)
+        task_ids += [t.task_id for t in relatives]
+
+    # verify the integrity of the dag runs in case a task was added or removed
+    # set the confirmed execution dates as they might be different
+    # from what was provided
+    confirmed_dates = []
+    drs = DagRun.find(dag_id=dag.dag_id, execution_date=dates)
+    for dr in drs:
+        dr.dag = dag
+        dr.verify_integrity()
+        confirmed_dates.append(dr.execution_date)
+
+    # go through subdagoperators and create dag runs. We will only work
+    # within the scope of the subdag. We wont propagate to the parent dag,
+    # but we will propagate from parent to subdag.
+    session = Session()
+    dags = [dag]
+    sub_dag_ids = []
+    while len(dags) > 0:
+        current_dag = dags.pop()
+        for task_id in task_ids:
+            if not current_dag.has_task(task_id):
+                continue
+
+            current_task = current_dag.get_task(task_id)
+            if isinstance(current_task, SubDagOperator):
+                # this works as a kind of integrity check
+                # it creates missing dag runs for subdagoperators,
+                # maybe this should be moved to dagrun.verify_integrity
+                drs = _create_dagruns(current_task.subdag,
+                                      execution_dates=confirmed_dates,
+                                      state=State.RUNNING,
+                                      
run_id_template=BackfillJob.ID_FORMAT_PREFIX)
+
+                for dr in drs:
+                    dr.dag = current_task.subdag
+                    dr.verify_integrity()
+                    if commit:
+                        dr.state = state
+                        session.merge(dr)
+
+                dags.append(current_task.subdag)
+                sub_dag_ids.append(current_task.subdag.dag_id)
+
+    # now look for the task instances that are affected
+    TI = TaskInstance
+
+    # get all tasks of the main dag that will be affected by a state change
+    qry_dag = session.query(TI).filter(
+        TI.dag_id==dag.dag_id,
+        TI.execution_date.in_(confirmed_dates),
+        TI.task_id.in_(task_ids)).filter(
+        or_(TI.state.is_(None),
+            TI.state != state)
+    )
+
+    # get *all* tasks of the sub dags
+    if len(sub_dag_ids) > 0:
+        qry_sub_dag = session.query(TI).filter(
+            TI.dag_id.in_(sub_dag_ids),
+            TI.execution_date.in_(confirmed_dates)).filter(
+            or_(TI.state.is_(None),
+                TI.state != state)
+        )
+
+    if commit:
+        tis_altered = qry_dag.with_for_update().all()
+        if len(sub_dag_ids) > 0:
+            tis_altered += qry_sub_dag.with_for_update().all()
+        for ti in tis_altered:
+            ti.state = state
+        session.commit()
+    else:
+        tis_altered = qry_dag.all()
+        if len(sub_dag_ids) > 0:
+            tis_altered += qry_sub_dag.all()
+
+    session.close()
+
+    return tis_altered
+

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 1362814..3ca0070 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -1632,6 +1632,8 @@ class BackfillJob(BaseJob):
     triggers a set of task instance runs, in the right order and lasts for
     as long as it takes for the set of task instance to be completed.
     """
+    ID_PREFIX = 'backfill_'
+    ID_FORMAT_PREFIX = ID_PREFIX + '{0}'
 
     __mapper_args__ = {
         'polymorphic_identity': 'BackfillJob'
@@ -1716,7 +1718,7 @@ class BackfillJob(BaseJob):
 
         active_dag_runs = []
         while next_run_date and next_run_date <= end_date:
-            run_id = 'backfill_' + next_run_date.isoformat()
+            run_id = 
BackfillJob.ID_FORMAT_PREFIX.format(next_run_date.isoformat())
 
             # check if we are scheduling on top of a already existing dag_run
             # we could find a "scheduled" run instead of a "backfill"

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index b9af58e..ba8d051 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -2317,6 +2317,7 @@ class BaseOperator(object):
         qry = qry.filter(TI.task_id.in_(tasks))
 
         count = qry.count()
+
         clear_task_instances(qry, session)
 
         session.commit()
@@ -2931,13 +2932,11 @@ class DAG(BaseDag, LoggingMixin):
     @property
     def latest_execution_date(self):
         """
-        Returns the latest date for which at least one task instance exists
+        Returns the latest date for which at least one dag run exists
         """
-        TI = TaskInstance
         session = settings.Session()
-        execution_date = session.query(func.max(TI.execution_date)).filter(
-            TI.dag_id == self.dag_id,
-            TI.task_id.in_(self.task_ids)
+        execution_date = session.query(func.max(DagRun.execution_date)).filter(
+            DagRun.dag_id == self.dag_id
         ).scalar()
         session.commit()
         session.close()
@@ -3330,7 +3329,7 @@ class DAG(BaseDag, LoggingMixin):
 
         # add a placeholder row into DagStat table
         if not session.query(DagStat).filter(DagStat.dag_id == 
self.dag_id).first():
-            session.add(DagStat(dag_id=self.dag_id, state=State.RUNNING, 
count=0, dirty=True))
+            session.add(DagStat(dag_id=self.dag_id, state=state, count=0, 
dirty=True))
         session.commit()
         return run
 
@@ -3801,6 +3800,8 @@ class DagRun(Base):
     def set_state(self, state):
         if self._state != state:
             self._state = state
+            # something really weird goes on here: if you try to close the 
session
+            # dag runs will end up detached
             session = settings.Session()
             DagStat.set_dirty(self.dag_id, session=session)
 
@@ -3859,7 +3860,10 @@ class DagRun(Base):
         if run_id:
             qry = qry.filter(DR.run_id == run_id)
         if execution_date:
-            qry = qry.filter(DR.execution_date == execution_date)
+            if isinstance(execution_date, list):
+                qry = qry.filter(DR.execution_date.in_(execution_date))
+            else:
+                qry = qry.filter(DR.execution_date == execution_date)
         if state:
             qry = qry.filter(DR.state == state)
         if external_trigger:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/www/templates/airflow/dag.html
----------------------------------------------------------------------
diff --git a/airflow/www/templates/airflow/dag.html 
b/airflow/www/templates/airflow/dag.html
index b9b1afa..8a4793d 100644
--- a/airflow/www/templates/airflow/dag.html
+++ b/airflow/www/templates/airflow/dag.html
@@ -206,10 +206,6 @@
               type="button" class="btn" data-toggle="button">
               Downstream
             </button>
-            <button id="btn_success_recursive"
-              type="button" class="btn" data-toggle="button">
-              Recursive
-            </button>
           </span>
         </div>
         <div class="modal-footer">
@@ -340,7 +336,6 @@ function updateQueryStringParameter(uri, key, value) {
         "&downstream=" + $('#btn_success_downstream').hasClass('active') +
         "&future=" + $('#btn_success_future').hasClass('active') +
         "&past=" + $('#btn_success_past').hasClass('active') +
-        "&recursive=" + $('#btn_success_recursive').hasClass('active') +
         "&execution_date=" + execution_date +
         "&origin=" + encodeURIComponent(window.location);
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/airflow/www/views.py
----------------------------------------------------------------------
diff --git a/airflow/www/views.py b/airflow/www/views.py
index b80d83e..b98bd74 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -23,7 +23,6 @@ from functools import wraps
 from datetime import datetime, timedelta
 import dateutil.parser
 import copy
-from itertools import chain, product
 import json
 
 import inspect
@@ -1085,6 +1084,7 @@ class Airflow(BaseView):
         origin = request.args.get('origin')
         dag = dagbag.get_dag(dag_id)
         task = dag.get_task(task_id)
+        task.dag = dag
 
         execution_date = request.args.get('execution_date')
         execution_date = dateutil.parser.parse(execution_date)
@@ -1093,110 +1093,39 @@ class Airflow(BaseView):
         downstream = request.args.get('downstream') == "true"
         future = request.args.get('future') == "true"
         past = request.args.get('past') == "true"
-        recursive = request.args.get('recursive') == "true"
-        MAX_PERIODS = 5000
 
-        # Flagging tasks as successful
-        session = settings.Session()
-        task_ids = [task_id]
-        dag_ids = [dag_id]
-        task_id_to_dag = {
-            task_id: dag
-        }
-        end_date = ((dag.latest_execution_date or datetime.now())
-                    if future else execution_date)
+        if not dag:
+            flash("Cannot find DAG: {}".format(dag_id))
+            return redirect(origin)
 
-        if 'start_date' in dag.default_args:
-            start_date = dag.default_args['start_date']
-        elif dag.start_date:
-            start_date = dag.start_date
-        else:
-            start_date = execution_date
+        if not task:
+            flash("Cannot find task {} in DAG {}".format(task_id, dag.dag_id))
+            return redirect(origin)
 
-        start_date = execution_date if not past else start_date
+        from airflow.api.common.experimental.mark_tasks import set_state
 
-        if recursive:
-            recurse_tasks(task, task_ids, dag_ids, task_id_to_dag)
-
-        if downstream:
-            relatives = task.get_flat_relatives(upstream=False)
-            task_ids += [t.task_id for t in relatives]
-            if recursive:
-                recurse_tasks(relatives, task_ids, dag_ids, task_id_to_dag)
-        if upstream:
-            relatives = task.get_flat_relatives(upstream=False)
-            task_ids += [t.task_id for t in relatives]
-            if recursive:
-                recurse_tasks(relatives, task_ids, dag_ids, task_id_to_dag)
-        TI = models.TaskInstance
+        if confirmed:
+            altered = set_state(task=task, execution_date=execution_date,
+                                upstream=upstream, downstream=downstream,
+                                future=future, past=past, state=State.SUCCESS,
+                                commit=True)
 
-        if dag.schedule_interval == '@once':
-            dates = [start_date]
-        else:
-            dates = dag.date_range(start_date, end_date=end_date)
-
-        tis = session.query(TI).filter(
-            TI.dag_id.in_(dag_ids),
-            TI.execution_date.in_(dates),
-            TI.task_id.in_(task_ids)).all()
-        tis_to_change = session.query(TI).filter(
-            TI.dag_id.in_(dag_ids),
-            TI.execution_date.in_(dates),
-            TI.task_id.in_(task_ids),
-            TI.state != State.SUCCESS).all()
-        tasks = list(product(task_ids, dates))
-        tis_to_create = list(
-            set(tasks) -
-            set([(ti.task_id, ti.execution_date) for ti in tis]))
-
-        tis_all_altered = list(chain(
-            [(ti.task_id, ti.execution_date) for ti in tis_to_change],
-            tis_to_create))
-
-        if len(tis_all_altered) > MAX_PERIODS:
-            flash("Too many tasks at once (>{0})".format(
-                MAX_PERIODS), 'error')
+            flash("Marked success on {} task instances".format(len(altered)))
             return redirect(origin)
 
-        if confirmed:
-            for ti in tis_to_change:
-                ti.state = State.SUCCESS
-            session.commit()
+        to_be_altered = set_state(task=task, execution_date=execution_date,
+                                  upstream=upstream, downstream=downstream,
+                                  future=future, past=past, 
state=State.SUCCESS,
+                                  commit=False)
 
-            for task_id, task_execution_date in tis_to_create:
-                ti = TI(
-                    task=task_id_to_dag[task_id].get_task(task_id),
-                    execution_date=task_execution_date,
-                    state=State.SUCCESS)
-                session.add(ti)
-                session.commit()
+        details = "\n".join([str(t) for t in to_be_altered])
 
-            session.commit()
-            session.close()
-            flash("Marked success on {} task instances".format(
-                len(tis_all_altered)))
-
-            return redirect(origin)
-        else:
-            if not tis_all_altered:
-                flash("No task instances to mark as successful", 'error')
-                response = redirect(origin)
-            else:
-                tis = []
-                for task_id, task_execution_date in tis_all_altered:
-                    tis.append(TI(
-                        task=task_id_to_dag[task_id].get_task(task_id),
-                        execution_date=task_execution_date,
-                        state=State.SUCCESS))
-                details = "\n".join([str(t) for t in tis])
+        response = self.render("airflow/confirm.html",
+                               message=("Here's the list of task instances you 
are "
+                                        "about to mark as successful:"),
+                               details=details)
 
-                response = self.render(
-                    'airflow/confirm.html',
-                    message=(
-                        "Here's the list of task instances you are about "
-                        "to mark as successful:"),
-                    details=details,)
-            return response
+        return response
 
     @expose('/tree')
     @login_required

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/api/__init__.py
----------------------------------------------------------------------
diff --git a/tests/api/__init__.py b/tests/api/__init__.py
index 2db97ad..37d59f0 100644
--- a/tests/api/__init__.py
+++ b/tests/api/__init__.py
@@ -15,3 +15,5 @@
 from __future__ import absolute_import
 
 from .client import *
+from .common import *
+

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/api/common/__init__.py
----------------------------------------------------------------------
diff --git a/tests/api/common/__init__.py b/tests/api/common/__init__.py
new file mode 100644
index 0000000..9d7677a
--- /dev/null
+++ b/tests/api/common/__init__.py
@@ -0,0 +1,13 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/api/common/mark_tasks.py
----------------------------------------------------------------------
diff --git a/tests/api/common/mark_tasks.py b/tests/api/common/mark_tasks.py
new file mode 100644
index 0000000..e01f3ad
--- /dev/null
+++ b/tests/api/common/mark_tasks.py
@@ -0,0 +1,211 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow import models
+from airflow.api.common.experimental.mark_tasks import set_state, 
_create_dagruns
+from airflow.settings import Session
+from airflow.utils.dates import days_ago
+from airflow.utils.state import State
+
+
+DEV_NULL = "/dev/null"
+
+
+class TestMarkTasks(unittest.TestCase):
+    def setUp(self):
+        self.dagbag = models.DagBag(include_examples=True)
+        self.dag1 = self.dagbag.dags['test_example_bash_operator']
+        self.dag2 = self.dagbag.dags['example_subdag_operator']
+
+        self.execution_dates = [days_ago(2), days_ago(1)]
+
+        drs = _create_dagruns(self.dag1, self.execution_dates,
+                              state=State.RUNNING,
+                              run_id_template="scheduled__{}")
+        for dr in drs:
+            dr.dag = self.dag1
+            dr.verify_integrity()
+
+        drs = _create_dagruns(self.dag2,
+                              [self.dag2.default_args['start_date']],
+                              state=State.RUNNING,
+                              run_id_template="scheduled__{}")
+
+        for dr in drs:
+            dr.dag = self.dag2
+            dr.verify_integrity()
+
+        self.session = Session()
+
+    def snapshot_state(self, dag, execution_dates):
+        TI = models.TaskInstance
+        tis = self.session.query(TI).filter(
+            TI.dag_id==dag.dag_id,
+            TI.execution_date.in_(execution_dates)
+        ).all()
+
+        self.session.expunge_all()
+
+        return tis
+
+    def verify_state(self, dag, task_ids, execution_dates, state, old_tis):
+        TI = models.TaskInstance
+
+        tis = self.session.query(TI).filter(
+            TI.dag_id==dag.dag_id,
+            TI.execution_date.in_(execution_dates)
+        ).all()
+
+        self.assertTrue(len(tis) > 0)
+
+        for ti in tis:
+            if ti.task_id in task_ids and ti.execution_date in execution_dates:
+                self.assertEqual(ti.state, state)
+            else:
+                for old_ti in old_tis:
+                    if (old_ti.task_id == ti.task_id
+                            and old_ti.execution_date == ti.execution_date):
+                            self.assertEqual(ti.state, old_ti.state)
+
+    def test_mark_tasks_now(self):
+        # set one task to success but do not commit
+        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        task = self.dag1.get_task("runme_1")
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=False, future=False,
+                            past=False, state=State.SUCCESS, commit=False)
+        self.assertEqual(len(altered), 1)
+        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+                          None, snapshot)
+
+        # set one and only one task to success
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=False, future=False,
+                            past=False, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 1)
+        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+                          State.SUCCESS, snapshot)
+
+        # set no tasks
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=False, future=False,
+                            past=False, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 0)
+        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+                          State.SUCCESS, snapshot)
+
+        # set task to other than success
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=False, future=False,
+                            past=False, state=State.FAILED, commit=True)
+        self.assertEqual(len(altered), 1)
+        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+                          State.FAILED, snapshot)
+
+        # dont alter other tasks
+        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        task = self.dag1.get_task("runme_0")
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=False, future=False,
+                            past=False, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 1)
+        self.verify_state(self.dag1, [task.task_id], [self.execution_dates[0]],
+                          State.SUCCESS, snapshot)
+
+    def test_mark_downstream(self):
+        # test downstream
+        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        task = self.dag1.get_task("runme_1")
+        relatives = task.get_flat_relatives(upstream=False)
+        task_ids = [t.task_id for t in relatives]
+        task_ids.append(task.task_id)
+
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=True, future=False,
+                            past=False, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 3)
+        self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
+                          State.SUCCESS, snapshot)
+
+    def test_mark_upstream(self):
+        # test upstream
+        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        task = self.dag1.get_task("run_after_loop")
+        relatives = task.get_flat_relatives(upstream=True)
+        task_ids = [t.task_id for t in relatives]
+        task_ids.append(task.task_id)
+
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=True, downstream=False, future=False,
+                            past=False, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 4)
+        self.verify_state(self.dag1, task_ids, [self.execution_dates[0]],
+                          State.SUCCESS, snapshot)
+
+    def test_mark_tasks_future(self):
+        # set one task to success towards end of scheduled dag runs
+        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        task = self.dag1.get_task("runme_1")
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=False, future=True,
+                            past=False, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 2)
+        self.verify_state(self.dag1, [task.task_id], self.execution_dates,
+                          State.SUCCESS, snapshot)
+
+    def test_mark_tasks_past(self):
+        # set one task to success towards end of scheduled dag runs
+        snapshot = self.snapshot_state(self.dag1, self.execution_dates)
+        task = self.dag1.get_task("runme_1")
+        altered = set_state(task=task, execution_date=self.execution_dates[1],
+                            upstream=False, downstream=False, future=False,
+                            past=True, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 2)
+        self.verify_state(self.dag1, [task.task_id], self.execution_dates,
+                          State.SUCCESS, snapshot)
+
+    def test_mark_tasks_subdag(self):
+        # set one task to success towards end of scheduled dag runs
+        task = self.dag2.get_task("section-1")
+        relatives = task.get_flat_relatives(upstream=False)
+        task_ids = [t.task_id for t in relatives]
+        task_ids.append(task.task_id)
+
+        altered = set_state(task=task, execution_date=self.execution_dates[0],
+                            upstream=False, downstream=True, future=False,
+                            past=False, state=State.SUCCESS, commit=True)
+        self.assertEqual(len(altered), 14)
+
+        # cannot use snapshot here as that will require drilling down the
+        # the sub dag tree essentially recreating the same code as in the
+        # tested logic.
+        self.verify_state(self.dag2, task_ids, [self.execution_dates[0]],
+                          State.SUCCESS, [])
+
+    def tearDown(self):
+        self.dag1.clear()
+        self.dag2.clear()
+
+        # just to make sure we are fully cleaned up
+        self.session.query(models.DagRun).delete()
+        self.session.query(models.TaskInstance).delete()
+        self.session.commit()
+
+        self.session.close()
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/core.py
----------------------------------------------------------------------
diff --git a/tests/core.py b/tests/core.py
index 0f7e41d..e35809d 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -1374,6 +1374,7 @@ class CliTests(unittest.TestCase):
         os.remove('variables1.json')
         os.remove('variables2.json')
 
+
 class WebUiTests(unittest.TestCase):
     def setUp(self):
         configuration.load_test_config()
@@ -1383,11 +1384,26 @@ class WebUiTests(unittest.TestCase):
         app.config['TESTING'] = True
         self.app = app.test_client()
 
-        self.dagbag = models.DagBag(
-            dag_folder=DEV_NULL, include_examples=True)
+        self.dagbag = models.DagBag(include_examples=True)
         self.dag_bash = self.dagbag.dags['example_bash_operator']
+        self.dag_bash2 = self.dagbag.dags['test_example_bash_operator']
+        self.sub_dag = self.dagbag.dags['example_subdag_operator']
         self.runme_0 = self.dag_bash.get_task('runme_0')
 
+        self.dag_bash2.create_dagrun(
+            run_id="test_{}".format(models.DagRun.id_for_date(datetime.now())),
+            execution_date=DEFAULT_DATE,
+            start_date=datetime.now(),
+            state=State.RUNNING
+        )
+
+        self.sub_dag.create_dagrun(
+            run_id="test_{}".format(models.DagRun.id_for_date(datetime.now())),
+            execution_date=DEFAULT_DATE,
+            start_date=datetime.now(),
+            state=State.RUNNING
+        )
+
     def test_index(self):
         response = self.app.get('/', follow_redirects=True)
         assert "DAGs" in response.data.decode('utf-8')
@@ -1470,7 +1486,7 @@ class WebUiTests(unittest.TestCase):
         assert "example_bash_operator" in response.data.decode('utf-8')
         url = (
             "/admin/airflow/success?task_id=run_this_last&"
-            "dag_id=example_bash_operator&upstream=false&downstream=false&"
+            
"dag_id=test_example_bash_operator&upstream=false&downstream=false&"
             "future=false&past=false&execution_date={}&"
             "origin=/admin".format(DEFAULT_DATE_DS))
         response = self.app.get(url)
@@ -1478,7 +1494,7 @@ class WebUiTests(unittest.TestCase):
         response = self.app.get(url + "&confirmed=true")
         response = self.app.get(
             '/admin/airflow/clear?task_id=run_this_last&'
-            'dag_id=example_bash_operator&future=true&past=false&'
+            'dag_id=test_example_bash_operator&future=true&past=false&'
             'upstream=true&downstream=false&'
             'execution_date={}&'
             'origin=/admin'.format(DEFAULT_DATE_DS))
@@ -1486,7 +1502,7 @@ class WebUiTests(unittest.TestCase):
         url = (
             "/admin/airflow/success?task_id=section-1&"
             "dag_id=example_subdag_operator&upstream=true&downstream=true&"
-            "recursive=true&future=false&past=false&execution_date={}&"
+            "future=false&past=false&execution_date={}&"
             "origin=/admin".format(DEFAULT_DATE_DS))
         response = self.app.get(url)
         assert "Wait a minute" in response.data.decode('utf-8')
@@ -1498,7 +1514,7 @@ class WebUiTests(unittest.TestCase):
         response = self.app.get(url + "&confirmed=true")
         url = (
             "/admin/airflow/clear?task_id=runme_1&"
-            "dag_id=example_bash_operator&future=false&past=false&"
+            "dag_id=test_example_bash_operator&future=false&past=false&"
             "upstream=false&downstream=true&"
             "execution_date={}&"
             "origin=/admin".format(DEFAULT_DATE_DS))
@@ -1542,23 +1558,19 @@ class WebUiTests(unittest.TestCase):
     def test_fetch_task_instance(self):
         url = (
             "/admin/airflow/object/task_instances?"
-            "dag_id=example_bash_operator&"
+            "dag_id=test_example_bash_operator&"
             "execution_date={}".format(DEFAULT_DATE_DS))
         response = self.app.get(url)
-        assert "{}" in response.data.decode('utf-8')
-
-        TI = models.TaskInstance
-        ti = TI(
-            task=self.runme_0, execution_date=DEFAULT_DATE)
-        job = jobs.LocalTaskJob(task_instance=ti, ignore_ti_state=True)
-        job.run()
-
-        response = self.app.get(url)
-        assert "runme_0" in response.data.decode('utf-8')
+        self.assertIn("run_this_last", response.data.decode('utf-8'))
 
     def tearDown(self):
         configuration.conf.set("webserver", "expose_config", "False")
         self.dag_bash.clear(start_date=DEFAULT_DATE, end_date=datetime.now())
+        session = Session()
+        session.query(models.DagRun).delete()
+        session.query(models.TaskInstance).delete()
+        session.commit()
+        session.close()
 
 
 class WebPasswordAuthTest(unittest.TestCase):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/dags/test_example_bash_operator.py
----------------------------------------------------------------------
diff --git a/tests/dags/test_example_bash_operator.py 
b/tests/dags/test_example_bash_operator.py
new file mode 100644
index 0000000..ad03353
--- /dev/null
+++ b/tests/dags/test_example_bash_operator.py
@@ -0,0 +1,55 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import airflow
+from builtins import range
+from airflow.operators.bash_operator import BashOperator
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.models import DAG
+from datetime import timedelta
+
+
+args = {
+    'owner': 'airflow',
+    'start_date': airflow.utils.dates.days_ago(2)
+}
+
+dag = DAG(
+    dag_id='test_example_bash_operator', default_args=args,
+    schedule_interval='0 0 * * *',
+    dagrun_timeout=timedelta(minutes=60))
+
+cmd = 'ls -l'
+run_this_last = DummyOperator(task_id='run_this_last', dag=dag)
+
+run_this = BashOperator(
+    task_id='run_after_loop', bash_command='echo 1', dag=dag)
+run_this.set_downstream(run_this_last)
+
+for i in range(3):
+    i = str(i)
+    task = BashOperator(
+        task_id='runme_'+i,
+        bash_command='echo "{{ task_instance_key_str }}" && sleep 1',
+        dag=dag)
+    task.set_downstream(run_this)
+
+task = BashOperator(
+    task_id='also_run_this',
+    bash_command='echo "run_id={{ run_id }} | dag_run={{ dag_run }}"',
+    dag=dag)
+task.set_downstream(run_this_last)
+
+if __name__ == "__main__":
+    dag.cli()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/563cc9a3/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 003fb21..868ea36 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -188,7 +188,7 @@ class DagBagTest(unittest.TestCase):
         class TestDagBag(models.DagBag):
             process_file_calls = 0
             def process_file(self, filepath, only_if_updated=True, 
safe_mode=True):
-                if 'example_bash_operator.py' in filepath:
+                if 'example_bash_operator.py' == os.path.basename(filepath):
                     TestDagBag.process_file_calls += 1
                 super(TestDagBag, self).process_file(filepath, 
only_if_updated, safe_mode)
 


Reply via email to