Repository: incubator-airflow
Updated Branches:
  refs/heads/v1-9-test ef775d4f8 -> 87afe8901


[AIRFLOW-1634] Adds task_concurrency feature

This adds a feature to limit the concurrency of
individual tasks. The
default will be to not change existing behavior.

Closes #2624 from saguziel/aguziel-task-
concurrency

(cherry picked from commit cfc2f73c445074e1e09d6ef6a056cd2b33a945da)
Signed-off-by: Bolke de Bruin <[email protected]>


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/87afe890
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/87afe890
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/87afe890

Branch: refs/heads/v1-9-test
Commit: 87afe8901559d4aa8b74179e980ca63fd1dedcb5
Parents: ef775d4
Author: Alex Guziel <[email protected]>
Authored: Thu Oct 5 14:37:26 2017 -0700
Committer: Bolke de Bruin <[email protected]>
Committed: Tue Oct 31 19:18:54 2017 +0100

----------------------------------------------------------------------
 airflow/jobs.py                              |  51 +++++++--
 airflow/models.py                            |  21 +++-
 airflow/ti_deps/dep_context.py               |   4 +-
 airflow/ti_deps/deps/task_concurrency_dep.py |  37 +++++++
 airflow/utils/dag_processing.py              |  18 +++-
 tests/jobs.py                                | 124 +++++++++++++++++++---
 tests/models.py                              |  56 ++++++++++
 tests/ti_deps/deps/test_task_concurrency.py  |  51 +++++++++
 8 files changed, 331 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index f92a570..7a7e564 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -999,6 +999,30 @@ class SchedulerJob(BaseJob):
             )
 
     @provide_session
+    def __get_task_concurrency_map(self, states, session=None):
+        """
+        Returns a map from tasks to number in the states list given.
+
+        :param states: List of states to query for
+        :type states: List[State]
+        :return: A map from (dag_id, task_id) to count of tasks in states
+        :rtype: Dict[[String, String], Int]
+
+        """
+        TI = models.TaskInstance
+        ti_concurrency_query = (
+            session
+            .query(TI.task_id, TI.dag_id, func.count('*'))
+            .filter(TI.state.in_(states))
+            .group_by(TI.task_id, TI.dag_id)
+        ).all()
+        task_map = defaultdict(int)
+        for result in ti_concurrency_query:
+            task_id, dag_id, count = result
+            task_map[(dag_id, task_id)] = count
+        return task_map
+
+    @provide_session
     def _find_executable_task_instances(self, simple_dag_bag, states, 
session=None):
         """
         Finds TIs that are ready for execution with respect to pool limits,
@@ -1013,6 +1037,9 @@ class SchedulerJob(BaseJob):
         :type states: Tuple[State]
         :return: List[TaskInstance]
         """
+        # TODO(saguziel): Change this to include QUEUED, for concurrency
+        # purposes we may want to count queued tasks
+        states_to_count_as_running = [State.RUNNING]
         executable_tis = []
 
         # Get all the queued task instances from associated with scheduled
@@ -1057,6 +1084,8 @@ class SchedulerJob(BaseJob):
         for task_instance in task_instances_to_examine:
             pool_to_task_instances[task_instance.pool].append(task_instance)
 
+        task_concurrency_map = 
self.__get_task_concurrency_map(states=states_to_count_as_running, 
session=session)
+
         # Go through each pool, and queue up a task for execution if there are
         # any open slots in the pool.
         for pool, task_instances in pool_to_task_instances.items():
@@ -1094,6 +1123,7 @@ class SchedulerJob(BaseJob):
                 # Check to make sure that the task concurrency of the DAG 
hasn't been
                 # reached.
                 dag_id = task_instance.dag_id
+                simple_dag = simple_dag_bag.get_dag(dag_id)
 
                 if dag_id not in dag_id_to_possibly_running_task_count:
                     # TODO(saguziel): also check against QUEUED state, see 
AIRFLOW-1104
@@ -1101,7 +1131,7 @@ class SchedulerJob(BaseJob):
                         DAG.get_num_task_instances(
                             dag_id,
                             simple_dag_bag.get_dag(dag_id).task_ids,
-                            states=[State.RUNNING],
+                            states=states_to_count_as_running,
                             session=session)
 
                 current_task_concurrency = 
dag_id_to_possibly_running_task_count[dag_id]
@@ -1118,6 +1148,16 @@ class SchedulerJob(BaseJob):
                     )
                     continue
 
