ARIA-79-concurrent-modifications
Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/4a72e113 Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/4a72e113 Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/4a72e113 Branch: refs/heads/ARIA-79-concurrent-storage-modifications Commit: 4a72e113b48b2aafdf778d641182a3a8a2c60d1a Parents: b619335 Author: Dan Kilman <d...@gigaspaces.com> Authored: Mon Jan 30 16:49:00 2017 +0200 Committer: mxmrlv <mxm...@gmail.com> Committed: Thu Feb 16 15:03:40 2017 +0200 ---------------------------------------------------------------------- aria/orchestrator/workflows/executor/process.py | 156 ++++++++++------- aria/storage/instrumentation.py | 30 +++- aria/storage/sql_mapi.py | 4 + ...process_executor_concurrent_modifications.py | 174 +++++++++++++++++++ tests/requirements.txt | 3 +- 5 files changed, 307 insertions(+), 60 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/aria/orchestrator/workflows/executor/process.py ---------------------------------------------------------------------- diff --git a/aria/orchestrator/workflows/executor/process.py b/aria/orchestrator/workflows/executor/process.py index 560ac43..a23e3da 100644 --- a/aria/orchestrator/workflows/executor/process.py +++ b/aria/orchestrator/workflows/executor/process.py @@ -74,6 +74,13 @@ class ProcessExecutor(base.BaseExecutor): # Contains reference to all currently running tasks self._tasks = {} + self._request_handlers = { + 'started': self._handle_task_started_request, + 'succeeded': self._handle_task_succeeded_request, + 'failed': self._handle_task_failed_request, + 'apply_tracked_changes': self._handle_apply_tracked_changes_request + } + # Server socket used to accept task status messages from subprocesses self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._server_socket.bind(('localhost', 0)) @@ -131,58 +138,6 @@ class ProcessExecutor(base.BaseExecutor): def _remove_task(self, task_id): return self._tasks.pop(task_id) - def _listener(self): - # Notify __init__ method this thread has actually started - self._listener_started.put(True) - while not self._stopped: - try: - # Accept messages written to the server socket - with contextlib.closing(self._server_socket.accept()[0]) as connection: - message = self._recv_message(connection) - message_type = message['type'] - if message_type == 'closed': - break - task_id = message['task_id'] - if message_type == 'started': - self._task_started(self._tasks[task_id]) - elif message_type == 'apply_tracked_changes': - task = self._tasks[task_id] - instrumentation.apply_tracked_changes( - tracked_changes=message['tracked_changes'], - model=task.context.model) - elif message_type == 'succeeded': - task = self._remove_task(task_id) - instrumentation.apply_tracked_changes( - tracked_changes=message['tracked_changes'], - model=task.context.model) - self._task_succeeded(task) - elif message_type == 'failed': - task = self._remove_task(task_id) - instrumentation.apply_tracked_changes( - tracked_changes=message['tracked_changes'], - model=task.context.model) - self._task_failed(task, exception=message['exception']) - else: - raise RuntimeError('Invalid state') - except BaseException as e: - self.logger.debug('Error in process executor listener: {0}'.format(e)) - - def _recv_message(self, connection): - message_len, = struct.unpack(_INT_FMT, self._recv_bytes(connection, _INT_SIZE)) - return jsonpickle.loads(self._recv_bytes(connection, message_len)) - - @staticmethod - def _recv_bytes(connection, count): - result = io.BytesIO() - while True: - if not count: - return result.getvalue() - read = connection.recv(count) - if not read: - return result.getvalue() - result.write(read) - count -= len(read) - def _check_closed(self): if self._stopped: raise RuntimeError('Executor closed') @@ -231,6 +186,87 @@ class ProcessExecutor(base.BaseExecutor): os.pathsep, env.get('PYTHONPATH', '')) + def _listener(self): + # Notify __init__ method this thread has actually started + self._listener_started.put(True) + while not self._stopped: + try: + with self._accept_request() as (request, response): + request_type = request['type'] + if request_type == 'closed': + break + request_handler = self._request_handlers.get(request_type) + if not request_handler: + raise RuntimeError('Invalid request type: {0}'.format(request_type)) + request_handler(task_id=request['task_id'], request=request, response=response) + except BaseException as e: + self.logger.debug('Error in process executor listener: {0}'.format(e)) + + @contextlib.contextmanager + def _accept_request(self): + with contextlib.closing(self._server_socket.accept()[0]) as connection: + message = _recv_message(connection) + response = {} + yield message, response + _send_message(connection, response) + + def _handle_task_started_request(self, task_id, **kwargs): + self._task_started(self._tasks[task_id]) + + def _handle_task_succeeded_request(self, task_id, request, **kwargs): + task = self._remove_task(task_id) + try: + self._apply_tracked_changes(task, request) + except BaseException as e: + self._task_failed(task, exception=e) + else: + self._task_succeeded(task) + + def _handle_task_failed_request(self, task_id, request, **kwargs): + task = self._remove_task(task_id) + try: + self._apply_tracked_changes(task, request) + except BaseException as e: + self._task_failed(task, exception=e) + else: + self._task_failed(task, exception=request['exception']) + + def _handle_apply_tracked_changes_request(self, task_id, request, response): + task = self._tasks[task_id] + try: + self._apply_tracked_changes(task, request) + except BaseException as e: + response['exception'] = exceptions.wrap_if_needed(e) + + @staticmethod + def _apply_tracked_changes(task, request): + instrumentation.apply_tracked_changes( + tracked_changes=request['tracked_changes'], + model=task.context.model) + + +def _send_message(connection, message): + data = jsonpickle.dumps(message) + connection.send(struct.pack(_INT_FMT, len(data))) + connection.sendall(data) + + +def _recv_message(connection): + message_len, = struct.unpack(_INT_FMT, _recv_bytes(connection, _INT_SIZE)) + return jsonpickle.loads(_recv_bytes(connection, message_len)) + + +def _recv_bytes(connection, count): + result = io.BytesIO() + while True: + if not count: + return result.getvalue() + read = connection.recv(count) + if not read: + return result.getvalue() + result.write(read) + count -= len(read) + class _Messenger(object): @@ -261,17 +297,16 @@ class _Messenger(object): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', self.port)) try: - data = jsonpickle.dumps({ + _send_message(sock, { 'type': type, 'task_id': self.task_id, 'exception': exceptions.wrap_if_needed(exception), 'tracked_changes': tracked_changes }) - sock.send(struct.pack(_INT_FMT, len(data))) - sock.sendall(data) - # send message will block until the server side closes the connection socket - # because we want it to be synchronous - sock.recv(1) + response = _recv_message(sock) + response_exception = response.get('exception') + if response_exception: + raise response_exception finally: sock.close() @@ -294,12 +329,17 @@ def _patch_session(ctx, messenger, instrument): messenger.apply_tracked_changes(instrument.tracked_changes) instrument.clear() + def patched_rollback(): + # Rollback is performed on parent process when commit fails + pass + # when autoflush is set to true (the default), refreshing an object will trigger # an auto flush by sqlalchemy, this autoflush will attempt to commit changes made so # far on the session. this is not the desired behavior in the subprocess session.autoflush = False session.commit = patched_commit + session.rollback = patched_rollback session.refresh = patched_refresh http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/aria/storage/instrumentation.py ---------------------------------------------------------------------- diff --git a/aria/storage/instrumentation.py b/aria/storage/instrumentation.py index 57fe9bd..41818b6 100644 --- a/aria/storage/instrumentation.py +++ b/aria/storage/instrumentation.py @@ -15,10 +15,14 @@ import copy +import sqlalchemy import sqlalchemy.event +from . import exceptions + from .modeling import model as _model +_VERSION_ID_COL = 'version_id' _STUB = object() _INSTRUMENTED = { _model.Node.runtime_properties: dict @@ -92,6 +96,11 @@ class _Instrumentation(object): mapi_name = instrumented_class.__modelname__ tracked_instances = self.tracked_changes.setdefault(mapi_name, {}) tracked_attributes = tracked_instances.setdefault(target.id, {}) + if hasattr(target, _VERSION_ID_COL): + # We want to keep track of the initial version id so it can be compared + # with the committed version id when the tracked changes are applied + tracked_attributes.setdefault(_VERSION_ID_COL, + _Value(_STUB, getattr(target, _VERSION_ID_COL))) for attribute_name, attribute_type in instrumented_attributes.items(): if attribute_name not in tracked_attributes: initial = getattr(target, attribute_name) @@ -143,7 +152,7 @@ class _Value(object): return self.initial == other.initial and self.current == other.current def __hash__(self): - return hash(self.initial) ^ hash(self.current) + return hash((self.initial, self.current)) def apply_tracked_changes(tracked_changes, model): @@ -163,4 +172,23 @@ def apply_tracked_changes(tracked_changes, model): instance = mapi.get(instance_id) setattr(instance, attribute_name, value.current) if instance: + _validate_version_id(instance, mapi) mapi.update(instance) + + +def _validate_version_id(instance, mapi): + version_id = sqlalchemy.inspect(instance).committed_state.get(_VERSION_ID_COL) + # There are two version conflict code paths: + # 1. The instance committed state loaded already holds a newer version, + # in this case, we manually raise the error + # 2. The UPDATE statement is executed with version validation and sqlalchemy + # will raise a StateDataError if there is a version mismatch. + if version_id and getattr(instance, _VERSION_ID_COL) != version_id: + object_version_id = getattr(instance, _VERSION_ID_COL) + mapi._session.rollback() + raise exceptions.StorageError( + 'Version conflict: committed and object {0} differ ' + '[committed {0}={1}, object {0}={2}]' + .format(_VERSION_ID_COL, + version_id, + object_version_id)) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/aria/storage/sql_mapi.py ---------------------------------------------------------------------- diff --git a/aria/storage/sql_mapi.py b/aria/storage/sql_mapi.py index b80ac8e..2711aee 100644 --- a/aria/storage/sql_mapi.py +++ b/aria/storage/sql_mapi.py @@ -23,6 +23,7 @@ from sqlalchemy import ( orm, ) from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm.exc import StaleDataError from aria.utils.collections import OrderedDict from . import ( @@ -162,6 +163,9 @@ class SQLAlchemyModelAPI(api.ModelAPI): """ try: self._session.commit() + except StaleDataError as e: + self._session.rollback() + raise exceptions.StorageError('Version conflict: {0}'.format(str(e))) except (SQLAlchemyError, ValueError) as e: self._session.rollback() raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py ---------------------------------------------------------------------- diff --git a/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py new file mode 100644 index 0000000..7c54bc5 --- /dev/null +++ b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py @@ -0,0 +1,174 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import json + +import fasteners +import pytest + +from aria.storage.exceptions import StorageError +from aria.orchestrator import events +from aria.orchestrator.workflows.exceptions import ExecutorException +from aria.orchestrator.workflows import api +from aria.orchestrator.workflows.executor import process +from aria.orchestrator import workflow, operation + +import tests +from tests.orchestrator.context import execute as execute_workflow +from tests.orchestrator.workflows.helpers import events_collector +from tests import mock +from tests import storage + + +def test_concurrent_modification_on_task_succeeded(context, executor, shared_file): + _test(context, executor, shared_file, _test_task_succeeded, expected_failure=True) + + +@operation +def _test_task_succeeded(ctx, shared_file, key, first_value, second_value): + _concurrent_update(shared_file, ctx.node_instance, key, first_value, second_value) + + +def test_concurrent_modification_on_task_failed(context, executor, shared_file): + _test(context, executor, shared_file, _test_task_failed, expected_failure=True) + + +@operation +def _test_task_failed(ctx, shared_file, key, first_value, second_value): + first = _concurrent_update(shared_file, ctx.node_instance, key, first_value, second_value) + if not first: + raise RuntimeError('MESSAGE') + + +def test_concurrent_modification_on_update_and_refresh(context, executor, shared_file): + _test(context, executor, shared_file, _test_update_and_refresh, expected_failure=False) + + +@operation +def _test_update_and_refresh(ctx, shared_file, key, first_value, second_value): + node_instance = ctx.node_instance + first = _concurrent_update(shared_file, node_instance, key, first_value, second_value) + if not first: + try: + ctx.model.node_instance.update(node_instance) + except StorageError as e: + assert 'Version conflict' in str(e) + ctx.model.node_instance.refresh(node_instance) + else: + raise RuntimeError('Unexpected') + + +def _test(context, executor, shared_file, func, expected_failure): + def _node_instance(ctx): + return ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + + shared_file.write(json.dumps({})) + key = 'key' + first_value = 'value1' + second_value = 'value2' + inputs = { + 'shared_file': str(shared_file), + 'key': key, + 'first_value': first_value, + 'second_value': second_value + } + + @workflow + def mock_workflow(ctx, graph): + op = 'test.op' + node_instance = _node_instance(ctx) + node_instance.node.operations[op] = {'operation': '{0}.{1}'.format(__name__, func.__name__)} + graph.add_tasks( + api.task.OperationTask.node_instance(instance=node_instance, name=op, inputs=inputs), + api.task.OperationTask.node_instance(instance=node_instance, name=op, inputs=inputs)) + + signal = events.on_failure_task_signal + with events_collector(signal) as collected: + try: + execute_workflow(mock_workflow, context, executor) + except ExecutorException: + pass + + props = _node_instance(context).runtime_properties + assert props[key] == first_value + + exceptions = [event['kwargs']['exception'] for event in collected.get(signal, [])] + if expected_failure: + assert exceptions + exception = exceptions[-1] + assert isinstance(exception, StorageError) + assert 'Version conflict' in str(exception) + else: + assert not exceptions + + +@pytest.fixture +def executor(): + result = process.ProcessExecutor(python_path=[tests.ROOT_DIR]) + yield result + result.close() + + +@pytest.fixture +def context(tmpdir): + result = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield result + storage.release_sqlite_storage(result.model) + + +@pytest.fixture +def shared_file(tmpdir): + return tmpdir.join('shared_file') + + +def _concurrent_update(shared_file, node_instance, key, first_value, second_value): + def lock(): + return fasteners.InterProcessLock(shared_file) + + def get(key): + with open(shared_file) as f: + return json.load(f).get(key) + + def set(key): + with open(shared_file) as f: + content = json.load(f) + content[key] = True + with open(shared_file, 'wb') as f: + json.dump(content, f) + + def wait_for(key): + while True: + time.sleep(0.01) + with lock(): + if get(key): + break + + with lock(): + first = not get('first_in') + set('first_in' if first else 'second_in') + + if first: + wait_for('second_in') + + node_instance.runtime_properties[key] = first_value if first else second_value + + if first: + with lock(): + set('first_out') + else: + wait_for('first_out') + + return first http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/tests/requirements.txt ---------------------------------------------------------------------- diff --git a/tests/requirements.txt b/tests/requirements.txt index 0e4740f..2f0245a 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -11,8 +11,9 @@ # limitations under the License. testtools +fasteners==0.13.0 mock==1.0.1 pylint==1.6.4 pytest==3.0.2 pytest-cov==2.3.1 -pytest-mock==1.2 \ No newline at end of file +pytest-mock==1.2