Repository: incubator-ariatosca
Updated Branches:
  refs/heads/ARIA-63-runtime-properties-modification e81f42758 -> 9f29d2912 
(forced update)


ARIA-63 Implement attribute tracking for subprocesses


Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/9f29d291
Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/9f29d291
Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/9f29d291

Branch: refs/heads/ARIA-63-runtime-properties-modification
Commit: 9f29d2912142039d1dc8e1398eefc323803a8e73
Parents: dac4da7
Author: Dan Kilman <d...@gigaspaces.com>
Authored: Sun Jan 15 17:42:23 2017 +0200
Committer: Dan Kilman <d...@gigaspaces.com>
Committed: Tue Jan 17 04:36:23 2017 +0200

----------------------------------------------------------------------
 aria/orchestrator/workflows/executor/process.py |  49 ++--
 aria/storage/instrumentation.py                 | 124 +++++++++
 aria/storage/type.py                            |  20 +-
 tests/.pylintrc                                 |   2 +-
 .../test_process_executor_tracked_changes.py    |  95 +++++++
 tests/storage/test_instrumentation.py           | 274 +++++++++++++++++++
 6 files changed, 544 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/aria/orchestrator/workflows/executor/process.py
----------------------------------------------------------------------
diff --git a/aria/orchestrator/workflows/executor/process.py 
b/aria/orchestrator/workflows/executor/process.py
index e0a8aeb..3c86c51 100644
--- a/aria/orchestrator/workflows/executor/process.py
+++ b/aria/orchestrator/workflows/executor/process.py
@@ -42,6 +42,8 @@ import jsonpickle
 from aria.utils import imports
 from aria.orchestrator.workflows.executor import base
 from aria.orchestrator.context import serialization
+from aria.storage import instrumentation
+from aria.storage import type as storage_type
 
 _IS_WIN = os.name == 'nt'
 
@@ -139,10 +141,17 @@ class ProcessExecutor(base.BaseExecutor):
                 if message_type == 'started':
                     self._task_started(self._tasks[task_id])
                 elif message_type == 'succeeded':
-                    self._task_succeeded(self._remove_task(task_id))
+                    task = self._remove_task(task_id)
+                    instrumentation.apply_tracked_changes(
+                        tracked_changes=message['tracked_changes'],
+                        model=task.context.model)
+                    self._task_succeeded(task)
                 elif message_type == 'failed':
-                    self._task_failed(self._remove_task(task_id),
-                                      exception=message['exception'])
+                    task = self._remove_task(task_id)
+                    instrumentation.apply_tracked_changes(
+                        tracked_changes=message['tracked_changes'],
+                        model=task.context.model)
+                    self._task_failed(task, exception=message['exception'])
                 else:
                     raise RuntimeError('Invalid state')
             except BaseException as e:
@@ -227,26 +236,27 @@ class _Messenger(object):
         """Task started message"""
         self._send_message(type='started')
 
-    def succeeded(self):
+    def succeeded(self, tracked_changes):
         """Task succeeded message"""
-        self._send_message(type='succeeded')
+        self._send_message(type='succeeded', tracked_changes=tracked_changes)
 
-    def failed(self, exception):
+    def failed(self, tracked_changes, exception):
         """Task failed message"""
-        self._send_message(type='failed', exception=exception)
+        self._send_message(type='failed', tracked_changes=tracked_changes, 
exception=exception)
 
     def closed(self):
         """Executor closed message"""
         self._send_message(type='closed')
 
