ARIA-79-concurrent-modifications

Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/4a72e113
Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/4a72e113
Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/4a72e113

Branch: refs/heads/ARIA-79-concurrent-storage-modifications
Commit: 4a72e113b48b2aafdf778d641182a3a8a2c60d1a
Parents: b619335
Author: Dan Kilman <d...@gigaspaces.com>
Authored: Mon Jan 30 16:49:00 2017 +0200
Committer: mxmrlv <mxm...@gmail.com>
Committed: Thu Feb 16 15:03:40 2017 +0200

----------------------------------------------------------------------
 aria/orchestrator/workflows/executor/process.py | 156 ++++++++++-------
 aria/storage/instrumentation.py                 |  30 +++-
 aria/storage/sql_mapi.py                        |   4 +
 ...process_executor_concurrent_modifications.py | 174 +++++++++++++++++++
 tests/requirements.txt                          |   3 +-
 5 files changed, 307 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/aria/orchestrator/workflows/executor/process.py
----------------------------------------------------------------------
diff --git a/aria/orchestrator/workflows/executor/process.py 
b/aria/orchestrator/workflows/executor/process.py
index 560ac43..a23e3da 100644
--- a/aria/orchestrator/workflows/executor/process.py
+++ b/aria/orchestrator/workflows/executor/process.py
@@ -74,6 +74,13 @@ class ProcessExecutor(base.BaseExecutor):
         # Contains reference to all currently running tasks
         self._tasks = {}
 
+        self._request_handlers = {
+            'started': self._handle_task_started_request,
+            'succeeded': self._handle_task_succeeded_request,
+            'failed': self._handle_task_failed_request,
+            'apply_tracked_changes': self._handle_apply_tracked_changes_request
+        }
+
         # Server socket used to accept task status messages from subprocesses
         self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         self._server_socket.bind(('localhost', 0))
@@ -131,58 +138,6 @@ class ProcessExecutor(base.BaseExecutor):
     def _remove_task(self, task_id):
         return self._tasks.pop(task_id)
 
-    def _listener(self):
-        # Notify __init__ method this thread has actually started
-        self._listener_started.put(True)
-        while not self._stopped:
-            try:
-                # Accept messages written to the server socket
-                with contextlib.closing(self._server_socket.accept()[0]) as 
connection:
-                    message = self._recv_message(connection)
-                    message_type = message['type']
-                    if message_type == 'closed':
-                        break
-                    task_id = message['task_id']
-                    if message_type == 'started':
-                        self._task_started(self._tasks[task_id])
-                    elif message_type == 'apply_tracked_changes':
-                        task = self._tasks[task_id]
-                        instrumentation.apply_tracked_changes(
-                            tracked_changes=message['tracked_changes'],
-                            model=task.context.model)
-                    elif message_type == 'succeeded':
-                        task = self._remove_task(task_id)
-                        instrumentation.apply_tracked_changes(
-                            tracked_changes=message['tracked_changes'],
-                            model=task.context.model)
-                        self._task_succeeded(task)
-                    elif message_type == 'failed':
-                        task = self._remove_task(task_id)
-                        instrumentation.apply_tracked_changes(
-                            tracked_changes=message['tracked_changes'],
-                            model=task.context.model)
-                        self._task_failed(task, exception=message['exception'])
-                    else:
-                        raise RuntimeError('Invalid state')
-            except BaseException as e:
-                self.logger.debug('Error in process executor listener: 
{0}'.format(e))
-
-    def _recv_message(self, connection):
-        message_len, = struct.unpack(_INT_FMT, self._recv_bytes(connection, 
_INT_SIZE))
-        return jsonpickle.loads(self._recv_bytes(connection, message_len))
-
-    @staticmethod
-    def _recv_bytes(connection, count):
-        result = io.BytesIO()
-        while True:
-            if not count:
-                return result.getvalue()
-            read = connection.recv(count)
-            if not read:
-                return result.getvalue()
-            result.write(read)
-            count -= len(read)
-
     def _check_closed(self):
         if self._stopped:
             raise RuntimeError('Executor closed')
@@ -231,6 +186,87 @@ class ProcessExecutor(base.BaseExecutor):
                 os.pathsep,
                 env.get('PYTHONPATH', ''))
 
