Repository: incubator-ariatosca Updated Branches: refs/heads/wf-executor 7334cc32b -> 73f327df7 (forced update)
Add basic executor mechanism Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/73f327df Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/73f327df Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/73f327df Branch: refs/heads/wf-executor Commit: 73f327df74d09f3817370b02b71a93d17619511c Parents: f2f4131 Author: Dan Kilman <dankil...@gmail.com> Authored: Thu Oct 13 15:44:58 2016 +0300 Committer: Dan Kilman <dankil...@gmail.com> Committed: Tue Oct 18 21:30:57 2016 +0300 ---------------------------------------------------------------------- aria/events/__init__.py | 12 +- aria/events/builtin_event_handler.py | 84 +++++++++++ aria/events/builtin_event_handlers.py | 44 ------ aria/events/workflow_engine_event_handler.py | 70 ++++----- aria/storage/models.py | 7 +- aria/tools/module.py | 29 ++++ aria/workflows/engine/engine.py | 132 +++-------------- aria/workflows/engine/executor.py | 165 +++++++++++++++++----- tests/workflows/test_executor.py | 118 ++++++++++++++++ 9 files changed, 413 insertions(+), 248 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/events/__init__.py ---------------------------------------------------------------------- diff --git a/aria/events/__init__.py b/aria/events/__init__.py index 70e7e03..c9d7b20 100644 --- a/aria/events/__init__.py +++ b/aria/events/__init__.py @@ -20,25 +20,15 @@ from blinker import signal from ..tools.plugin import plugin_installer -# workflow engine default signals: +# workflow engine task signals: start_task_signal = signal('start_task_signal') -end_task_signal = signal('end_task_signal') on_success_task_signal = signal('success_task_signal') on_failure_task_signal = signal('failure_task_signal') # workflow engine workflow signals: start_workflow_signal = signal('start_workflow_signal') -end_workflow_signal = signal('end_workflow_signal') on_success_workflow_signal = signal('on_success_workflow_signal') on_failure_workflow_signal = signal('on_failure_workflow_signal') -start_sub_workflow_signal = signal('start_sub_workflow_signal') -end_sub_workflow_signal = signal('end_sub_workflow_signal') - -# workflow engine operation signals: -start_operation_signal = signal('start_operation_signal') -end_operation_signal = signal('end_operation_signal') -on_success_operation_signal = signal('on_success_operation_signal') -on_failure_operation_signal = signal('on_failure_operation_signal') plugin_installer( path=os.path.dirname(os.path.realpath(__file__)), http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/events/builtin_event_handler.py ---------------------------------------------------------------------- diff --git a/aria/events/builtin_event_handler.py b/aria/events/builtin_event_handler.py new file mode 100644 index 0000000..404cc01 --- /dev/null +++ b/aria/events/builtin_event_handler.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime + +from . import ( + start_workflow_signal, + on_success_workflow_signal, + on_failure_workflow_signal, + start_task_signal, + on_success_task_signal, + on_failure_task_signal, +) + + +@start_task_signal.connect +def _task_started(task, *args, **kwargs): + operation_context = task.context + operation = operation_context.operation + operation.started_at = datetime.utcnow() + operation.status = operation.STARTED + operation_context.operation = operation + + +@on_failure_task_signal.connect +def _task_failed(task, *args, **kwargs): + operation_context = task.context + operation = operation_context.operation + operation.ended_at = datetime.utcnow() + operation.status = operation.FAILED + operation_context.operation = operation + + +@on_success_task_signal.connect +def _task_succeeded(task, *args, **kwargs): + operation_context = task.context + operation = operation_context.operation + operation.ended_at = datetime.utcnow() + operation.status = operation.SUCCESS + operation_context.operation = operation + + +@start_workflow_signal.connect +def _workflow_started(workflow_context, *args, **kwargs): + Execution = workflow_context.storage.execution.model_cls + execution = Execution( + id=workflow_context.execution_id, + deployment_id=workflow_context.deployment_id, + workflow_id=workflow_context.workflow_id, + blueprint_id=workflow_context.blueprint_id, + status=Execution.PENDING, + started_at=datetime.utcnow(), + parameters=workflow_context.parameters, + ) + workflow_context.execution = execution + + +@on_failure_workflow_signal.connect +def _workflow_failed(workflow_context, exception, *args, **kwargs): + execution = workflow_context.execution + execution.error = str(exception) + execution.status = execution.FAILED + execution.ended_at = datetime.utcnow(), + workflow_context.execution = execution + + +@on_success_workflow_signal.connect +def _workflow_succeeded(workflow_context, *args, **kwargs): + execution = workflow_context.execution + execution.status = execution.TERMINATED + execution.ended_at = datetime.utcnow(), + workflow_context.execution = execution http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/events/builtin_event_handlers.py ---------------------------------------------------------------------- diff --git a/aria/events/builtin_event_handlers.py b/aria/events/builtin_event_handlers.py deleted file mode 100644 index 59f59c1..0000000 --- a/aria/events/builtin_event_handlers.py +++ /dev/null @@ -1,44 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from ..storage.models import NodeInstance -from . import start_operation_signal - - -class _OperationToNodeInstanceState(dict): - def __missing__(self, key): - for cached_key, value in self.items(): - if key.startswith(cached_key): - return value - raise KeyError(key) - -_operation_to_node_instance_state = _OperationToNodeInstanceState({ - 'cloudify.interfaces.lifecycle.create': NodeInstance.INITIALIZING, - 'cloudify.interfaces.lifecycle.configure': NodeInstance.CONFIGURING, - 'cloudify.interfaces.lifecycle.start': NodeInstance.STARTING, - 'cloudify.interfaces.lifecycle.stop': NodeInstance.STOPPING, - 'cloudify.interfaces.lifecycle.delete': NodeInstance.DELETING -}) - - -@start_operation_signal.connect -def _update_node_instance_state(sender, **kwargs): - try: - next_state = _operation_to_node_instance_state[sender.task_name] - except KeyError: - return - node_instance = sender.context.node_instance - node_instance.state = next_state - sender.context.storage.node_instance.store(node_instance) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/events/workflow_engine_event_handler.py ---------------------------------------------------------------------- diff --git a/aria/events/workflow_engine_event_handler.py b/aria/events/workflow_engine_event_handler.py index 59bed99..6916206 100644 --- a/aria/events/workflow_engine_event_handler.py +++ b/aria/events/workflow_engine_event_handler.py @@ -14,61 +14,47 @@ # limitations under the License. from . import ( - start_operation_signal, - end_operation_signal, - on_success_operation_signal, - on_failure_operation_signal, + start_task_signal, + on_success_task_signal, + on_failure_task_signal, start_workflow_signal, - end_workflow_signal, - start_sub_workflow_signal, - end_sub_workflow_signal, + on_success_workflow_signal, + on_failure_workflow_signal ) -@start_operation_signal.connect -def start_operation_handler(sender, **kwargs): - sender.context.logger.debug( - 'Event - starting operation: {sender.task_name}'.format(sender=sender)) +@start_task_signal.connect +def start_task_handler(task, **kwargs): + task.logger.debug( + 'Event: Starting task: {task.name}'.format(task=task)) -@end_operation_signal.connect -def end_operation_handler(sender, **kwargs): - sender.context.logger.debug( - 'Event - finished operation: {sender.task_name}'.format(sender=sender)) +@on_success_task_signal.connect +def success_task_handler(task, **kwargs): + task.logger.debug( + 'Event: Task success: {task.name}'.format(task=task)) -@on_success_operation_signal.connect -def success_operation_handler(sender, **kwargs): - sender.context.logger.debug( - 'Event - operation success: {sender.task_name}'.format(sender=sender)) - - -@on_failure_operation_signal.connect -def failure_operation_handler(sender, **kwargs): - sender.context.logger.error( - 'Event - operation failure: {sender.task_name}'.format(sender=sender), +@on_failure_task_signal.connect +def failure_operation_handler(task, **kwargs): + task.logger.error( + 'Event: Task failure: {task.name}'.format(task=task), exc_info=kwargs.get('exception', True)) @start_workflow_signal.connect -def start_workflow_handler(sender, **kwargs): - sender.context.logger.debug( - 'Event - starting workflow: {sender.task_name}'.format(sender=sender)) - - -@end_workflow_signal.connect -def end_workflow_handler(sender, **kwargs): - sender.context.logger.debug( - 'Event - finished workflow: {sender.task_name}'.format(sender=sender)) +def start_workflow_handler(context, **kwargs): + context.logger.debug( + 'Event: Starting workflow: {context.name}'.format(context=context)) -@start_sub_workflow_signal.connect -def start_sub_workflow_handler(sender, **kwargs): - sender.context.logger.debug( - 'Event - starting sub workflow: {sender.task_name}'.format(sender=sender)) +@on_failure_workflow_signal.connect +def failure_workflow_handler(context, **kwargs): + context.logger.debug( + 'Event: Workflow failure: {context.name}'.format(context=context)) -@end_sub_workflow_signal.connect -def end_sub_workflow_handler(sender, **kwargs): - sender.context.logger.debug( - 'Event - finished sub workflow: {sender.task_name}'.format(sender=sender)) +@on_success_workflow_signal.connect +def success_workflow_handler(context, **kwargs): + context.logger.debug( + 'Event: Workflow success: {context.name}'.format(context=context)) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/storage/models.py ---------------------------------------------------------------------- diff --git a/aria/storage/models.py b/aria/storage/models.py index d3cb3f7..d96c74a 100644 --- a/aria/storage/models.py +++ b/aria/storage/models.py @@ -191,10 +191,11 @@ class Execution(Model): deployment_id = Field(type=basestring) workflow_id = Field(type=basestring) blueprint_id = Field(type=basestring) - created_at = Field(type=datetime) - error = Field() + started_at = Field(type=datetime) + ended_at = Field(type=datetime, default=None) + error = Field(type=basestring, default=None) parameters = Field() - is_system_workflow = Field(type=bool) + is_system_workflow = Field(type=bool, default=False) class Operation(Model): http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/tools/module.py ---------------------------------------------------------------------- diff --git a/aria/tools/module.py b/aria/tools/module.py new file mode 100644 index 0000000..535f7aa --- /dev/null +++ b/aria/tools/module.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib + + +def load_attribute(attribute_path): + module_name, attribute_name = attribute_path.rsplit('.', 1) + try: + module = importlib.import_module(module_name) + return getattr(module, attribute_name) + except ImportError: + # TODO: handle + raise + except AttributeError: + # TODO: handle + raise http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/workflows/engine/engine.py ---------------------------------------------------------------------- diff --git a/aria/workflows/engine/engine.py b/aria/workflows/engine/engine.py index 508ae3b..7cc4781 100644 --- a/aria/workflows/engine/engine.py +++ b/aria/workflows/engine/engine.py @@ -14,57 +14,42 @@ # limitations under the License. import time -from datetime import datetime -from contextlib import contextmanager -from networkx import DiGraph +import networkx -from aria.events import ( - start_workflow_signal, - on_success_workflow_signal, - on_failure_workflow_signal, - start_task_signal, - on_success_task_signal, - on_failure_task_signal, -) -from aria.logger import LoggerMixin +from aria import events, logger -from .translation import build_execution_graph +from . import translation -from ...storage import Model - - -class Engine(LoggerMixin): +class Engine(logger.LoggerMixin): def __init__(self, executor, workflow_context, tasks_graph, **kwargs): super(Engine, self).__init__(**kwargs) self._workflow_context = workflow_context self._tasks_graph = tasks_graph - self._execution_graph = DiGraph() + self._execution_graph = networkx.DiGraph() self._executor = executor - build_execution_graph(task_graph=self._tasks_graph, - workflow_context=workflow_context, - execution_graph=self._execution_graph) + translation.build_execution_graph(task_graph=self._tasks_graph, + workflow_context=workflow_context, + execution_graph=self._execution_graph) def execute(self): - execution_id = self._workflow_context.execution_id - with self._connect_signals(): - try: - start_workflow_signal.send(self, execution_id=execution_id) - while True: - for task in self._ended_tasks(): - self._handle_ended_tasks(task) - for task in self._executable_tasks(): - self._handle_executable_task(task) - if self._all_tasks_consumed(): - break - else: - time.sleep(0.1) - on_success_workflow_signal.send(self, execution_id=execution_id) - except BaseException as e: - on_failure_workflow_signal.send(self, execution_id=execution_id, exception=e) - raise + try: + events.start_workflow_signal.send(self._workflow_context) + while True: + for task in self._ended_tasks(): + self._handle_ended_tasks(task) + for task in self._executable_tasks(): + self._handle_executable_task(task) + if self._all_tasks_consumed(): + break + else: + time.sleep(0.1) + events.on_success_workflow_signal.send(self._workflow_context) + except BaseException as e: + events.on_failure_workflow_signal.send(self._workflow_context, exception=e) + raise def _executable_tasks(self): now = time.time() @@ -85,9 +70,6 @@ class Engine(LoggerMixin): def _tasks_iter(self): return (data['task'] for _, data in self._execution_graph.nodes_iter(data=True)) - def _get_task(self, task_id): - return self._execution_graph.node[task_id]['task'] - def _handle_executable_task(self, task): self._executor.execute(task) @@ -96,71 +78,3 @@ class Engine(LoggerMixin): raise RuntimeError('Workflow failed') else: self._execution_graph.remove_node(task.id) - - def _task_started_receiver(self, task_id, *args, **kwargs): - task = self._get_task(task_id) - operation_context = task.operation_context - operation = operation_context.operation - operation.started_at = datetime.utcnow() - operation.status = operation.STARTED - operation_context.operation = operation - - def _task_failed_receiver(self, task_id, *args, **kwargs): - task = self._get_task(task_id) - operation_context = task.operation_context - operation = operation_context.operation - operation.ended_at = datetime.utcnow() - operation.status = operation.FAILED - operation_context.operation = operation - - def _task_succeeded_receiver(self, task_id, *args, **kwargs): - task = self._get_task(task_id) - operation_context = task.operation_context - operation = operation_context.operation - operation.ended_at = datetime.utcnow() - operation.status = operation.SUCCESS - operation_context.operation = operation - - def _start_workflow_receiver(self, *args, **kwargs): - Execution = self._workflow_context.storage.execution.model_cls - execution = Execution( - id=self._workflow_context.execution_id, - deployment_id=self._workflow_context.deployment_id, - workflow_id=self._workflow_context.workflow_id, - blueprint_id=self._workflow_context.blueprint_id, - status=Execution.PENDING, - created_at=datetime.utcnow(), - error='', - parameters=self._workflow_context.parameters, - is_system_workflow=False - ) - self._workflow_context.execution = execution - - def _workflow_failed_receiver(self, exception, *args, **kwargs): - execution = self._workflow_context.execution - execution.error = str(exception) - execution.status = execution.FAILED - self._workflow_context.execution = execution - - def _workflow_succeeded_receiver(self, *args, **kwargs): - execution = self._workflow_context.execution - execution.status = execution.TERMINATED - self._workflow_context.execution = execution - - @contextmanager - def _connect_signals(self): - start_workflow_signal.connect(self._start_workflow_receiver) - on_success_workflow_signal.connect(self._workflow_succeeded_receiver) - on_failure_workflow_signal.connect(self._workflow_failed_receiver) - start_task_signal.connect(self._task_started_receiver) - on_success_task_signal.connect(self._task_succeeded_receiver) - on_failure_task_signal.connect(self._task_failed_receiver) - try: - yield - finally: - start_workflow_signal.disconnect(self._start_workflow_receiver) - on_success_workflow_signal.disconnect(self._workflow_succeeded_receiver) - on_failure_workflow_signal.disconnect(self._workflow_failed_receiver) - start_task_signal.disconnect(self._task_started_receiver) - on_success_task_signal.disconnect(self._task_succeeded_receiver) - on_failure_task_signal.disconnect(self._task_failed_receiver) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/aria/workflows/engine/executor.py ---------------------------------------------------------------------- diff --git a/aria/workflows/engine/executor.py b/aria/workflows/engine/executor.py index dacfc15..4a70920 100644 --- a/aria/workflows/engine/executor.py +++ b/aria/workflows/engine/executor.py @@ -14,74 +14,161 @@ # limitations under the License. import threading +import multiprocessing import Queue -from importlib import import_module -from aria.events import ( - start_task_signal, - on_success_task_signal, - on_failure_task_signal, -) +import jsonpickle + +from aria import events +from aria.tools import module class Executor(object): + def __init__(self, *args, **kwargs): + pass + def execute(self, task): raise NotImplementedError - def task_started(self, task_id): - start_task_signal.send(self, task_id=task_id) + def close(self): + pass + + @staticmethod + def _task_started(task): + events.start_task_signal.send(task) + + @staticmethod + def _task_failed(task, exception): + events.on_failure_task_signal.send(task, exception=exception) + + @staticmethod + def _task_succeeded(task): + events.on_success_task_signal.send(task) + - def task_failed(self, task_id, exception): - on_failure_task_signal.send(self, task_id=task_id, exception=exception) +class CurrentThreadBlockingExecutor(Executor): - def task_succeeded(self, task_id): - on_success_task_signal.send(self, task_id=task_id) + def execute(self, task): + self._task_started(task) + try: + operation_context = task.context + task_func = module.load_attribute(operation_context.operation_details['operation']) + task_func(**operation_context.inputs) + self._task_succeeded(task) + except BaseException as e: + self._task_failed(task, exception=e) -class LocalThreadExecutor(Executor): +class ThreadExecutor(Executor): - def __init__(self, pool_size=1): - self.stopped = False - self.queue = Queue.Queue() - self.pool = [] + def __init__(self, pool_size=1, *args, **kwargs): + super(ThreadExecutor, self).__init__(*args, **kwargs) + self._stopped = False + self._queue = Queue.Queue() + self._pool = [] for i in range(pool_size): - name = 'LocalThreadExecutor-{index}'.format(index=i+1) + name = 'ThreadExecutor-{index}'.format(index=i+1) thread = threading.Thread(target=self._processor, name=name) thread.daemon = True thread.start() - self.pool.append(thread) + self._pool.append(thread) def execute(self, task): - self.queue.put(task) + self._queue.put(task) def close(self): - self.stopped = True + self._stopped = True + for thread in self._pool: + thread.join() def _processor(self): - while not self.stopped: + while not self._stopped: try: - task = self.queue.get(timeout=1) - self.task_started(task.id) + task = self._queue.get(timeout=1) + self._task_started(task) try: - operation_context = task.operation_context - task_func = self._load_task(operation_context.operation_details['operation']) + operation_context = task.context + task_func = module.load_attribute( + operation_context.operation_details['operation']) task_func(**operation_context.inputs) - self.task_succeeded(task.id) + self._task_succeeded(task) except BaseException as e: - self.task_failed(task.id, exception=e) + self._task_failed(task, exception=e) # Daemon threads except: pass - def _load_task(self, handler_path): - module_name, spec_handler_name = handler_path.rsplit('.', 1) - try: - module = import_module(module_name) - return getattr(module, spec_handler_name) - except ImportError: - # TODO: handle - raise - except AttributeError: - # TODO: handle - raise + +class MultiprocessExecutor(Executor): + + def __init__(self, pool_size=1, *args, **kwargs): + super(MultiprocessExecutor, self).__init__(*args, **kwargs) + self._stopped = False + self._manager = multiprocessing.Manager() + self._queue = self._manager.Queue() + self._tasks = {} + self._listener = threading.Thread(target=self._listener) + self._listener.daemon = True + self._listener.start() + self._pool = multiprocessing.Pool(processes=pool_size, + maxtasksperchild=1) + + def execute(self, task): + self._tasks[task.id] = task + self._pool.apply_async(_multiprocess_handler, args=( + self._queue, + task.id, + task.context.operation_details, + task.context.inputs)) + + def close(self): + self._pool.close() + self._stopped = True + self._pool.join() + self._listener.join() + + def _listener(self): + while not self._stopped: + try: + message = self._queue.get(timeout=1) + if message.type == 'task_started': + self._task_started(self._tasks[message.task_id]) + elif message.type == 'task_succeeded': + self._task_succeeded(self._remove_task(message.task_id)) + elif message.type == 'task_failed': + self._task_failed(self._remove_task(message.task_id), + exception=jsonpickle.loads(message.exception)) + else: + # TODO: something + raise RuntimeError() + # Daemon threads + except: + pass + + def _remove_task(self, task_id): + return self._tasks.pop(task_id) + + +class _MultiprocessMessage(object): + + def __init__(self, type, task_id, exception): + self.type = type + self.task_id = task_id + self.exception = exception + + +def _multiprocess_handler(queue, task_id, operation_details, operation_inputs): + queue.put(_MultiprocessMessage(type='task_started', + task_id=task_id, + exception=None)) + try: + task_func = module.load_attribute(operation_details['operation']) + task_func(**operation_inputs) + queue.put(_MultiprocessMessage(type='task_succeeded', + task_id=task_id, + exception=None)) + except BaseException as e: + queue.put(_MultiprocessMessage(type='task_failed', + task_id=task_id, + exception=jsonpickle.dumps(e))) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/73f327df/tests/workflows/test_executor.py ---------------------------------------------------------------------- diff --git a/tests/workflows/test_executor.py b/tests/workflows/test_executor.py new file mode 100644 index 0000000..7419f2f --- /dev/null +++ b/tests/workflows/test_executor.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import uuid + +import pytest +import retrying + +from aria import events +from aria.storage import models +from aria.workflows.engine import executor + + +class TestExecutor(object): + + @pytest.mark.parametrize('pool_size,executor_cls', [ + (1, executor.ThreadExecutor), + (2, executor.ThreadExecutor), + (1, executor.MultiprocessExecutor), + (2, executor.MultiprocessExecutor), + (0, executor.CurrentThreadBlockingExecutor) + ]) + def test_execute(self, pool_size, executor_cls): + self.executor = executor_cls(pool_size) + expected_value = 'value' + successful_task = MockTask(mock_successful_task) + failing_task = MockTask(mock_failing_task) + task_with_inputs = MockTask(mock_task_with_input, inputs={'input': expected_value}) + + for task in [successful_task, failing_task, task_with_inputs]: + self.executor.execute(task) + + @retrying.retry(stop_max_delay=10000, wait_fixed=100) + def assertion(): + assert successful_task.states == ['start', 'success'] + assert failing_task.states == ['start', 'failure'] + assert task_with_inputs.states == ['start', 'failure'] + assert isinstance(failing_task.exception, TestException) + assert isinstance(task_with_inputs.exception, TestException) + assert task_with_inputs.exception.message == expected_value + assertion() + + def setup_method(self): + self.executor = None + events.start_task_signal.connect(start_handler) + events.on_success_task_signal.connect(success_handler) + events.on_failure_task_signal.connect(failure_handler) + + def teardown_method(self): + events.start_task_signal.disconnect(start_handler) + events.on_success_task_signal.disconnect(success_handler) + events.on_failure_task_signal.disconnect(failure_handler) + if self.executor: + self.executor.close() + + +def mock_successful_task(): + pass + + +def mock_failing_task(): + raise TestException + + +def mock_task_with_input(input): + raise TestException(input) + + +class TestException(Exception): + pass + + +class MockContext(object): + + def __init__(self, operation_details, inputs): + self.operation_details = operation_details + self.inputs = inputs + self.operation = models.Operation(execution_id='') + + +class MockTask(object): + + def __init__(self, func, inputs=None): + self.states = [] + self.exception = None + self.id = str(uuid.uuid4()) + name = func.__name__ + operation = 'tests.workflows.test_executor.{name}'.format(name=name) + self.context = MockContext(operation_details={'operation': operation}, + inputs=inputs or {}) + self.logger = logging.getLogger() + self.name = name + + +def start_handler(task, *args, **kwargs): + task.states.append('start') + + +def success_handler(task, *args, **kwargs): + task.states.append('success') + + +def failure_handler(task, exception, *args, **kwargs): + task.states.append('failure') + task.exception = exception