-    def _send_message(self, type, exception=None):
+    def _send_message(self, type, tracked_changes=None, exception=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.connect(('localhost', self.port))
         try:
             data = jsonpickle.dumps({
                 'type': type,
                 'task_id': self.task_id,
-                'exception': exception
+                'exception': exception,
+                'tracked_changes': tracked_changes
             })
             sock.send(struct.pack(_INT_FMT, len(data)))
             sock.sendall(data)
@@ -271,13 +281,20 @@ def _main():
     operation_mapping = arguments['operation_mapping']
     operation_inputs = arguments['operation_inputs']
     context_dict = arguments['context']
-    try:
-        ctx = serialization.operation_context_from_dict(context_dict)
-        task_func = imports.load_attribute(operation_mapping)
-        task_func(ctx=ctx, **operation_inputs)
-        messenger.succeeded()
-    except BaseException as e:
-        messenger.failed(exception=e)
+
+    # This must happen before any model class is loaded, because that would 
trigger
+    # the listener we are trying to remove. Once it is triggered, many other 
listeners
+    # will then be registered. At that point, it is too late.
+    storage_type.remove_mutable_association_listener()
+
+    with instrumentation.track_changes() as instrument:
+        try:
+            ctx = serialization.operation_context_from_dict(context_dict)
+            task_func = imports.load_attribute(operation_mapping)
+            task_func(ctx=ctx, **operation_inputs)
+            messenger.succeeded(tracked_changes=instrument.tracked_changes)
+        except BaseException as e:
+            messenger.failed(exception=e, 
tracked_changes=instrument.tracked_changes)
 
 if __name__ == '__main__':
     _main()

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/aria/storage/instrumentation.py
----------------------------------------------------------------------
diff --git a/aria/storage/instrumentation.py b/aria/storage/instrumentation.py
new file mode 100644
index 0000000..1023b94
--- /dev/null
+++ b/aria/storage/instrumentation.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import sqlalchemy.event
+
+from . import api
+from . import model as _model
+
+_STUB = object()
+_INSTRUMENTED = {
+    _model.NodeInstance.runtime_properties: dict
+}
+
+
+def track_changes(instrumented=None):
+    return _Instrumentation(instrumented or _INSTRUMENTED)
+
+
+class _Instrumentation(object):
+
+    def __init__(self, instrumented):
+        self.tracked_changes = {}
+        self.listeners = []
+        self._track_changes(instrumented)
+
+    def _track_changes(self, instrumented):
+        instrumented_classes = {}
+        for instrumented_attribute, attribute_type in instrumented.items():
+            self._register_set_attribute_listener(
+                instrumented_attribute=instrumented_attribute,
+                attribute_type=attribute_type)
+            instrumented_class = instrumented_attribute.parent.entity
+            instrumented_class_attributes = 
instrumented_classes.setdefault(instrumented_class, {})
+            instrumented_class_attributes[instrumented_attribute.key] = 
attribute_type
+        for instrumented_class, instrumented_attributes in 
instrumented_classes.items():
+            self._register_instance_listeners(
+                instrumented_class=instrumented_class,
+                instrumented_attributes=instrumented_attributes)
+
+    def _register_set_attribute_listener(self, instrumented_attribute, 
attribute_type):
+        def listener(target, value, *_):
+            mapi_name = api.generate_lower_name(target.__class__)
+            tracked_instances = self.tracked_changes.setdefault(mapi_name, {})
+            tracked_attributes = tracked_instances.setdefault(target.id, {})
+            if value is None:
+                current = None
+            else:
+                current = copy.deepcopy(attribute_type(value))
+            tracked_attributes[instrumented_attribute.key] = _Value(_STUB, 
current)
+            return current
+        listener_args = (instrumented_attribute, 'set', listener)
+        sqlalchemy.event.listen(*listener_args, retval=True)
+        self.listeners.append(listener_args)
+
+    def _register_instance_listeners(self, instrumented_class, 
instrumented_attributes):
+        def listener(target, *_):
+            mapi_name = api.generate_lower_name(instrumented_class)
+            tracked_instances = self.tracked_changes.setdefault(mapi_name, {})
+            tracked_attributes = tracked_instances.setdefault(target.id, {})
+            for attribute_name, attribute_type in 
instrumented_attributes.items():
+                if attribute_name not in tracked_attributes:
+                    initial = getattr(target, attribute_name)
+                    if initial is None:
+                        current = None
+                    else:
+                        current = copy.deepcopy(attribute_type(initial))
+                    tracked_attributes[attribute_name] = _Value(initial, 
current)
+                target.__dict__[attribute_name] = 
tracked_attributes[attribute_name].current
+        for listener_args in [(instrumented_class, 'load', listener),
+                              (instrumented_class, 'refresh', listener),
+                              (instrumented_class, 'refresh_flush', listener)]:
+            sqlalchemy.event.listen(*listener_args)
+            self.listeners.append(listener_args)
+
+    def restore(self):
+        for listener_args in self.listeners:
+            if sqlalchemy.event.contains(*listener_args):
+                sqlalchemy.event.remove(*listener_args)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.restore()
+
+
+class _Value(object):
+    def __init__(self, initial, current):
+        self.initial = initial
+        self.current = current
+
+    def __eq__(self, other):
+        if not isinstance(other, _Value):
+            return False
+        return self.initial == other.initial and self.current == other.current
+
+    def __hash__(self):
+        return hash(self.initial) ^ hash(self.current)
+
+
+def apply_tracked_changes(tracked_changes, model):
+    for mapi_name, tracked_instances in tracked_changes.items():
+        mapi = getattr(model, mapi_name)
+        for instance_id, tracked_attributes in tracked_instances.items():
+            instance = None
+            for attribute_name, value in tracked_attributes.items():
+                if value.initial != value.current:
+                    if not instance:
+                        instance = mapi.get(instance_id)
+                    setattr(instance, attribute_name, value.current)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/aria/storage/type.py
----------------------------------------------------------------------
diff --git a/aria/storage/type.py b/aria/storage/type.py
index ab50b0f..ec81b2c 100644
--- a/aria/storage/type.py
+++ b/aria/storage/type.py
@@ -16,7 +16,8 @@ import json
 
 from sqlalchemy import (
     TypeDecorator,
-    VARCHAR
+    VARCHAR,
+    event
 )
 
 from sqlalchemy.ext import mutable
@@ -84,5 +85,18 @@ class _MutableList(mutable.MutableList):
             raise exceptions.StorageError('SQL Storage error: 
{0}'.format(str(e)))
 
 
-_MutableList.associate_with(List)
-_MutableDict.associate_with(Dict)
+def _mutable_association_listener(mapper, cls):
+    for prop in mapper.column_attrs:
+        column_type = prop.columns[0].type
+        if isinstance(column_type, Dict):
+            _MutableDict.associate_with_attribute(getattr(cls, prop.key))
+        if isinstance(column_type, List):
+            _MutableList.associate_with_attribute(getattr(cls, prop.key))
+
+
+def remove_mutable_association_listener():
+    if event.contains(*_LISTENER_ARGS):
+        event.remove(*_LISTENER_ARGS)
+
+_LISTENER_ARGS = (mutable.mapper, 'mapper_configured', 
_mutable_association_listener)
+event.listen(*_LISTENER_ARGS)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/tests/.pylintrc
----------------------------------------------------------------------
diff --git a/tests/.pylintrc b/tests/.pylintrc
index 23251af..5de0691 100644
--- a/tests/.pylintrc
+++ b/tests/.pylintrc
@@ -77,7 +77,7 @@ confidence=
 # --enable=similarities". If you want to run only the classes checker, but have
 # no Warning level messages displayed, use"--disable=all --enable=classes
 # --disable=W"
-disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,redefined-builtin,no-self-use,missing-docstring,attribute-defined-outside-init,redefined-outer-name,import-error,redefined-variable-type,broad
 -except,protected-access,global-statement,too-many-locals
+disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,redefined-builtin,no-self-use,missing-docstring,attribute-defined-outside-init,redefined-outer-name,import-error,redefined-variable-type,broad
 -except,protected-access,global-statement,too-many-locals,abstract-method
 
 [REPORTS]
 

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
----------------------------------------------------------------------
diff --git 
a/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
 
b/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
new file mode 100644
index 0000000..1564292
--- /dev/null
+++ 
b/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from aria.orchestrator.workflows import api
+from aria.orchestrator.workflows.core import engine
+from aria.orchestrator.workflows.executor import process
+from aria.orchestrator import workflow, operation
+from aria.orchestrator.workflows import exceptions
+
+import tests
+from tests import mock
+from tests import storage
+
+
+_TEST_RUNTIME_PROPERTIES = {
+    'some': 'values', 'that': 'are', 'most': 'likely', 'only': 'set', 'here': 
'yo'
+}
+
+
+def test_track_changes_of_successful_operation(context, executor):
+    _run_workflow(context=context, executor=executor, 
op_func=_mock_success_operation)
+    _assert_tracked_changes_are_applied(context)
+
+
+def test_track_changes_of_failed_operation(context, executor):
+    with pytest.raises(exceptions.ExecutorException):
+        _run_workflow(context=context, executor=executor, 
op_func=_mock_fail_operation)
+    _assert_tracked_changes_are_applied(context)
+
+
+def _assert_tracked_changes_are_applied(context):
+    instance = 
context.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+    assert instance.runtime_properties == _TEST_RUNTIME_PROPERTIES
+
+
+def _update_runtime_properties(context):
+    context.node_instance.runtime_properties.clear()
+    context.node_instance.runtime_properties.update(_TEST_RUNTIME_PROPERTIES)
+
+
+def _run_workflow(context, executor, op_func):
+    @workflow
+    def mock_workflow(ctx, graph):
+        node_instance = ctx.model.node_instance.get_by_name(
+            mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+        node_instance.node.operations['test.op'] = {'operation': 
_operation_mapping(op_func)}
+        task = api.task.OperationTask.node_instance(instance=node_instance, 
name='test.op')
+        graph.add_tasks(task)
+        return graph
+    graph = mock_workflow(ctx=context)  # pylint: 
disable=no-value-for-parameter
+    eng = engine.Engine(executor=executor, workflow_context=context, 
tasks_graph=graph)
+    eng.execute()
+
+
+@operation
+def _mock_success_operation(ctx):
+    _update_runtime_properties(ctx)
+
+
+@operation
+def _mock_fail_operation(ctx):
+    _update_runtime_properties(ctx)
+    raise RuntimeError
+
+
+def _operation_mapping(func):
+    return '{name}.{func.__name__}'.format(name=__name__, func=func)
+
+
+@pytest.fixture
+def executor():
+    result = process.ProcessExecutor(python_path=[tests.ROOT_DIR])
+    yield result
+    result.close()
+
+
+@pytest.fixture
+def context(tmpdir):
+    result = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir)))
+    yield result
+    storage.release_sqlite_storage(result.model)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/9f29d291/tests/storage/test_instrumentation.py
----------------------------------------------------------------------
diff --git a/tests/storage/test_instrumentation.py 
b/tests/storage/test_instrumentation.py
new file mode 100644
index 0000000..b00bbd3
--- /dev/null
+++ b/tests/storage/test_instrumentation.py
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+from sqlalchemy import Column, Text, Integer, event
+
+from aria.storage import (
+    model,
+    structure,
+    type as aria_type,
+    ModelStorage,
+    sql_mapi,
+    instrumentation
+)
+from ..storage import get_sqlite_api_kwargs, release_sqlite_storage
+
+
+STUB = instrumentation._STUB
+Value = instrumentation._Value
+instruments_holder = []
+
+
+class TestInstrumentation(object):
+
+    def test_track_changes(self, storage):
+        model_kwargs = dict(
+            name='name',
+            dict1={'initial': 'value'},
+            dict2={'initial': 'value'},
+            list1=['initial'],
+            list2=['initial'],
+            int1=0,
+            int2=0,
+            string2='string')
+        model1_instance = MockModel1(**model_kwargs)
+        model2_instance = MockModel2(**model_kwargs)
+        storage.mock_model_1.put(model1_instance)
+        storage.mock_model_2.put(model2_instance)
+
+        instrument = self._track_changes({
+            MockModel1.dict1: dict,
+            MockModel1.list1: list,
+            MockModel1.int1: int,
+            MockModel1.string2: str,
+            MockModel2.dict2: dict,
+            MockModel2.list2: list,
+            MockModel2.int2: int,
+            MockModel2.name: str
+        })
+
+        assert not instrument.tracked_changes
+
+        storage_model1_instance = storage.mock_model_1.get(model1_instance.id)
+        storage_model2_instance = storage.mock_model_2.get(model2_instance.id)
+
+        storage_model1_instance.dict1 = {'hello': 'world'}
+        storage_model1_instance.dict2 = {'should': 'not track'}
+        storage_model1_instance.list1 = ['hello']
+        storage_model1_instance.list2 = ['should not track']
+        storage_model1_instance.int1 = 100
+        storage_model1_instance.int2 = 20000
+        storage_model1_instance.name = 'should not track'
+        storage_model1_instance.string2 = 'new_string'
+
+        storage_model2_instance.dict1.update({'should': 'not track'})
+        storage_model2_instance.dict2.update({'hello': 'world'})
+        storage_model2_instance.list1.append('should not track')
+        storage_model2_instance.list2.append('hello')
+        storage_model2_instance.int1 = 100
+        storage_model2_instance.int2 = 20000
+        storage_model2_instance.name = 'new_name'
+        storage_model2_instance.string2 = 'should not track'
+
+        assert instrument.tracked_changes == {
+            'mock_model_1': {
+                model1_instance.id: {
+                    'dict1': Value(STUB, {'hello': 'world'}),
+                    'list1': Value(STUB, ['hello']),
+                    'int1': Value(STUB, 100),
+                    'string2': Value(STUB, 'new_string')
+                }
+            },
+            'mock_model_2': {
+                model2_instance.id: {
+                    'dict2': Value({'initial': 'value'}, {'hello': 'world', 
'initial': 'value'}),
+                    'list2': Value(['initial'], ['initial', 'hello']),
+                    'int2': Value(STUB, 20000),
+                    'name': Value(STUB, 'new_name'),
+                }
+            }
+        }
+
+    def test_attribute_initial_none_value(self, storage):
+        instance1 = MockModel1(name='name1', dict1=None)
+        instance2 = MockModel1(name='name2', dict1=None)
+        storage.mock_model_1.put(instance1)
+        storage.mock_model_1.put(instance2)
+        instrument = self._track_changes({MockModel1.dict1: dict})
+        instance1 = storage.mock_model_1.get(instance1.id)
+        instance2 = storage.mock_model_1.get(instance2.id)
+        instance1.dict1 = {'new': 'value'}
+        assert instrument.tracked_changes == {
+            'mock_model_1': {
+                instance1.id: {'dict1': Value(STUB, {'new': 'value'})},
+                instance2.id: {'dict1': Value(None, None)},
+            }
+        }
+
+    def test_attribute_set_none_value(self, storage):
+        instance = MockModel1(name='name')
+        storage.mock_model_1.put(instance)
+        instrument = self._track_changes({
+            MockModel1.dict1: dict,
+            MockModel1.list1: list,
+            MockModel1.string2: str,
+            MockModel1.int1: int
+        })
+        instance = storage.mock_model_1.get(instance.id)
+        instance.dict1 = None
+        instance.list1 = None
+        instance.string2 = None
+        instance.int1 = None
+        assert instrument.tracked_changes == {
+            'mock_model_1': {
+                instance.id: {
+                    'dict1': Value(STUB, None),
+                    'list1': Value(STUB, None),
+                    'string2': Value(STUB, None),
+                    'int1': Value(STUB, None)
+                }
+            }
+        }
+
+    def test_restore(self):
+        instrument = self._track_changes({MockModel1.dict1: dict})
+        # set instance attribute, load instance, refresh instance and 
flush_refresh listeners
+        assert len(instrument.listeners) == 4
+        for listener_args in instrument.listeners:
+            assert event.contains(*listener_args)
+        instrument.restore()
+        assert len(instrument.listeners) == 4
+        for listener_args in instrument.listeners:
+            assert not event.contains(*listener_args)
+        return instrument
+
+    def test_restore_twice(self):
+        instrument = self.test_restore()
+        instrument.restore()
+
+    def test_instrumentation_context_manager(self, storage):
+        instance = MockModel1(name='name')
+        storage.mock_model_1.put(instance)
+        with self._track_changes({MockModel1.dict1: dict}) as instrument:
+            instance = storage.mock_model_1.get(instance.id)
+            instance.dict1 = {'new': 'value'}
+            assert instrument.tracked_changes == {
+                'mock_model_1': {instance.id: {'dict1': Value(STUB, {'new': 
'value'})}}
+            }
+            assert len(instrument.listeners) == 4
+            for listener_args in instrument.listeners:
+                assert event.contains(*listener_args)
+        for listener_args in instrument.listeners:
+            assert not event.contains(*listener_args)
+
+    def test_apply_tracked_changes(self, storage):
+        initial_values = {'dict1': {'initial': 'value'}, 'list1': ['initial']}
+        instance1_1 = MockModel1(name='instance1_1', **initial_values)
+        instance1_2 = MockModel1(name='instance1_2', **initial_values)
+        instance2_1 = MockModel2(name='instance2_1', **initial_values)
+        instance2_2 = MockModel2(name='instance2_2', **initial_values)
+        storage.mock_model_1.put(instance1_1)
+        storage.mock_model_1.put(instance1_2)
+        storage.mock_model_2.put(instance2_1)
+        storage.mock_model_2.put(instance2_2)
+
+        instrument = self._track_changes({
+            MockModel1.dict1: dict,
+            MockModel1.list1: list,
+            MockModel2.dict1: dict,
+            MockModel2.list1: list
+        })
+
+        def get_instances():
+            return (storage.mock_model_1.get(instance1_1.id),
+                    storage.mock_model_1.get(instance1_2.id),
+                    storage.mock_model_2.get(instance2_1.id),
+                    storage.mock_model_2.get(instance2_2.id))
+
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        instance1_1.dict1 = {'new': 'value'}
+        instance1_2.list1 = ['new_value']
+        instance2_1.dict1.update({'new': 'value'})
+        instance2_2.list1.append('new_value')
+
+        instrument.restore()
+        storage.mock_model_1._session.expire_all()
+
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        instance1_1.dict1 = {'overriding': 'value'}
+        instance1_2.list1 = ['overriding_value']
+        instance2_1.dict1 = {'overriding': 'value'}
+        instance2_2.list1 = ['overriding_value']
+        storage.mock_model_1.put(instance1_1)
+        storage.mock_model_1.put(instance1_2)
+        storage.mock_model_2.put(instance2_1)
+        storage.mock_model_2.put(instance2_2)
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        assert instance1_1.dict1 == {'overriding': 'value'}
+        assert instance1_2.list1 == ['overriding_value']
+        assert instance2_1.dict1 == {'overriding': 'value'}
+        assert instance2_2.list1 == ['overriding_value']
+
+        instrumentation.apply_tracked_changes(
+            tracked_changes=instrument.tracked_changes,
+            model=storage)
+
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        assert instance1_1.dict1 == {'new': 'value'}
+        assert instance1_2.list1 == ['new_value']
+        assert instance2_1.dict1 == {'initial': 'value', 'new': 'value'}
+        assert instance2_2.list1 == ['initial', 'new_value']
+
+    def _track_changes(self, instrumented):
+        instrument = instrumentation.track_changes(instrumented)
+        instruments_holder.append(instrument)
+        return instrument
+
+
+@pytest.fixture(autouse=True)
+def restore_instrumentation():
+    for instrument in instruments_holder:
+        instrument.restore()
+    del instruments_holder[:]
+
+
+@pytest.fixture
+def storage():
+    result = ModelStorage(
+        api_cls=sql_mapi.SQLAlchemyModelAPI,
+        api_kwargs=get_sqlite_api_kwargs(),
+        items=(MockModel1, MockModel2))
+    yield result
+    release_sqlite_storage(result)
+
+
+class _MockModel(structure.ModelMixin):
+    name = Column(Text)
+    dict1 = Column(aria_type.Dict)
+    dict2 = Column(aria_type.Dict)
+    list1 = Column(aria_type.List)
+    list2 = Column(aria_type.List)
+    int1 = Column(Integer)
+    int2 = Column(Integer)
+    string2 = Column(Text)
+
+
+class MockModel1(model.DeclarativeBase, _MockModel):
+    __tablename__ = 'mock_model1'
+
+
+class MockModel2(model.DeclarativeBase, _MockModel):
+    __tablename__ = 'mock_model2'

Reply via email to