+    def _listener(self):
+        # Notify __init__ method this thread has actually started
+        self._listener_started.put(True)
+        while not self._stopped:
+            try:
+                with self._accept_request() as (request, response):
+                    request_type = request['type']
+                    if request_type == 'closed':
+                        break
+                    request_handler = self._request_handlers.get(request_type)
+                    if not request_handler:
+                        raise RuntimeError('Invalid request type: 
{0}'.format(request_type))
+                    request_handler(task_id=request['task_id'], 
request=request, response=response)
+            except BaseException as e:
+                self.logger.debug('Error in process executor listener: 
{0}'.format(e))
+
+    @contextlib.contextmanager
+    def _accept_request(self):
+        with contextlib.closing(self._server_socket.accept()[0]) as connection:
+            message = _recv_message(connection)
+            response = {}
+            yield message, response
+            _send_message(connection, response)
+
+    def _handle_task_started_request(self, task_id, **kwargs):
+        self._task_started(self._tasks[task_id])
+
+    def _handle_task_succeeded_request(self, task_id, request, **kwargs):
+        task = self._remove_task(task_id)
+        try:
+            self._apply_tracked_changes(task, request)
+        except BaseException as e:
+            self._task_failed(task, exception=e)
+        else:
+            self._task_succeeded(task)
+
+    def _handle_task_failed_request(self, task_id, request, **kwargs):
+        task = self._remove_task(task_id)
+        try:
+            self._apply_tracked_changes(task, request)
+        except BaseException as e:
+            self._task_failed(task, exception=e)
+        else:
+            self._task_failed(task, exception=request['exception'])
+
+    def _handle_apply_tracked_changes_request(self, task_id, request, 
response):
+        task = self._tasks[task_id]
+        try:
+            self._apply_tracked_changes(task, request)
+        except BaseException as e:
+            response['exception'] = exceptions.wrap_if_needed(e)
+
+    @staticmethod
+    def _apply_tracked_changes(task, request):
+        instrumentation.apply_tracked_changes(
+            tracked_changes=request['tracked_changes'],
+            model=task.context.model)
+
+
+def _send_message(connection, message):
+    data = jsonpickle.dumps(message)
+    connection.send(struct.pack(_INT_FMT, len(data)))
+    connection.sendall(data)
+
+
+def _recv_message(connection):
+    message_len, = struct.unpack(_INT_FMT, _recv_bytes(connection, _INT_SIZE))
+    return jsonpickle.loads(_recv_bytes(connection, message_len))
+
+
+def _recv_bytes(connection, count):
+    result = io.BytesIO()
+    while True:
+        if not count:
+            return result.getvalue()
+        read = connection.recv(count)
+        if not read:
+            return result.getvalue()
+        result.write(read)
+        count -= len(read)
+
 
 class _Messenger(object):
 
@@ -261,17 +297,16 @@ class _Messenger(object):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.connect(('localhost', self.port))
         try:
-            data = jsonpickle.dumps({
+            _send_message(sock, {
                 'type': type,
                 'task_id': self.task_id,
                 'exception': exceptions.wrap_if_needed(exception),
                 'tracked_changes': tracked_changes
             })
-            sock.send(struct.pack(_INT_FMT, len(data)))
-            sock.sendall(data)
-            # send message will block until the server side closes the 
connection socket
-            # because we want it to be synchronous
-            sock.recv(1)
+            response = _recv_message(sock)
+            response_exception = response.get('exception')
+            if response_exception:
+                raise response_exception
         finally:
             sock.close()
 
@@ -294,12 +329,17 @@ def _patch_session(ctx, messenger, instrument):
         messenger.apply_tracked_changes(instrument.tracked_changes)
         instrument.clear()
 
+    def patched_rollback():
+        # Rollback is performed on parent process when commit fails
+        pass
+
     # when autoflush is set to true (the default), refreshing an object will 
trigger
     # an auto flush by sqlalchemy, this autoflush will attempt to commit 
changes made so
     # far on the session. this is not the desired behavior in the subprocess
     session.autoflush = False
 
     session.commit = patched_commit
+    session.rollback = patched_rollback
     session.refresh = patched_refresh
 
 

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/aria/storage/instrumentation.py
----------------------------------------------------------------------
diff --git a/aria/storage/instrumentation.py b/aria/storage/instrumentation.py
index 57fe9bd..41818b6 100644
--- a/aria/storage/instrumentation.py
+++ b/aria/storage/instrumentation.py
@@ -15,10 +15,14 @@
 
 import copy
 
+import sqlalchemy
 import sqlalchemy.event
 
+from . import exceptions
+
 from .modeling import model as _model
 
+_VERSION_ID_COL = 'version_id'
 _STUB = object()
 _INSTRUMENTED = {
     _model.Node.runtime_properties: dict
@@ -92,6 +96,11 @@ class _Instrumentation(object):
             mapi_name = instrumented_class.__modelname__
             tracked_instances = self.tracked_changes.setdefault(mapi_name, {})
             tracked_attributes = tracked_instances.setdefault(target.id, {})
+            if hasattr(target, _VERSION_ID_COL):
+                # We want to keep track of the initial version id so it can be 
compared
+                # with the committed version id when the tracked changes are 
applied
+                tracked_attributes.setdefault(_VERSION_ID_COL,
+                                              _Value(_STUB, getattr(target, 
_VERSION_ID_COL)))
             for attribute_name, attribute_type in 
instrumented_attributes.items():
                 if attribute_name not in tracked_attributes:
                     initial = getattr(target, attribute_name)
@@ -143,7 +152,7 @@ class _Value(object):
         return self.initial == other.initial and self.current == other.current
 
     def __hash__(self):
-        return hash(self.initial) ^ hash(self.current)
+        return hash((self.initial, self.current))
 
 
 def apply_tracked_changes(tracked_changes, model):
@@ -163,4 +172,23 @@ def apply_tracked_changes(tracked_changes, model):
                         instance = mapi.get(instance_id)
                     setattr(instance, attribute_name, value.current)
             if instance:
+                _validate_version_id(instance, mapi)
                 mapi.update(instance)
+
+
+def _validate_version_id(instance, mapi):
+    version_id = 
sqlalchemy.inspect(instance).committed_state.get(_VERSION_ID_COL)
+    # There are two version conflict code paths:
+    # 1. The instance committed state loaded already holds a newer version,
+    #    in this case, we manually raise the error
+    # 2. The UPDATE statement is executed with version validation and 
sqlalchemy
+    #    will raise a StateDataError if there is a version mismatch.
+    if version_id and getattr(instance, _VERSION_ID_COL) != version_id:
+        object_version_id = getattr(instance, _VERSION_ID_COL)
+        mapi._session.rollback()
+        raise exceptions.StorageError(
+            'Version conflict: committed and object {0} differ '
+            '[committed {0}={1}, object {0}={2}]'
+            .format(_VERSION_ID_COL,
+                    version_id,
+                    object_version_id))

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/aria/storage/sql_mapi.py
----------------------------------------------------------------------
diff --git a/aria/storage/sql_mapi.py b/aria/storage/sql_mapi.py
index b80ac8e..2711aee 100644
--- a/aria/storage/sql_mapi.py
+++ b/aria/storage/sql_mapi.py
@@ -23,6 +23,7 @@ from sqlalchemy import (
     orm,
 )
 from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm.exc import StaleDataError
 
 from aria.utils.collections import OrderedDict
 from . import (
@@ -162,6 +163,9 @@ class SQLAlchemyModelAPI(api.ModelAPI):
         """
         try:
             self._session.commit()
+        except StaleDataError as e:
+            self._session.rollback()
+            raise exceptions.StorageError('Version conflict: 
{0}'.format(str(e)))
         except (SQLAlchemyError, ValueError) as e:
             self._session.rollback()
             raise exceptions.StorageError('SQL Storage error: 
{0}'.format(str(e)))

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
----------------------------------------------------------------------
diff --git 
a/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
 
b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
new file mode 100644
index 0000000..7c54bc5
--- /dev/null
+++ 
b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+import json
+
+import fasteners
+import pytest
+
+from aria.storage.exceptions import StorageError
+from aria.orchestrator import events
+from aria.orchestrator.workflows.exceptions import ExecutorException
+from aria.orchestrator.workflows import api
+from aria.orchestrator.workflows.executor import process
+from aria.orchestrator import workflow, operation
+
+import tests
+from tests.orchestrator.context import execute as execute_workflow
+from tests.orchestrator.workflows.helpers import events_collector
+from tests import mock
+from tests import storage
+
+
+def test_concurrent_modification_on_task_succeeded(context, executor, 
shared_file):
+    _test(context, executor, shared_file, _test_task_succeeded, 
expected_failure=True)
+
+
+@operation
+def _test_task_succeeded(ctx, shared_file, key, first_value, second_value):
+    _concurrent_update(shared_file, ctx.node_instance, key, first_value, 
second_value)
+
+
+def test_concurrent_modification_on_task_failed(context, executor, 
shared_file):
+    _test(context, executor, shared_file, _test_task_failed, 
expected_failure=True)
+
+
+@operation
+def _test_task_failed(ctx, shared_file, key, first_value, second_value):
+    first = _concurrent_update(shared_file, ctx.node_instance, key, 
first_value, second_value)
+    if not first:
+        raise RuntimeError('MESSAGE')
+
+
+def test_concurrent_modification_on_update_and_refresh(context, executor, 
shared_file):
+    _test(context, executor, shared_file, _test_update_and_refresh, 
expected_failure=False)
+
+
+@operation
+def _test_update_and_refresh(ctx, shared_file, key, first_value, second_value):
+    node_instance = ctx.node_instance
+    first = _concurrent_update(shared_file, node_instance, key, first_value, 
second_value)
+    if not first:
+        try:
+            ctx.model.node_instance.update(node_instance)
+        except StorageError as e:
+            assert 'Version conflict' in str(e)
+            ctx.model.node_instance.refresh(node_instance)
+        else:
+            raise RuntimeError('Unexpected')
+
+
+def _test(context, executor, shared_file, func, expected_failure):
+    def _node_instance(ctx):
+        return 
ctx.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+
+    shared_file.write(json.dumps({}))
+    key = 'key'
+    first_value = 'value1'
+    second_value = 'value2'
+    inputs = {
+        'shared_file': str(shared_file),
+        'key': key,
+        'first_value': first_value,
+        'second_value': second_value
+    }
+
+    @workflow
+    def mock_workflow(ctx, graph):
+        op = 'test.op'
+        node_instance = _node_instance(ctx)
+        node_instance.node.operations[op] = {'operation': 
'{0}.{1}'.format(__name__, func.__name__)}
+        graph.add_tasks(
+            api.task.OperationTask.node_instance(instance=node_instance, 
name=op, inputs=inputs),
+            api.task.OperationTask.node_instance(instance=node_instance, 
name=op, inputs=inputs))
+
+    signal = events.on_failure_task_signal
+    with events_collector(signal) as collected:
+        try:
+            execute_workflow(mock_workflow, context, executor)
+        except ExecutorException:
+            pass
+
+    props = _node_instance(context).runtime_properties
+    assert props[key] == first_value
+
+    exceptions = [event['kwargs']['exception'] for event in 
collected.get(signal, [])]
+    if expected_failure:
+        assert exceptions
+        exception = exceptions[-1]
+        assert isinstance(exception, StorageError)
+        assert 'Version conflict' in str(exception)
+    else:
+        assert not exceptions
+
+
+@pytest.fixture
+def executor():
+    result = process.ProcessExecutor(python_path=[tests.ROOT_DIR])
+    yield result
+    result.close()
+
+
+@pytest.fixture
+def context(tmpdir):
+    result = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir)))
+    yield result
+    storage.release_sqlite_storage(result.model)
+
+
+@pytest.fixture
+def shared_file(tmpdir):
+    return tmpdir.join('shared_file')
+
+
+def _concurrent_update(shared_file, node_instance, key, first_value, 
second_value):
+    def lock():
+        return fasteners.InterProcessLock(shared_file)
+
+    def get(key):
+        with open(shared_file) as f:
+            return json.load(f).get(key)
+
+    def set(key):
+        with open(shared_file) as f:
+            content = json.load(f)
+        content[key] = True
+        with open(shared_file, 'wb') as f:
+            json.dump(content, f)
+
+    def wait_for(key):
+        while True:
+            time.sleep(0.01)
+            with lock():
+                if get(key):
+                    break
+
+    with lock():
+        first = not get('first_in')
+        set('first_in' if first else 'second_in')
+
+    if first:
+        wait_for('second_in')
+
+    node_instance.runtime_properties[key] = first_value if first else 
second_value
+
+    if first:
+        with lock():
+            set('first_out')
+    else:
+        wait_for('first_out')
+
+    return first

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/4a72e113/tests/requirements.txt
----------------------------------------------------------------------
diff --git a/tests/requirements.txt b/tests/requirements.txt
index 0e4740f..2f0245a 100644
--- a/tests/requirements.txt
+++ b/tests/requirements.txt
@@ -11,8 +11,9 @@
 # limitations under the License.
 
 testtools
+fasteners==0.13.0
 mock==1.0.1
 pylint==1.6.4
 pytest==3.0.2
 pytest-cov==2.3.1
-pytest-mock==1.2
\ No newline at end of file
+pytest-mock==1.2

Reply via email to