Repository: incubator-ariatosca Updated Branches: refs/heads/ARIA-79-concurrent-storage-modifications 67a42409b -> ba5247a39 (forced update)
ARIA-79-concurrent-modifications Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/ba5247a3 Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/ba5247a3 Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/ba5247a3 Branch: refs/heads/ARIA-79-concurrent-storage-modifications Commit: ba5247a398a11fa0556ad9138a9179d0b8b9c713 Parents: 9e62fca Author: Dan Kilman <[email protected]> Authored: Mon Jan 30 16:49:00 2017 +0200 Committer: Dan Kilman <[email protected]> Committed: Wed Feb 1 16:00:15 2017 +0200 ---------------------------------------------------------------------- aria/orchestrator/workflows/executor/process.py | 156 +++++++++------ aria/storage/base_model.py | 3 + aria/storage/instrumentation.py | 30 ++- aria/storage/sql_mapi.py | 4 + ...process_executor_concurrent_modifications.py | 196 +++++++++++++++++++ tests/requirements.txt | 3 +- tests/storage/__init__.py | 5 +- 7 files changed, 336 insertions(+), 61 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/ba5247a3/aria/orchestrator/workflows/executor/process.py ---------------------------------------------------------------------- diff --git a/aria/orchestrator/workflows/executor/process.py b/aria/orchestrator/workflows/executor/process.py index 7d990fa..319982e 100644 --- a/aria/orchestrator/workflows/executor/process.py +++ b/aria/orchestrator/workflows/executor/process.py @@ -74,6 +74,13 @@ class ProcessExecutor(base.BaseExecutor): # Contains reference to all currently running tasks self._tasks = {} + self._request_handlers = { + 'started': self._handle_task_started_request, + 'succeeded': self._handle_task_succeeded_request, + 'failed': self._handle_task_failed_request, + 'apply_tracked_changes': self._handle_apply_tracked_changes_request + } + # Server socket used to accept task status messages from subprocesses self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._server_socket.bind(('localhost', 0)) @@ -131,58 +138,6 @@ class ProcessExecutor(base.BaseExecutor): def _remove_task(self, task_id): return self._tasks.pop(task_id) - def _listener(self): - # Notify __init__ method this thread has actually started - self._listener_started.put(True) - while not self._stopped: - try: - # Accept messages written to the server socket - with contextlib.closing(self._server_socket.accept()[0]) as connection: - message = self._recv_message(connection) - message_type = message['type'] - if message_type == 'closed': - break - task_id = message['task_id'] - if message_type == 'started': - self._task_started(self._tasks[task_id]) - elif message_type == 'apply_tracked_changes': - task = self._tasks[task_id] - instrumentation.apply_tracked_changes( - tracked_changes=message['tracked_changes'], - model=task.context.model) - elif message_type == 'succeeded': - task = self._remove_task(task_id) - instrumentation.apply_tracked_changes( - tracked_changes=message['tracked_changes'], - model=task.context.model) - self._task_succeeded(task) - elif message_type == 'failed': - task = self._remove_task(task_id) - instrumentation.apply_tracked_changes( - tracked_changes=message['tracked_changes'], - model=task.context.model) - self._task_failed(task, exception=message['exception']) - else: - raise RuntimeError('Invalid state') - except BaseException as e: - self.logger.debug('Error in process executor listener: {0}'.format(e)) - - def _recv_message(self, connection): - message_len, = struct.unpack(_INT_FMT, self._recv_bytes(connection, _INT_SIZE)) - return jsonpickle.loads(self._recv_bytes(connection, message_len)) - - @staticmethod - def _recv_bytes(connection, count): - result = io.BytesIO() - while True: - if not count: - return result.getvalue() - read = connection.recv(count) - if not read: - return result.getvalue() - result.write(read) - count -= len(read) - def _check_closed(self): if self._stopped: raise RuntimeError('Executor closed') @@ -231,6 +186,87 @@ class ProcessExecutor(base.BaseExecutor): os.pathsep, env.get('PYTHONPATH', '')) + def _listener(self): + # Notify __init__ method this thread has actually started + self._listener_started.put(True) + while not self._stopped: + try: + with self._accept_request() as (request, response): + request_type = request['type'] + if request_type == 'closed': + break + request_handler = self._request_handlers.get(request_type) + if not request_handler: + raise RuntimeError('Invalid request type: {0}'.format(request_type)) + request_handler(task_id=request['task_id'], request=request, response=response) + except BaseException as e: + self.logger.debug('Error in process executor listener: {0}'.format(e)) + + @contextlib.contextmanager + def _accept_request(self): + with contextlib.closing(self._server_socket.accept()[0]) as connection: + message = _recv_message(connection) + response = {} + yield message, response + _send_message(connection, response) + + def _handle_task_started_request(self, task_id, **kwargs): + self._task_started(self._tasks[task_id]) + + def _handle_task_succeeded_request(self, task_id, request, **kwargs): + task = self._remove_task(task_id) + try: + self._apply_tracked_changes(task, request) + except BaseException as e: + self._task_failed(task, exception=e) + else: + self._task_succeeded(task) + + def _handle_task_failed_request(self, task_id, request, **kwargs): + task = self._remove_task(task_id) + try: + self._apply_tracked_changes(task, request) + except BaseException as e: + self._task_failed(task, exception=e) + else: + self._task_failed(task, exception=request['exception']) + + def _handle_apply_tracked_changes_request(self, task_id, request, response): + task = self._tasks[task_id] + try: + self._apply_tracked_changes(task, request) + except BaseException as e: + response['exception'] = exceptions.wrap_if_needed(e) + + @staticmethod + def _apply_tracked_changes(task, request): + instrumentation.apply_tracked_changes( + tracked_changes=request['tracked_changes'], + model=task.context.model) + + +def _send_message(connection, message): + data = jsonpickle.dumps(message) + connection.send(struct.pack(_INT_FMT, len(data))) + connection.sendall(data) + + +def _recv_message(connection): + message_len, = struct.unpack(_INT_FMT, _recv_bytes(connection, _INT_SIZE)) + return jsonpickle.loads(_recv_bytes(connection, message_len)) + + +def _recv_bytes(connection, count): + result = io.BytesIO() + while True: + if not count: + return result.getvalue() + read = connection.recv(count) + if not read: + return result.getvalue() + result.write(read) + count -= len(read) + class _Messenger(object): @@ -261,17 +297,16 @@ class _Messenger(object): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(('localhost', self.port)) try: - data = jsonpickle.dumps({ + _send_message(sock, { 'type': type, 'task_id': self.task_id, 'exception': exceptions.wrap_if_needed(exception), 'tracked_changes': tracked_changes }) - sock.send(struct.pack(_INT_FMT, len(data))) - sock.sendall(data) - # send message will block until the server side closes the connection socket - # because we want it to be synchronous - sock.recv(1) + response = _recv_message(sock) + response_exception = response.get('exception') + if response_exception: + raise response_exception finally: sock.close() @@ -294,12 +329,17 @@ def _patch_session(ctx, messenger, instrument): messenger.apply_tracked_changes(instrument.tracked_changes) instrument.clear() + def patched_rollback(): + # Rollback is performed on parent process when commit fails + pass + # when autoflush is set to true (the default), refreshing an object will trigger # an auto flush by sqlalchemy, this autoflush will attempt to commit changes made so # far on the session. this is not the desired behavior in the subprocess session.autoflush = False session.commit = patched_commit + session.rollback = patched_rollback session.refresh = patched_refresh http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/ba5247a3/aria/storage/base_model.py ---------------------------------------------------------------------- diff --git a/aria/storage/base_model.py b/aria/storage/base_model.py index f7d0e5b..56605fc 100644 --- a/aria/storage/base_model.py +++ b/aria/storage/base_model.py @@ -479,6 +479,7 @@ class NodeInstanceBase(ModelMixin): __tablename__ = 'node_instances' _private_fields = ['node_fk', 'host_fk'] + version_id = Column(Integer, nullable=False) runtime_properties = Column(Dict) scaling_groups = Column(List) state = Column(Text, nullable=False) @@ -528,6 +529,8 @@ class NodeInstanceBase(ModelMixin): return host_node.properties['ip'] return None + __mapper_args__ = {'version_id_col': version_id} + class RelationshipInstanceBase(ModelMixin): """ http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/ba5247a3/aria/storage/instrumentation.py ---------------------------------------------------------------------- diff --git a/aria/storage/instrumentation.py b/aria/storage/instrumentation.py index 537dbb5..e91fab0 100644 --- a/aria/storage/instrumentation.py +++ b/aria/storage/instrumentation.py @@ -15,11 +15,15 @@ import copy +import sqlalchemy import sqlalchemy.event +from . import exceptions from . import api from . import model as _model + +_VERSION_ID_COL = 'version_id' _STUB = object() _INSTRUMENTED = { _model.NodeInstance.runtime_properties: dict @@ -93,6 +97,11 @@ class _Instrumentation(object): mapi_name = self._mapi_name(instrumented_class) tracked_instances = self.tracked_changes.setdefault(mapi_name, {}) tracked_attributes = tracked_instances.setdefault(target.id, {}) + if hasattr(target, _VERSION_ID_COL): + # We want to keep track of the initial version id so it can be compared + # with the committed version id when the tracked changes are applied + tracked_attributes.setdefault(_VERSION_ID_COL, + _Value(_STUB, getattr(target, _VERSION_ID_COL))) for attribute_name, attribute_type in instrumented_attributes.items(): if attribute_name not in tracked_attributes: initial = getattr(target, attribute_name) @@ -148,7 +157,7 @@ class _Value(object): return self.initial == other.initial and self.current == other.current def __hash__(self): - return hash(self.initial) ^ hash(self.current) + return hash((self.initial, self.current)) def apply_tracked_changes(tracked_changes, model): @@ -168,4 +177,23 @@ def apply_tracked_changes(tracked_changes, model): instance = mapi.get(instance_id) setattr(instance, attribute_name, value.current) if instance: + _validate_version_id(instance, mapi) mapi.update(instance) + + +def _validate_version_id(instance, mapi): + version_id = sqlalchemy.inspect(instance).committed_state.get(_VERSION_ID_COL) + # There are two version conflict code paths: + # 1. The instance committed state loaded already holds a newer version, + # in this case, we manually raise the error + # 2. The UPDATE statement is executed with version validation and sqlalchemy + # will raise a StateDataError if there is a version mismatch. + if version_id and getattr(instance, _VERSION_ID_COL) != version_id: + object_version_id = getattr(instance, _VERSION_ID_COL) + mapi._session.rollback() + raise exceptions.StorageError( + 'Version conflict: committed and object {0} differ ' + '[committed {0}={1}, object {0}={2}]' + .format(_VERSION_ID_COL, + version_id, + object_version_id)) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/ba5247a3/aria/storage/sql_mapi.py ---------------------------------------------------------------------- diff --git a/aria/storage/sql_mapi.py b/aria/storage/sql_mapi.py index 809f677..0c08e48 100644 --- a/aria/storage/sql_mapi.py +++ b/aria/storage/sql_mapi.py @@ -17,6 +17,7 @@ SQLAlchemy based MAPI """ from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm.exc import StaleDataError from aria.utils.collections import OrderedDict from aria.storage import ( @@ -152,6 +153,9 @@ class SQLAlchemyModelAPI(api.ModelAPI): """ try: self._session.commit() + except StaleDataError as e: + self._session.rollback() + raise exceptions.StorageError('Version conflict: {0}'.format(str(e))) except (SQLAlchemyError, ValueError) as e: self._session.rollback() raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/ba5247a3/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py ---------------------------------------------------------------------- diff --git a/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py new file mode 100644 index 0000000..c9bff8d --- /dev/null +++ b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py @@ -0,0 +1,196 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import time +import json + +import fasteners +import pytest + +from aria.storage.exceptions import StorageError +from aria.orchestrator import events +from aria.orchestrator.workflows.exceptions import ExecutorException +from aria.orchestrator.workflows import api +from aria.orchestrator.workflows.executor import process +from aria.orchestrator import workflow, operation + +import tests +from tests.orchestrator.context import execute as execute_workflow +from tests.orchestrator.workflows.helpers import events_collector +from tests import mock +from tests import storage + + +def test_concurrent_modification_on_task_succeeded(context, executor, shared_file): + _test(context, executor, shared_file, _test_task_succeeded, expected_failure=True) + + +@operation +def _test_task_succeeded(ctx, shared_file, first_entry, second_entry): + node_instance = ctx.node_instance + with _sync(shared_file) as first: + entry = first_entry if first else second_entry + key, value = entry + node_instance.runtime_properties[key] = value + + +def test_concurrent_modification_on_task_failed(context, executor, shared_file): + _test(context, executor, shared_file, _test_task_failed, expected_failure=True) + + +@operation +def _test_task_failed(ctx, shared_file, first_entry, second_entry): + node_instance = ctx.node_instance + with _sync(shared_file) as first: + entry = first_entry if first else second_entry + key, value = entry + node_instance.runtime_properties[key] = value + if not first: + raise RuntimeError(value) + + +def test_concurrent_modification_on_update_and_refresh(context, executor, shared_file): + _test(context, executor, shared_file, _test_update_and_refresh, expected_failure=False) + + +@operation +def _test_update_and_refresh(ctx, shared_file, first_entry, second_entry): + node_instance = ctx.node_instance + with _sync(shared_file) as first: + entry = first_entry if first else second_entry + key, value = entry + node_instance.runtime_properties[key] = value + if not first: + try: + ctx.model.node_instance.update(node_instance) + except StorageError as e: + assert 'Version conflict' in str(e) + ctx.model.node_instance.refresh(node_instance) + else: + raise RuntimeError('Unexpected') + + +def _test(context, executor, shared_file, func, expected_failure): + with open(shared_file, 'wb') as f: + json.dump({}, f) + key = 'key' + first_value = 'value1' + second_value = 'value2' + inputs = { + 'shared_file': shared_file, + 'first_entry': [key, first_value], + 'second_entry': [key, second_value] + } + props, exceptions = _run_workflow(context, executor, func, inputs) + assert props[key] == first_value + if expected_failure: + assert exceptions + exception = exceptions[-1] + assert isinstance(exception, StorageError) + assert 'Version conflict' in str(exception) + else: + assert not exceptions + + +def _run_workflow(context, executor, func, inputs): + def _node_instance(ctx): + return ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME) + + @workflow + def mock_workflow(ctx, graph): + op = 'test.op' + node_instance = _node_instance(ctx) + node_instance.node.operations[op] = { + 'operation': '{0}.{1}'.format(__name__, func.__name__) + } + graph.add_tasks( + api.task.OperationTask.node_instance(instance=node_instance, name=op, inputs=inputs), + api.task.OperationTask.node_instance(instance=node_instance, name=op, inputs=inputs)) + + signal = events.on_failure_task_signal + with events_collector(signal) as collected: + try: + execute_workflow(mock_workflow, context, executor) + except ExecutorException: + pass + if signal in collected: + exceptions = [event['kwargs']['exception'] for event in collected[signal]] + else: + exceptions = [] + + return _node_instance(context).runtime_properties, exceptions + + [email protected] +def executor(): + result = process.ProcessExecutor(python_path=[tests.ROOT_DIR]) + yield result + result.close() + + [email protected] +def context(tmpdir): + result = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir))) + yield result + storage.release_sqlite_storage(result.model) + + [email protected] +def shared_file(tmpdir): + return str(tmpdir.join('shared_file')) + + [email protected] +def _sync(shared_file): + + def lock(): + return fasteners.InterProcessLock('{0}.lock'.format(shared_file)) + + def read(): + with open(shared_file) as f: + return json.load(f) + + def write(content): + with open(shared_file, 'wb') as f: + json.dump(content, f) + + with lock(): + content = read() + first = not content.get('first_in') + key = 'first_in' if first else 'second_in' + content[key] = True + write(content) + + if first: + while True: + time.sleep(0.01) + with lock(): + if read().get('second_in'): + break + + yield first + + if first: + with lock(): + content = read() + content['first_out'] = True + write(content) + else: + while True: + time.sleep(0.01) + with lock(): + if read().get('first_out'): + break http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/ba5247a3/tests/requirements.txt ---------------------------------------------------------------------- diff --git a/tests/requirements.txt b/tests/requirements.txt index 0e4740f..2f0245a 100644 --- a/tests/requirements.txt +++ b/tests/requirements.txt @@ -11,8 +11,9 @@ # limitations under the License. testtools +fasteners==0.13.0 mock==1.0.1 pylint==1.6.4 pytest==3.0.2 pytest-cov==2.3.1 -pytest-mock==1.2 \ No newline at end of file +pytest-mock==1.2 http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/ba5247a3/tests/storage/__init__.py ---------------------------------------------------------------------- diff --git a/tests/storage/__init__.py b/tests/storage/__init__.py index 3b3715e..8fb4f30 100644 --- a/tests/storage/__init__.py +++ b/tests/storage/__init__.py @@ -75,8 +75,11 @@ def get_sqlite_api_kwargs(base_dir=None, filename='db.sqlite'): engine = create_engine(uri, **engine_kwargs) session_factory = orm.sessionmaker(bind=engine) - session = orm.scoped_session(session_factory=session_factory) if base_dir else session_factory() + if base_dir: + session = orm.scoped_session(session_factory=session_factory) + else: + session = session_factory() model.DeclarativeBase.metadata.create_all(bind=engine) return dict(engine=engine, session=session)