+                task_concurrency = 
simple_dag.get_task_special_arg(task_instance.task_id, 'task_concurrency')
+                if task_concurrency is not None:
+                    num_running = task_concurrency_map[((task_instance.dag_id, 
task_instance.task_id))]
+                    if num_running >= task_concurrency:
+                        self.logger.info("Not executing %s since the task 
concurrency for this task"
+                                         " has been reached.", task_instance)
+                        continue
+                    else:
+                        task_concurrency_map[(task_instance.dag_id, 
task_instance.task_id)] += 1
+
                 if self.executor.has_task(task_instance):
                     self.log.debug(
                         "Not handling task %s as the executor reports it is 
running",
@@ -1726,16 +1766,9 @@ class SchedulerJob(BaseJob):
             if pickle_dags:
                 pickle_id = dag.pickle(session).id
 
-            task_ids = [task.task_id for task in dag.tasks]
-
             # Only return DAGs that are not paused
             if dag_id not in paused_dag_ids:
-                simple_dags.append(SimpleDag(dag.dag_id,
-                                             task_ids,
-                                             dag.full_filepath,
-                                             dag.concurrency,
-                                             dag.is_paused,
-                                             pickle_id))
+                simple_dags.append(SimpleDag(dag, pickle_id=pickle_id))
 
         if len(self.dag_ids) > 0:
             dags = [dag for dag in dagbag.dags.values()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 32b7d7e..e5bf857 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -65,6 +65,7 @@ from airflow.dag.base_dag import BaseDag, BaseDagBag
 from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
 from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
 from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
 
 from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
 from airflow.utils.dates import cron_presets, date_range as utils_date_range
@@ -1835,6 +1836,15 @@ class TaskInstance(Base, LoggingMixin):
         else:
             return pull_fn(task_id=task_ids)
 
+    @provide_session
+    def get_num_running_task_instances(self, session):
+        TI = TaskInstance
+        return session.query(TI).filter(
+            TI.dag_id == self.dag_id,
+            TI.task_id == self.task_id,
+            TI.state == State.RUNNING
+        ).count()
+
 
 class TaskFail(Base):
     """
@@ -2058,6 +2068,9 @@ class BaseOperator(LoggingMixin):
     :type resources: dict
     :param run_as_user: unix username to impersonate while running the task
     :type run_as_user: str
+    :param task_concurrency: When set, a task will be able to limit the 
concurrent
+        runs across execution_dates
+    :type task_concurrency: int
     """
 
     # For derived classes to define which fields will get jinjaified
@@ -2100,6 +2113,7 @@ class BaseOperator(LoggingMixin):
             trigger_rule=TriggerRule.ALL_SUCCESS,
             resources=None,
             run_as_user=None,
+            task_concurrency=None,
             *args,
             **kwargs):
 
@@ -2165,6 +2179,7 @@ class BaseOperator(LoggingMixin):
         self.priority_weight = priority_weight
         self.resources = Resources(**(resources or {}))
         self.run_as_user = run_as_user
+        self.task_concurrency = task_concurrency
 
         # Private attributes
         self._upstream_task_ids = []
@@ -4542,8 +4557,9 @@ class DagRun(Base, LoggingMixin):
             session=session
         )
         none_depends_on_past = all(not t.task.depends_on_past for t in 
unfinished_tasks)
+        none_task_concurrency = all(t.task.task_concurrency is None for t in 
unfinished_tasks)
         # small speed up
-        if unfinished_tasks and none_depends_on_past:
+        if unfinished_tasks and none_depends_on_past and none_task_concurrency:
             # todo: this can actually get pretty slow: one task costs between 
0.01-015s
             no_dependencies_met = True
             for ut in unfinished_tasks:
@@ -4581,7 +4597,8 @@ class DagRun(Base, LoggingMixin):
                 self.state = State.SUCCESS
 
             # if *all tasks* are deadlocked, the run failed
