Repository: incubator-airflow Updated Branches: refs/heads/master 96206b0e5 -> cfc2f73c4
[AIRFLOW-1634] Adds task_concurrency feature This adds a feature to limit the concurrency of individual tasks. The default will be to not change existing behavior. Closes #2624 from saguziel/aguziel-task- concurrency Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/cfc2f73c Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/cfc2f73c Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/cfc2f73c Branch: refs/heads/master Commit: cfc2f73c445074e1e09d6ef6a056cd2b33a945da Parents: 96206b0 Author: Alex Guziel <[email protected]> Authored: Thu Oct 5 14:37:26 2017 -0700 Committer: Alex Guziel <[email protected]> Committed: Thu Oct 5 14:37:26 2017 -0700 ---------------------------------------------------------------------- airflow/jobs.py | 51 +++++++-- airflow/models.py | 21 +++- airflow/ti_deps/dep_context.py | 4 +- airflow/ti_deps/deps/task_concurrency_dep.py | 37 +++++++ airflow/utils/dag_processing.py | 29 ++++- tests/jobs.py | 127 +++++++++++++++++++--- tests/models.py | 56 ++++++++++ tests/ti_deps/deps/test_task_concurrency.py | 51 +++++++++ 8 files changed, 348 insertions(+), 28 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/jobs.py ---------------------------------------------------------------------- diff --git a/airflow/jobs.py b/airflow/jobs.py index 8ca81dc..2675bd3 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -1024,6 +1024,30 @@ class SchedulerJob(BaseJob): ) @provide_session + def __get_task_concurrency_map(self, states, session=None): + """ + Returns a map from tasks to number in the states list given. + + :param states: List of states to query for + :type states: List[State] + :return: A map from (dag_id, task_id) to count of tasks in states + :rtype: Dict[[String, String], Int] + + """ + TI = models.TaskInstance + ti_concurrency_query = ( + session + .query(TI.task_id, TI.dag_id, func.count('*')) + .filter(TI.state.in_(states)) + .group_by(TI.task_id, TI.dag_id) + ).all() + task_map = defaultdict(int) + for result in ti_concurrency_query: + task_id, dag_id, count = result + task_map[(dag_id, task_id)] = count + return task_map + + @provide_session def _find_executable_task_instances(self, simple_dag_bag, states, session=None): """ Finds TIs that are ready for execution with respect to pool limits, @@ -1038,6 +1062,9 @@ class SchedulerJob(BaseJob): :type states: Tuple[State] :return: List[TaskInstance] """ + # TODO(saguziel): Change this to include QUEUED, for concurrency + # purposes we may want to count queued tasks + states_to_count_as_running = [State.RUNNING] executable_tis = [] # Get all the queued task instances from associated with scheduled @@ -1082,6 +1109,8 @@ class SchedulerJob(BaseJob): for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) + task_concurrency_map = self.__get_task_concurrency_map(states=states_to_count_as_running, session=session) + # Go through each pool, and queue up a task for execution if there are # any open slots in the pool. for pool, task_instances in pool_to_task_instances.items(): @@ -1119,6 +1148,7 @@ class SchedulerJob(BaseJob): # Check to make sure that the task concurrency of the DAG hasn't been # reached. dag_id = task_instance.dag_id + simple_dag = simple_dag_bag.get_dag(dag_id) if dag_id not in dag_id_to_possibly_running_task_count: # TODO(saguziel): also check against QUEUED state, see AIRFLOW-1104 @@ -1126,7 +1156,7 @@ class SchedulerJob(BaseJob): DAG.get_num_task_instances( dag_id, simple_dag_bag.get_dag(dag_id).task_ids, - states=[State.RUNNING], + states=states_to_count_as_running, session=session) current_task_concurrency = dag_id_to_possibly_running_task_count[dag_id] @@ -1143,6 +1173,16 @@ class SchedulerJob(BaseJob): ) continue + task_concurrency = simple_dag.get_task_special_arg(task_instance.task_id, 'task_concurrency') + if task_concurrency is not None: + num_running = task_concurrency_map[((task_instance.dag_id, task_instance.task_id))] + if num_running >= task_concurrency: + self.logger.info("Not executing %s since the task concurrency for this task" + " has been reached.", task_instance) + continue + else: + task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 + if self.executor.has_task(task_instance): self.log.debug( "Not handling task %s as the executor reports it is running", @@ -1723,16 +1763,9 @@ class SchedulerJob(BaseJob): if pickle_dags: pickle_id = dag.pickle(session).id - task_ids = [task.task_id for task in dag.tasks] - # Only return DAGs that are not paused if dag_id not in paused_dag_ids: - simple_dags.append(SimpleDag(dag.dag_id, - task_ids, - dag.full_filepath, - dag.concurrency, - dag.is_paused, - pickle_id)) + simple_dags.append(SimpleDag(dag, pickle_id=pickle_id)) if len(self.dag_ids) > 0: dags = [dag for dag in dagbag.dags.values() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index e764d85..e3c52b5 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -65,6 +65,7 @@ from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep +from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS from airflow.utils.dates import cron_presets, date_range as utils_date_range @@ -1835,6 +1836,15 @@ class TaskInstance(Base, LoggingMixin): else: return pull_fn(task_id=task_ids) + @provide_session + def get_num_running_task_instances(self, session): + TI = TaskInstance + return session.query(TI).filter( + TI.dag_id == self.dag_id, + TI.task_id == self.task_id, + TI.state == State.RUNNING + ).count() + class TaskFail(Base): """ @@ -2058,6 +2068,9 @@ class BaseOperator(LoggingMixin): :type resources: dict :param run_as_user: unix username to impersonate while running the task :type run_as_user: str + :param task_concurrency: When set, a task will be able to limit the concurrent + runs across execution_dates + :type task_concurrency: int """ # For derived classes to define which fields will get jinjaified @@ -2100,6 +2113,7 @@ class BaseOperator(LoggingMixin): trigger_rule=TriggerRule.ALL_SUCCESS, resources=None, run_as_user=None, + task_concurrency=None, *args, **kwargs): @@ -2165,6 +2179,7 @@ class BaseOperator(LoggingMixin): self.priority_weight = priority_weight self.resources = Resources(**(resources or {})) self.run_as_user = run_as_user + self.task_concurrency = task_concurrency # Private attributes self._upstream_task_ids = [] @@ -4542,8 +4557,9 @@ class DagRun(Base, LoggingMixin): session=session ) none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) + none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) # small speed up - if unfinished_tasks and none_depends_on_past: + if unfinished_tasks and none_depends_on_past and none_task_concurrency: # todo: this can actually get pretty slow: one task costs between 0.01-015s no_dependencies_met = True for ut in unfinished_tasks: @@ -4581,7 +4597,8 @@ class DagRun(Base, LoggingMixin): self.state = State.SUCCESS # if *all tasks* are deadlocked, the run failed - elif unfinished_tasks and none_depends_on_past and no_dependencies_met: + elif (unfinished_tasks and none_depends_on_past and + none_task_concurrency and no_dependencies_met): self.log.info('Deadlock; marking run %s failed', self) self.state = State.FAILED http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/ti_deps/dep_context.py ---------------------------------------------------------------------- diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index 01e01dd..f461a81 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -19,6 +19,7 @@ from airflow.ti_deps.deps.not_running_dep import NotRunningDep from airflow.ti_deps.deps.not_skipped_dep import NotSkippedDep from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep from airflow.ti_deps.deps.valid_state_dep import ValidStateDep +from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep from airflow.utils.state import State @@ -97,7 +98,8 @@ QUEUE_DEPS = { # Dependencies that need to be met for a given task instance to be able to get run by an # executor. This class just extends QueueContext by adding dependencies for resources. RUN_DEPS = QUEUE_DEPS | { - DagTISlotsAvailableDep() + DagTISlotsAvailableDep(), + TaskConcurrencyDep(), } # TODO(aoen): SCHEDULER_DEPS is not coupled to actual execution in any way and http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/ti_deps/deps/task_concurrency_dep.py ---------------------------------------------------------------------- diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py b/airflow/ti_deps/deps/task_concurrency_dep.py new file mode 100644 index 0000000..99df5ac --- /dev/null +++ b/airflow/ti_deps/deps/task_concurrency_dep.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep +from airflow.utils.db import provide_session + + +class TaskConcurrencyDep(BaseTIDep): + """ + This restricts the number of running task instances for a particular task. + """ + NAME = "Task Concurrency" + IGNOREABLE = True + IS_TASK_DEP = True + + @provide_session + def _get_dep_statuses(self, ti, session, dep_context): + if ti.task.task_concurrency is None: + yield self._passing_status(reason="Task concurrency is not set.") + return + + if ti.get_num_running_task_instances(session) >= ti.task.task_concurrency: + yield self._failing_status(reason="The max task concurrency has been reached.") + return + else: + yield self._passing_status(reason="The max task concurrency has not been reached.") + return http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/utils/dag_processing.py ---------------------------------------------------------------------- diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index d8c13ea..b80f701 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -42,7 +42,8 @@ class SimpleDag(BaseDag): full_filepath, concurrency, is_paused, - pickle_id): + pickle_id, + task_special_args): """ :param dag_id: ID of the DAG :type dag_id: unicode @@ -66,6 +67,22 @@ class SimpleDag(BaseDag): self._is_paused = is_paused self._concurrency = concurrency self._pickle_id = pickle_id + self._task_special_args = task_special_args + + def __init__(self, dag, pickle_id=None): + self._dag_id = dag.dag_id + self._task_ids = [task.task_id for task in dag.tasks] + self._full_filepath = dag.full_filepath + self._is_paused = dag.is_paused + self._concurrency = dag.concurrency + self._pickle_id = pickle_id + self._task_special_args = {} + for task in dag.tasks: + special_args = {} + if task.task_concurrency is not None: + special_args['task_concurrency'] = task.task_concurrency + if len(special_args) > 0: + self._task_special_args[task.task_id] = special_args @property def dag_id(self): @@ -115,6 +132,16 @@ class SimpleDag(BaseDag): """ return self._pickle_id + @property + def task_special_args(self): + return self._task_special_args + + def get_task_special_arg(self, task_id, special_arg_name): + if task_id in self._task_special_args and special_arg_name in self._task_special_args[task_id]: + return self._task_special_args[task_id][special_arg_name] + else: + return None + class SimpleDagBag(BaseDagBag): """ http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/jobs.py ---------------------------------------------------------------------- diff --git a/tests/jobs.py b/tests/jobs.py index 0a7f213..ba08fd6 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -41,7 +41,7 @@ from airflow.utils.dates import days_ago from airflow.utils.db import provide_session from airflow.utils.state import State from airflow.utils.timeout import timeout -from airflow.utils.dag_processing import SimpleDagBag, list_py_file_paths +from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths from mock import Mock, patch from sqlalchemy.orm.session import make_transient @@ -932,13 +932,16 @@ class SchedulerJobTest(unittest.TestCase): scheduler.heartrate = 0 scheduler.run() + def _make_simple_dag_bag(self, dags): + return SimpleDagBag([SimpleDag(dag) for dag in dags]) + def test_execute_task_instances_is_paused_wont_execute(self): dag_id = 'SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute' task_id_1 = 'dummy_task' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -968,7 +971,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -993,7 +996,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1018,7 +1021,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1058,7 +1061,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a') task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b') - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1099,7 +1102,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1117,7 +1120,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1158,6 +1161,98 @@ class SchedulerJobTest(unittest.TestCase): self.assertEqual(0, len(res)) + def test_find_executable_task_instances_task_concurrency(self): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency' + task_id_1 = 'dummy' + task_id_2 = 'dummy2' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) + task1 = DummyOperator(dag=dag, task_id=task_id_1, task_concurrency=2) + task2 = DummyOperator(dag=dag, task_id=task_id_2) + dagbag = self._make_simple_dag_bag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + dr2 = scheduler.create_dag_run(dag) + dr3 = scheduler.create_dag_run(dag) + + ti1_1 = TI(task1, dr1.execution_date) + ti2 = TI(task2, dr1.execution_date) + + ti1_1.state = State.SCHEDULED + ti2.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti2) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(2, len(res)) + + ti1_1.state = State.RUNNING + ti2.state = State.RUNNING + ti1_2 = TI(task1, dr2.execution_date) + ti1_2.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti2) + session.merge(ti1_2) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(1, len(res)) + + ti1_2.state = State.RUNNING + ti1_3 = TI(task1, dr3.execution_date) + ti1_3.state = State.SCHEDULED + session.merge(ti1_2) + session.merge(ti1_3) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(0, len(res)) + + ti1_1.state = State.SCHEDULED + ti1_2.state = State.SCHEDULED + ti1_3.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti1_2) + session.merge(ti1_3) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(2, len(res)) + + ti1_1.state = State.RUNNING + ti1_2.state = State.SCHEDULED + ti1_3.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti1_2) + session.merge(ti1_3) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(1, len(res)) + def test_change_state_for_executable_task_instances_no_tis(self): scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1169,7 +1264,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1201,7 +1296,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1237,7 +1332,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1282,7 +1377,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3) task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1343,7 +1438,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) scheduler.max_tis_per_query = 3 @@ -1410,16 +1505,18 @@ class SchedulerJobTest(unittest.TestCase): ti2.state = State.SCHEDULED session.commit() - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(num_runs=0, run_duration=0) scheduler._change_state_for_tis_without_dagrun(simple_dag_bag=dagbag, old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session) + ti = dr.get_task_instance(task_id='dummy', session=session) ti.refresh_from_db(session=session) self.assertEqual(ti.state, State.SCHEDULED) + ti2 = dr2.get_task_instance(task_id='dummy', session=session) ti2.refresh_from_db(session=session) self.assertEqual(ti2.state, State.SCHEDULED) @@ -2042,7 +2139,7 @@ class SchedulerJobTest(unittest.TestCase): queue = [] scheduler._process_task_instances(dag, queue=queue) self.assertEquals(len(queue), 2) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) # Recreated part of the scheduler here, to kick off tasks -> executor for ti_key in queue: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/models.py ---------------------------------------------------------------------- diff --git a/tests/models.py b/tests/models.py index db5beca..a1de17d 100644 --- a/tests/models.py +++ b/tests/models.py @@ -517,6 +517,39 @@ class DagRunTest(unittest.TestCase): dr.update_state() self.assertEqual(dr.state, State.FAILED) + def test_dagrun_no_deadlock(self): + session = settings.Session() + dag = DAG('test_dagrun_no_deadlock', + start_date=DEFAULT_DATE) + with dag: + op1 = DummyOperator(task_id='dop', depends_on_past=True) + op2 = DummyOperator(task_id='tc', task_concurrency=1) + + dag.clear() + dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1', + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE) + dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2', + state=State.RUNNING, + execution_date=DEFAULT_DATE + datetime.timedelta(days=1), + start_date=DEFAULT_DATE + datetime.timedelta(days=1)) + ti1_op1 = dr.get_task_instance(task_id='dop') + ti2_op1 = dr2.get_task_instance(task_id='dop') + ti2_op1 = dr.get_task_instance(task_id='tc') + ti2_op2 = dr.get_task_instance(task_id='tc') + ti1_op1.set_state(state=State.RUNNING, session=session) + dr.update_state() + dr2.update_state() + self.assertEqual(dr.state, State.RUNNING) + self.assertEqual(dr2.state, State.RUNNING) + + ti2_op1.set_state(state=State.RUNNING, session=session) + dr.update_state() + dr2.update_state() + self.assertEqual(dr.state, State.RUNNING) + self.assertEqual(dr2.state, State.RUNNING) + def test_get_task_instance_on_empty_dagrun(self): """ Make sure that a proper value is returned when a dagrun has no task instances @@ -1201,6 +1234,29 @@ class TaskInstanceTest(unittest.TestCase): ti = TI( task=task2, execution_date=datetime.datetime.now()) self.assertFalse(ti._check_and_change_state_before_execution()) + + def test_get_num_running_task_instances(self): + session = settings.Session() + + dag = models.DAG(dag_id='test_get_num_running_task_instances') + dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy') + task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) + task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE) + + ti1 = TI(task=task, execution_date=DEFAULT_DATE) + ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) + ti3 = TI(task=task2, execution_date=DEFAULT_DATE) + ti1.state = State.RUNNING + ti2.state = State.QUEUED + ti3.state = State.RUNNING + session.add(ti1) + session.add(ti2) + session.add(ti3) + session.commit() + + self.assertEquals(1, ti1.get_num_running_task_instances(session=session)) + self.assertEquals(1, ti2.get_num_running_task_instances(session=session)) + self.assertEquals(1, ti3.get_num_running_task_instances(session=session)) class ClearTasksTest(unittest.TestCase): http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/ti_deps/deps/test_task_concurrency.py ---------------------------------------------------------------------- diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py new file mode 100644 index 0000000..77a5990 --- /dev/null +++ b/tests/ti_deps/deps/test_task_concurrency.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime +from mock import Mock + +from airflow.models import DAG, BaseOperator +from airflow.ti_deps.dep_context import DepContext +from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep +from airflow.utils.state import State + + +class TaskConcurrencyDepTest(unittest.TestCase): + + def _get_task(self, **kwargs): + return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs) + + def test_not_task_concurrency(self): + task = self._get_task(start_date=datetime(2016, 1, 1)) + dep_context = DepContext() + ti = Mock(task=task, execution_date=datetime(2016, 1, 1)) + self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) + + def test_not_reached_concurrency(self): + task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=1) + dep_context = DepContext() + ti = Mock(task=task, execution_date=datetime(2016, 1, 1)) + ti.get_num_running_task_instances = lambda x: 0 + self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) + + def test_reached_concurrency(self): + task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=2) + dep_context = DepContext() + ti = Mock(task=task, execution_date=datetime(2016, 1, 1)) + ti.get_num_running_task_instances = lambda x: 1 + self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) + ti.get_num_running_task_instances = lambda x: 2 + self.assertFalse(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) +
