Repository: incubator-airflow Updated Branches: refs/heads/v1-9-test ef775d4f8 -> 87afe8901
[AIRFLOW-1634] Adds task_concurrency feature This adds a feature to limit the concurrency of individual tasks. The default will be to not change existing behavior. Closes #2624 from saguziel/aguziel-task- concurrency (cherry picked from commit cfc2f73c445074e1e09d6ef6a056cd2b33a945da) Signed-off-by: Bolke de Bruin <[email protected]> Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/87afe890 Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/87afe890 Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/87afe890 Branch: refs/heads/v1-9-test Commit: 87afe8901559d4aa8b74179e980ca63fd1dedcb5 Parents: ef775d4 Author: Alex Guziel <[email protected]> Authored: Thu Oct 5 14:37:26 2017 -0700 Committer: Bolke de Bruin <[email protected]> Committed: Tue Oct 31 19:18:54 2017 +0100 ---------------------------------------------------------------------- airflow/jobs.py | 51 +++++++-- airflow/models.py | 21 +++- airflow/ti_deps/dep_context.py | 4 +- airflow/ti_deps/deps/task_concurrency_dep.py | 37 +++++++ airflow/utils/dag_processing.py | 18 +++- tests/jobs.py | 124 +++++++++++++++++++--- tests/models.py | 56 ++++++++++ tests/ti_deps/deps/test_task_concurrency.py | 51 +++++++++ 8 files changed, 331 insertions(+), 31 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/jobs.py ---------------------------------------------------------------------- diff --git a/airflow/jobs.py b/airflow/jobs.py index f92a570..7a7e564 100644 --- a/airflow/jobs.py +++ b/airflow/jobs.py @@ -999,6 +999,30 @@ class SchedulerJob(BaseJob): ) @provide_session + def __get_task_concurrency_map(self, states, session=None): + """ + Returns a map from tasks to number in the states list given. + + :param states: List of states to query for + :type states: List[State] + :return: A map from (dag_id, task_id) to count of tasks in states + :rtype: Dict[[String, String], Int] + + """ + TI = models.TaskInstance + ti_concurrency_query = ( + session + .query(TI.task_id, TI.dag_id, func.count('*')) + .filter(TI.state.in_(states)) + .group_by(TI.task_id, TI.dag_id) + ).all() + task_map = defaultdict(int) + for result in ti_concurrency_query: + task_id, dag_id, count = result + task_map[(dag_id, task_id)] = count + return task_map + + @provide_session def _find_executable_task_instances(self, simple_dag_bag, states, session=None): """ Finds TIs that are ready for execution with respect to pool limits, @@ -1013,6 +1037,9 @@ class SchedulerJob(BaseJob): :type states: Tuple[State] :return: List[TaskInstance] """ + # TODO(saguziel): Change this to include QUEUED, for concurrency + # purposes we may want to count queued tasks + states_to_count_as_running = [State.RUNNING] executable_tis = [] # Get all the queued task instances from associated with scheduled @@ -1057,6 +1084,8 @@ class SchedulerJob(BaseJob): for task_instance in task_instances_to_examine: pool_to_task_instances[task_instance.pool].append(task_instance) + task_concurrency_map = self.__get_task_concurrency_map(states=states_to_count_as_running, session=session) + # Go through each pool, and queue up a task for execution if there are # any open slots in the pool. for pool, task_instances in pool_to_task_instances.items(): @@ -1094,6 +1123,7 @@ class SchedulerJob(BaseJob): # Check to make sure that the task concurrency of the DAG hasn't been # reached. dag_id = task_instance.dag_id + simple_dag = simple_dag_bag.get_dag(dag_id) if dag_id not in dag_id_to_possibly_running_task_count: # TODO(saguziel): also check against QUEUED state, see AIRFLOW-1104 @@ -1101,7 +1131,7 @@ class SchedulerJob(BaseJob): DAG.get_num_task_instances( dag_id, simple_dag_bag.get_dag(dag_id).task_ids, - states=[State.RUNNING], + states=states_to_count_as_running, session=session) current_task_concurrency = dag_id_to_possibly_running_task_count[dag_id] @@ -1118,6 +1148,16 @@ class SchedulerJob(BaseJob): ) continue + task_concurrency = simple_dag.get_task_special_arg(task_instance.task_id, 'task_concurrency') + if task_concurrency is not None: + num_running = task_concurrency_map[((task_instance.dag_id, task_instance.task_id))] + if num_running >= task_concurrency: + self.logger.info("Not executing %s since the task concurrency for this task" + " has been reached.", task_instance) + continue + else: + task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1 + if self.executor.has_task(task_instance): self.log.debug( "Not handling task %s as the executor reports it is running", @@ -1726,16 +1766,9 @@ class SchedulerJob(BaseJob): if pickle_dags: pickle_id = dag.pickle(session).id - task_ids = [task.task_id for task in dag.tasks] - # Only return DAGs that are not paused if dag_id not in paused_dag_ids: - simple_dags.append(SimpleDag(dag.dag_id, - task_ids, - dag.full_filepath, - dag.concurrency, - dag.is_paused, - pickle_id)) + simple_dags.append(SimpleDag(dag, pickle_id=pickle_id)) if len(self.dag_ids) > 0: dags = [dag for dag in dagbag.dags.values() http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/models.py ---------------------------------------------------------------------- diff --git a/airflow/models.py b/airflow/models.py index 32b7d7e..e5bf857 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -65,6 +65,7 @@ from airflow.dag.base_dag import BaseDag, BaseDagBag from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep +from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS from airflow.utils.dates import cron_presets, date_range as utils_date_range @@ -1835,6 +1836,15 @@ class TaskInstance(Base, LoggingMixin): else: return pull_fn(task_id=task_ids) + @provide_session + def get_num_running_task_instances(self, session): + TI = TaskInstance + return session.query(TI).filter( + TI.dag_id == self.dag_id, + TI.task_id == self.task_id, + TI.state == State.RUNNING + ).count() + class TaskFail(Base): """ @@ -2058,6 +2068,9 @@ class BaseOperator(LoggingMixin): :type resources: dict :param run_as_user: unix username to impersonate while running the task :type run_as_user: str + :param task_concurrency: When set, a task will be able to limit the concurrent + runs across execution_dates + :type task_concurrency: int """ # For derived classes to define which fields will get jinjaified @@ -2100,6 +2113,7 @@ class BaseOperator(LoggingMixin): trigger_rule=TriggerRule.ALL_SUCCESS, resources=None, run_as_user=None, + task_concurrency=None, *args, **kwargs): @@ -2165,6 +2179,7 @@ class BaseOperator(LoggingMixin): self.priority_weight = priority_weight self.resources = Resources(**(resources or {})) self.run_as_user = run_as_user + self.task_concurrency = task_concurrency # Private attributes self._upstream_task_ids = [] @@ -4542,8 +4557,9 @@ class DagRun(Base, LoggingMixin): session=session ) none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks) + none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks) # small speed up - if unfinished_tasks and none_depends_on_past: + if unfinished_tasks and none_depends_on_past and none_task_concurrency: # todo: this can actually get pretty slow: one task costs between 0.01-015s no_dependencies_met = True for ut in unfinished_tasks: @@ -4581,7 +4597,8 @@ class DagRun(Base, LoggingMixin): self.state = State.SUCCESS # if *all tasks* are deadlocked, the run failed - elif unfinished_tasks and none_depends_on_past and no_dependencies_met: + elif (unfinished_tasks and none_depends_on_past and + none_task_concurrency and no_dependencies_met): self.log.info('Deadlock; marking run %s failed', self) self.state = State.FAILED http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/ti_deps/dep_context.py ---------------------------------------------------------------------- diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py index 01e01dd..f461a81 100644 --- a/airflow/ti_deps/dep_context.py +++ b/airflow/ti_deps/dep_context.py @@ -19,6 +19,7 @@ from airflow.ti_deps.deps.not_running_dep import NotRunningDep from airflow.ti_deps.deps.not_skipped_dep import NotSkippedDep from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep from airflow.ti_deps.deps.valid_state_dep import ValidStateDep +from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep from airflow.utils.state import State @@ -97,7 +98,8 @@ QUEUE_DEPS = { # Dependencies that need to be met for a given task instance to be able to get run by an # executor. This class just extends QueueContext by adding dependencies for resources. RUN_DEPS = QUEUE_DEPS | { - DagTISlotsAvailableDep() + DagTISlotsAvailableDep(), + TaskConcurrencyDep(), } # TODO(aoen): SCHEDULER_DEPS is not coupled to actual execution in any way and http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/ti_deps/deps/task_concurrency_dep.py ---------------------------------------------------------------------- diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py b/airflow/ti_deps/deps/task_concurrency_dep.py new file mode 100644 index 0000000..99df5ac --- /dev/null +++ b/airflow/ti_deps/deps/task_concurrency_dep.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep +from airflow.utils.db import provide_session + + +class TaskConcurrencyDep(BaseTIDep): + """ + This restricts the number of running task instances for a particular task. + """ + NAME = "Task Concurrency" + IGNOREABLE = True + IS_TASK_DEP = True + + @provide_session + def _get_dep_statuses(self, ti, session, dep_context): + if ti.task.task_concurrency is None: + yield self._passing_status(reason="Task concurrency is not set.") + return + + if ti.get_num_running_task_instances(session) >= ti.task.task_concurrency: + yield self._failing_status(reason="The max task concurrency has been reached.") + return + else: + yield self._passing_status(reason="The max task concurrency has not been reached.") + return http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/airflow/utils/dag_processing.py ---------------------------------------------------------------------- diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py index 3a6cb98..5e92f0e 100644 --- a/airflow/utils/dag_processing.py +++ b/airflow/utils/dag_processing.py @@ -105,6 +105,16 @@ class SimpleDag(BaseDag): """ return self._pickle_id + @property + def task_special_args(self): + return self._task_special_args + + def get_task_special_arg(self, task_id, special_arg_name): + if task_id in self._task_special_args and special_arg_name in self._task_special_args[task_id]: + return self._task_special_args[task_id][special_arg_name] + else: + return None + class SimpleDagBag(BaseDagBag): """ @@ -366,7 +376,7 @@ class DagFileProcessorManager(LoggingMixin): being processed """ if file_path in self._processors: - return (datetime.utcnow() - self._processors[file_path].start_time)\ + return (datetime.utcnow() - self._processors[file_path].start_time) \ .total_seconds() return None @@ -489,8 +499,8 @@ class DagFileProcessorManager(LoggingMixin): for file_path in self._file_paths: last_finish_time = self.get_last_finish_time(file_path) if (last_finish_time is not None and - (now - last_finish_time).total_seconds() < - self._process_file_interval): + (now - last_finish_time).total_seconds() < + self._process_file_interval): file_paths_recently_processed.append(file_path) files_paths_at_run_limit = [file_path @@ -517,7 +527,7 @@ class DagFileProcessorManager(LoggingMixin): # Start more processors if we have enough slots and files to process while (self._parallelism - len(self._processors) > 0 and - len(self._file_path_queue) > 0): + len(self._file_path_queue) > 0): file_path = self._file_path_queue.pop(0) processor = self._processor_factory(file_path) http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/tests/jobs.py ---------------------------------------------------------------------- diff --git a/tests/jobs.py b/tests/jobs.py index f4bbe81..e8fff7e 100644 --- a/tests/jobs.py +++ b/tests/jobs.py @@ -41,7 +41,7 @@ from airflow.utils.dates import days_ago from airflow.utils.db import provide_session from airflow.utils.state import State from airflow.utils.timeout import timeout -from airflow.utils.dag_processing import SimpleDagBag, list_py_file_paths +from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths from mock import Mock, patch from sqlalchemy.orm.session import make_transient @@ -935,7 +935,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -965,7 +965,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -990,7 +990,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1015,7 +1015,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1055,7 +1055,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a') task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b') - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1096,7 +1096,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1114,7 +1114,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1155,6 +1155,98 @@ class SchedulerJobTest(unittest.TestCase): self.assertEqual(0, len(res)) + def test_find_executable_task_instances_task_concurrency(self): + dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency' + task_id_1 = 'dummy' + task_id_2 = 'dummy2' + dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) + task1 = DummyOperator(dag=dag, task_id=task_id_1, task_concurrency=2) + task2 = DummyOperator(dag=dag, task_id=task_id_2) + dagbag = self._make_simple_dag_bag([dag]) + + scheduler = SchedulerJob(**self.default_scheduler_args) + session = settings.Session() + + dr1 = scheduler.create_dag_run(dag) + dr2 = scheduler.create_dag_run(dag) + dr3 = scheduler.create_dag_run(dag) + + ti1_1 = TI(task1, dr1.execution_date) + ti2 = TI(task2, dr1.execution_date) + + ti1_1.state = State.SCHEDULED + ti2.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti2) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(2, len(res)) + + ti1_1.state = State.RUNNING + ti2.state = State.RUNNING + ti1_2 = TI(task1, dr2.execution_date) + ti1_2.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti2) + session.merge(ti1_2) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(1, len(res)) + + ti1_2.state = State.RUNNING + ti1_3 = TI(task1, dr3.execution_date) + ti1_3.state = State.SCHEDULED + session.merge(ti1_2) + session.merge(ti1_3) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(0, len(res)) + + ti1_1.state = State.SCHEDULED + ti1_2.state = State.SCHEDULED + ti1_3.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti1_2) + session.merge(ti1_3) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(2, len(res)) + + ti1_1.state = State.RUNNING + ti1_2.state = State.SCHEDULED + ti1_3.state = State.SCHEDULED + session.merge(ti1_1) + session.merge(ti1_2) + session.merge(ti1_3) + session.commit() + + res = scheduler._find_executable_task_instances( + dagbag, + states=[State.SCHEDULED], + session=session) + + self.assertEqual(1, len(res)) + def test_change_state_for_executable_task_instances_no_tis(self): scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1166,7 +1258,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1198,7 +1290,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1234,7 +1326,7 @@ class SchedulerJobTest(unittest.TestCase): task_id_1 = 'dummy' dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE) task1 = DummyOperator(dag=dag, task_id=task_id_1) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1279,7 +1371,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3) task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) session = settings.Session() @@ -1340,7 +1432,7 @@ class SchedulerJobTest(unittest.TestCase): dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16) task1 = DummyOperator(dag=dag, task_id=task_id_1) task2 = DummyOperator(dag=dag, task_id=task_id_2) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(**self.default_scheduler_args) scheduler.max_tis_per_query = 3 @@ -1407,16 +1499,18 @@ class SchedulerJobTest(unittest.TestCase): ti2.state = State.SCHEDULED session.commit() - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) scheduler = SchedulerJob(num_runs=0, run_duration=0) scheduler._change_state_for_tis_without_dagrun(simple_dag_bag=dagbag, old_states=[State.SCHEDULED, State.QUEUED], new_state=State.NONE, session=session) + ti = dr.get_task_instance(task_id='dummy', session=session) ti.refresh_from_db(session=session) self.assertEqual(ti.state, State.SCHEDULED) + ti2 = dr2.get_task_instance(task_id='dummy', session=session) ti2.refresh_from_db(session=session) self.assertEqual(ti2.state, State.SCHEDULED) @@ -2039,7 +2133,7 @@ class SchedulerJobTest(unittest.TestCase): queue = [] scheduler._process_task_instances(dag, queue=queue) self.assertEquals(len(queue), 2) - dagbag = SimpleDagBag([dag]) + dagbag = self._make_simple_dag_bag([dag]) # Recreated part of the scheduler here, to kick off tasks -> executor for ti_key in queue: http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/tests/models.py ---------------------------------------------------------------------- diff --git a/tests/models.py b/tests/models.py index db5beca..a1de17d 100644 --- a/tests/models.py +++ b/tests/models.py @@ -517,6 +517,39 @@ class DagRunTest(unittest.TestCase): dr.update_state() self.assertEqual(dr.state, State.FAILED) + def test_dagrun_no_deadlock(self): + session = settings.Session() + dag = DAG('test_dagrun_no_deadlock', + start_date=DEFAULT_DATE) + with dag: + op1 = DummyOperator(task_id='dop', depends_on_past=True) + op2 = DummyOperator(task_id='tc', task_concurrency=1) + + dag.clear() + dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1', + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE) + dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2', + state=State.RUNNING, + execution_date=DEFAULT_DATE + datetime.timedelta(days=1), + start_date=DEFAULT_DATE + datetime.timedelta(days=1)) + ti1_op1 = dr.get_task_instance(task_id='dop') + ti2_op1 = dr2.get_task_instance(task_id='dop') + ti2_op1 = dr.get_task_instance(task_id='tc') + ti2_op2 = dr.get_task_instance(task_id='tc') + ti1_op1.set_state(state=State.RUNNING, session=session) + dr.update_state() + dr2.update_state() + self.assertEqual(dr.state, State.RUNNING) + self.assertEqual(dr2.state, State.RUNNING) + + ti2_op1.set_state(state=State.RUNNING, session=session) + dr.update_state() + dr2.update_state() + self.assertEqual(dr.state, State.RUNNING) + self.assertEqual(dr2.state, State.RUNNING) + def test_get_task_instance_on_empty_dagrun(self): """ Make sure that a proper value is returned when a dagrun has no task instances @@ -1201,6 +1234,29 @@ class TaskInstanceTest(unittest.TestCase): ti = TI( task=task2, execution_date=datetime.datetime.now()) self.assertFalse(ti._check_and_change_state_before_execution()) + + def test_get_num_running_task_instances(self): + session = settings.Session() + + dag = models.DAG(dag_id='test_get_num_running_task_instances') + dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy') + task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE) + task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE) + + ti1 = TI(task=task, execution_date=DEFAULT_DATE) + ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1)) + ti3 = TI(task=task2, execution_date=DEFAULT_DATE) + ti1.state = State.RUNNING + ti2.state = State.QUEUED + ti3.state = State.RUNNING + session.add(ti1) + session.add(ti2) + session.add(ti3) + session.commit() + + self.assertEquals(1, ti1.get_num_running_task_instances(session=session)) + self.assertEquals(1, ti2.get_num_running_task_instances(session=session)) + self.assertEquals(1, ti3.get_num_running_task_instances(session=session)) class ClearTasksTest(unittest.TestCase): http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/87afe890/tests/ti_deps/deps/test_task_concurrency.py ---------------------------------------------------------------------- diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py new file mode 100644 index 0000000..77a5990 --- /dev/null +++ b/tests/ti_deps/deps/test_task_concurrency.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime +from mock import Mock + +from airflow.models import DAG, BaseOperator +from airflow.ti_deps.dep_context import DepContext +from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep +from airflow.utils.state import State + + +class TaskConcurrencyDepTest(unittest.TestCase): + + def _get_task(self, **kwargs): + return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs) + + def test_not_task_concurrency(self): + task = self._get_task(start_date=datetime(2016, 1, 1)) + dep_context = DepContext() + ti = Mock(task=task, execution_date=datetime(2016, 1, 1)) + self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) + + def test_not_reached_concurrency(self): + task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=1) + dep_context = DepContext() + ti = Mock(task=task, execution_date=datetime(2016, 1, 1)) + ti.get_num_running_task_instances = lambda x: 0 + self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) + + def test_reached_concurrency(self): + task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=2) + dep_context = DepContext() + ti = Mock(task=task, execution_date=datetime(2016, 1, 1)) + ti.get_num_running_task_instances = lambda x: 1 + self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) + ti.get_num_running_task_instances = lambda x: 2 + self.assertFalse(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context)) +