-            elif unfinished_tasks and none_depends_on_past and 
no_dependencies_met:
+            elif (unfinished_tasks and none_depends_on_past and
+                  none_task_concurrency and no_dependencies_met):
                 self.log.info('Deadlock; marking run %s failed', self)
                 self.state = State.FAILED
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/ti_deps/dep_context.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py
index 01e01dd..f461a81 100644
--- a/airflow/ti_deps/dep_context.py
+++ b/airflow/ti_deps/dep_context.py
@@ -19,6 +19,7 @@ from airflow.ti_deps.deps.not_running_dep import NotRunningDep
 from airflow.ti_deps.deps.not_skipped_dep import NotSkippedDep
 from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep
 from airflow.ti_deps.deps.valid_state_dep import ValidStateDep
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
 from airflow.utils.state import State
 
 
@@ -97,7 +98,8 @@ QUEUE_DEPS = {
 # Dependencies that need to be met for a given task instance to be able to get 
run by an
 # executor. This class just extends QueueContext by adding dependencies for 
resources.
 RUN_DEPS = QUEUE_DEPS | {
-    DagTISlotsAvailableDep()
+    DagTISlotsAvailableDep(),
+    TaskConcurrencyDep(),
 }
 
 # TODO(aoen): SCHEDULER_DEPS is not coupled to actual execution in any way and

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/ti_deps/deps/task_concurrency_dep.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py 
b/airflow/ti_deps/deps/task_concurrency_dep.py
new file mode 100644
index 0000000..99df5ac
--- /dev/null
+++ b/airflow/ti_deps/deps/task_concurrency_dep.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+from airflow.utils.db import provide_session
+
+
+class TaskConcurrencyDep(BaseTIDep):
+    """
+    This restricts the number of running task instances for a particular task.
+    """
+    NAME = "Task Concurrency"
+    IGNOREABLE = True
+    IS_TASK_DEP = True
+
+    @provide_session
+    def _get_dep_statuses(self, ti, session, dep_context):
+        if ti.task.task_concurrency is None:
+            yield self._passing_status(reason="Task concurrency is not set.")
+            return
+
+        if ti.get_num_running_task_instances(session) >= 
ti.task.task_concurrency:
+            yield self._failing_status(reason="The max task concurrency has 
been reached.")
+            return
+        else:
+            yield self._passing_status(reason="The max task concurrency has 
not been reached.")
+            return

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/utils/dag_processing.py
----------------------------------------------------------------------
diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py
index 3a6cb98..5e92f0e 100644
--- a/airflow/utils/dag_processing.py
+++ b/airflow/utils/dag_processing.py
@@ -105,6 +105,16 @@ class SimpleDag(BaseDag):
         """
         return self._pickle_id
 
+    @property
+    def task_special_args(self):
+        return self._task_special_args
+
+    def get_task_special_arg(self, task_id, special_arg_name):
+        if task_id in self._task_special_args and special_arg_name in 
self._task_special_args[task_id]:
+            return self._task_special_args[task_id][special_arg_name]
+        else:
+            return None
+
 
 class SimpleDagBag(BaseDagBag):
     """
@@ -366,7 +376,7 @@ class DagFileProcessorManager(LoggingMixin):
         being processed
         """
         if file_path in self._processors:
-            return (datetime.utcnow() - 
self._processors[file_path].start_time)\
+            return (datetime.utcnow() - 
self._processors[file_path].start_time) \
                 .total_seconds()
         return None
 
@@ -489,8 +499,8 @@ class DagFileProcessorManager(LoggingMixin):
             for file_path in self._file_paths:
                 last_finish_time = self.get_last_finish_time(file_path)
                 if (last_finish_time is not None and
-                    (now - last_finish_time).total_seconds() <
-                        self._process_file_interval):
+                            (now - last_finish_time).total_seconds() <
+                            self._process_file_interval):
                     file_paths_recently_processed.append(file_path)
 
             files_paths_at_run_limit = [file_path
@@ -517,7 +527,7 @@ class DagFileProcessorManager(LoggingMixin):
 
         # Start more processors if we have enough slots and files to process
         while (self._parallelism - len(self._processors) > 0 and
-               len(self._file_path_queue) > 0):
+                       len(self._file_path_queue) > 0):
             file_path = self._file_path_queue.pop(0)
             processor = self._processor_factory(file_path)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index f4bbe81..e8fff7e 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -41,7 +41,7 @@ from airflow.utils.dates import days_ago
 from airflow.utils.db import provide_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
-from airflow.utils.dag_processing import SimpleDagBag, list_py_file_paths
+from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, 
list_py_file_paths
 
 from mock import Mock, patch
 from sqlalchemy.orm.session import make_transient
@@ -935,7 +935,7 @@ class SchedulerJobTest(unittest.TestCase):
 
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -965,7 +965,7 @@ class SchedulerJobTest(unittest.TestCase):
 
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -990,7 +990,7 @@ class SchedulerJobTest(unittest.TestCase):
 
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1015,7 +1015,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1055,7 +1055,7 @@ class SchedulerJobTest(unittest.TestCase):
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a')
         task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b')
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1096,7 +1096,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1114,7 +1114,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1155,6 +1155,98 @@ class SchedulerJobTest(unittest.TestCase):
 
         self.assertEqual(0, len(res))
 
+    def test_find_executable_task_instances_task_concurrency(self):
+        dag_id = 
'SchedulerJobTest.test_find_executable_task_instances_task_concurrency'
+        task_id_1 = 'dummy'
+        task_id_2 = 'dummy2'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1, task_concurrency=2)
+        task2 = DummyOperator(dag=dag, task_id=task_id_2)
+        dagbag = self._make_simple_dag_bag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        dr2 = scheduler.create_dag_run(dag)
+        dr3 = scheduler.create_dag_run(dag)
+
+        ti1_1 = TI(task1, dr1.execution_date)
+        ti2 = TI(task2, dr1.execution_date)
+
+        ti1_1.state = State.SCHEDULED
+        ti2.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti2)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(2, len(res))
+
+        ti1_1.state = State.RUNNING
+        ti2.state = State.RUNNING
+        ti1_2 = TI(task1, dr2.execution_date)
+        ti1_2.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti2)
+        session.merge(ti1_2)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(1, len(res))
+
+        ti1_2.state = State.RUNNING
+        ti1_3 = TI(task1, dr3.execution_date)
+        ti1_3.state = State.SCHEDULED
+        session.merge(ti1_2)
+        session.merge(ti1_3)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(0, len(res))
+
+        ti1_1.state = State.SCHEDULED
+        ti1_2.state = State.SCHEDULED
+        ti1_3.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti1_2)
+        session.merge(ti1_3)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(2, len(res))
+
+        ti1_1.state = State.RUNNING
+        ti1_2.state = State.SCHEDULED
+        ti1_3.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti1_2)
+        session.merge(ti1_3)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(1, len(res))
+
     def test_change_state_for_executable_task_instances_no_tis(self):
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1166,7 +1258,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1198,7 +1290,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1234,7 +1326,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1279,7 +1371,7 @@ class SchedulerJobTest(unittest.TestCase):
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
         task2 = DummyOperator(dag=dag, task_id=task_id_2)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1340,7 +1432,7 @@ class SchedulerJobTest(unittest.TestCase):
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
         task2 = DummyOperator(dag=dag, task_id=task_id_2)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         scheduler.max_tis_per_query = 3
@@ -1407,16 +1499,18 @@ class SchedulerJobTest(unittest.TestCase):
         ti2.state = State.SCHEDULED
         session.commit()
 
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
         scheduler = SchedulerJob(num_runs=0, run_duration=0)
         scheduler._change_state_for_tis_without_dagrun(simple_dag_bag=dagbag,
                                                        
old_states=[State.SCHEDULED, State.QUEUED],
                                                        new_state=State.NONE,
                                                        session=session)
 
+        ti = dr.get_task_instance(task_id='dummy', session=session)
         ti.refresh_from_db(session=session)
         self.assertEqual(ti.state, State.SCHEDULED)
 
+        ti2 = dr2.get_task_instance(task_id='dummy', session=session)
         ti2.refresh_from_db(session=session)
         self.assertEqual(ti2.state, State.SCHEDULED)
 
@@ -2039,7 +2133,7 @@ class SchedulerJobTest(unittest.TestCase):
         queue = []
         scheduler._process_task_instances(dag, queue=queue)
         self.assertEquals(len(queue), 2)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         # Recreated part of the scheduler here, to kick off tasks -> executor
         for ti_key in queue:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index db5beca..a1de17d 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -517,6 +517,39 @@ class DagRunTest(unittest.TestCase):
         dr.update_state()
         self.assertEqual(dr.state, State.FAILED)
 
+    def test_dagrun_no_deadlock(self):
+        session = settings.Session()
+        dag = DAG('test_dagrun_no_deadlock',
+                  start_date=DEFAULT_DATE)
+        with dag:
+            op1 = DummyOperator(task_id='dop', depends_on_past=True)
+            op2 = DummyOperator(task_id='tc', task_concurrency=1)
+
+        dag.clear()
+        dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1',
+                               state=State.RUNNING,
+                               execution_date=DEFAULT_DATE,
+                               start_date=DEFAULT_DATE)
+        dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2',
+                                state=State.RUNNING,
+                                execution_date=DEFAULT_DATE + 
datetime.timedelta(days=1),
+                                start_date=DEFAULT_DATE + 
datetime.timedelta(days=1))
+        ti1_op1 = dr.get_task_instance(task_id='dop')
+        ti2_op1 = dr2.get_task_instance(task_id='dop')
+        ti2_op1 = dr.get_task_instance(task_id='tc')
+        ti2_op2 = dr.get_task_instance(task_id='tc')
+        ti1_op1.set_state(state=State.RUNNING, session=session)
+        dr.update_state()
+        dr2.update_state()
+        self.assertEqual(dr.state, State.RUNNING)
+        self.assertEqual(dr2.state, State.RUNNING)
+
+        ti2_op1.set_state(state=State.RUNNING, session=session)
+        dr.update_state()
+        dr2.update_state()
+        self.assertEqual(dr.state, State.RUNNING)
+        self.assertEqual(dr2.state, State.RUNNING)
+
     def test_get_task_instance_on_empty_dagrun(self):
         """
         Make sure that a proper value is returned when a dagrun has no task 
instances
@@ -1201,6 +1234,29 @@ class TaskInstanceTest(unittest.TestCase):
         ti = TI(
             task=task2, execution_date=datetime.datetime.now())
         self.assertFalse(ti._check_and_change_state_before_execution())
+
+    def test_get_num_running_task_instances(self):
+        session = settings.Session()
+
+        dag = models.DAG(dag_id='test_get_num_running_task_instances')
+        dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy')
+        task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
+        task2 = DummyOperator(task_id='task', dag=dag2, 
start_date=DEFAULT_DATE)
+
+        ti1 = TI(task=task, execution_date=DEFAULT_DATE)
+        ti2 = TI(task=task, execution_date=DEFAULT_DATE + 
datetime.timedelta(days=1))
+        ti3 = TI(task=task2, execution_date=DEFAULT_DATE)
+        ti1.state = State.RUNNING
+        ti2.state = State.QUEUED
+        ti3.state = State.RUNNING
+        session.add(ti1)
+        session.add(ti2)
+        session.add(ti3)
+        session.commit()
+
+        self.assertEquals(1, 
ti1.get_num_running_task_instances(session=session))
+        self.assertEquals(1, 
ti2.get_num_running_task_instances(session=session))
+        self.assertEquals(1, 
ti3.get_num_running_task_instances(session=session))
         
 
 class ClearTasksTest(unittest.TestCase):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/tests/ti_deps/deps/test_task_concurrency.py
----------------------------------------------------------------------
diff --git a/tests/ti_deps/deps/test_task_concurrency.py 
b/tests/ti_deps/deps/test_task_concurrency.py
new file mode 100644
index 0000000..77a5990
--- /dev/null
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from datetime import datetime
+from mock import Mock
+
+from airflow.models import DAG, BaseOperator
+from airflow.ti_deps.dep_context import DepContext
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
+from airflow.utils.state import State
+
+
+class TaskConcurrencyDepTest(unittest.TestCase):
+
+    def _get_task(self, **kwargs):
+        return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs)
+
+    def test_not_task_concurrency(self):
+        task = self._get_task(start_date=datetime(2016, 1, 1))
+        dep_context = DepContext()
+        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+        self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, 
dep_context=dep_context))
+
+    def test_not_reached_concurrency(self):
+        task = self._get_task(start_date=datetime(2016, 1, 1), 
task_concurrency=1)
+        dep_context = DepContext()
+        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+        ti.get_num_running_task_instances = lambda x: 0
+        self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, 
dep_context=dep_context))
+
+    def test_reached_concurrency(self):
+        task = self._get_task(start_date=datetime(2016, 1, 1), 
task_concurrency=2)
+        dep_context = DepContext()
+        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+        ti.get_num_running_task_instances = lambda x: 1
+        self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, 
dep_context=dep_context))
+        ti.get_num_running_task_instances = lambda x: 2
+        self.assertFalse(TaskConcurrencyDep().is_met(ti=ti, 
dep_context=dep_context))
+

Reply via email to