Wrap all state-dependent method in job_state with a decorator that
ensures that file access is properly serialized and that the
in-memory and on-disk state are correctly synchronized at the start
and end of the method.

Note that this specifically does not address:
 1) thread safety
 2) race conditions involving multiple processes trying to alter the
    same state
In both these cases the offending code need to do its own locking. I
don't recommend doing either one, in any case.

The main use case this is trying to address is to ensure that state
set in child processes propogates back up to the parent processes.
In particular, we want to make sure that state set in a client test
propogates back up to the client job code, since the client test is
run in a forked subprocess.

Risk: Low
Visibility: Eliminates a variety of cross-process job state issues.

Signed-off-by: John Admanski <[email protected]>

--- autotest/client/common_lib/base_job.py      2010-01-21 11:33:15.000000000 
-0800
+++ autotest/client/common_lib/base_job.py      2010-01-21 11:33:15.000000000 
-0800
@@ -1,4 +1,4 @@
-import os, copy, logging, errno, tempfile, cPickle as pickle, platform
+import os, copy, logging, errno, cPickle as pickle, fcntl
 
 from autotest_lib.client.common_lib import autotemp, error
 
@@ -136,71 +136,99 @@
         """Initialize the job state."""
         self._state = {}
         self._backing_file = None
+        self._backing_file_initialized = False
+        self._backing_file_lock = None
 
 
-    def get(self, namespace, name, default=NO_DEFAULT):
-        """Returns the value associated with a particular name.
+    def _with_backing_file(method):
+        """A decorator to perform a lock-read-*-write-unlock cycle.
 
-        @param namespace The namespace that the property should be stored in.
-        @param name The name the value was saved with.
-        @param default A default value to return if no state is currently
-            associated with var.
+        When applied to a method, this decorator will automatically wrap
+        calls to the method in a lock-and-read before the call followed by a
+        write-and-unlock. Any operation that is reading or writing state
+        should be decorated with this method to ensure that backing file
+        state is consistently maintained.
+        """
+        def wrapped_method(self, *args, **dargs):
+            already_have_lock = self._backing_file_lock is not None
+            if not already_have_lock:
+                self._lock_backing_file()
+            try:
+                self._read_from_backing_file()
+                try:
+                    return method(self, *args, **dargs)
+                finally:
+                    self._write_to_backing_file()
+            finally:
+                if not already_have_lock:
+                    self._unlock_backing_file()
+        wrapped_method.__name__ = method.__name__
+        wrapped_method.__doc__ = method.__doc__
+        return wrapped_method
+
+
+    def _lock_backing_file(self):
+        """Acquire a lock on the backing file."""
+        if self._backing_file:
+            self._backing_file_lock = open(self._backing_file, 'a')
+            fcntl.lockf(self._backing_file_lock, fcntl.LOCK_EX)
 
-        @returns A deep copy of the value associated with name. Note that this
-            explicitly returns a deep copy to avoid problems with mutable
-            values; mutations are not persisted or shared.
-        @raises KeyError raised when no state is associated with var and a
-            default value is not provided.
-        """
-        if self.has(namespace, name):
-            return copy.deepcopy(self._state[namespace][name])
-        elif default is self.NO_DEFAULT:
-            raise KeyError('No key %s in namespace %s' % (name, namespace))
-        else:
-            return default
 
+    def _unlock_backing_file(self):
+        """Release a lock on the backing file."""
+        if self._backing_file_lock:
+            fcntl.lockf(self._backing_file_lock, fcntl.LOCK_UN)
+            self._backing_file_lock.close()
+            self._backing_file_lock = None
 
-    def read_from_file(self, file_path):
+
+    def read_from_file(self, file_path, merge=True):
         """Read in any state from the file at file_path.
 
-        Any state specified only in-memory will be preserved. Any state
-        specified on-disk will be set in-memory, even if an in-memory
-        setting already exists. In the special case that the file does
-        not exist it is treated as empty and not a failure.
-        """
-        # if the file exists, pull out its contents
-        if file_path:
-            try:
-                on_disk_state = pickle.load(open(file_path))
-            except IOError, e:
-                if e.errno == errno.ENOENT:
-                    logging.info('Persistent state file %s does not exist',
-                                 file_path)
-                    return
-                else:
-                    raise
+        When merge=True, any state specified only in-memory will be preserved.
+        Any state specified on-disk will be set in-memory, even if an in-memory
+        setting already exists.
+
+        @param file_path The path where the state should be read from. It must
+            exist but it can be empty.
+        @param merge If true, merge the on-disk state with the in-memory
+            state. If false, replace the in-memory state with the on-disk
+            state.
+
+        @warning This method is intentionally concurrency-unsafe. It makes no
+            attempt to control concurrent access to the file at file_path.
+        """
+
+        # we can assume that the file exists
+        if os.path.getsize(file_path) == 0:
+            on_disk_state = {}
         else:
-            return
+            on_disk_state = pickle.load(open(file_path))
 
-        # merge the on-disk state with the in-memory state
-        for namespace, namespace_dict in on_disk_state.iteritems():
-            in_memory_namespace = self._state.setdefault(namespace, {})
-            for name, value in namespace_dict.iteritems():
-                if name in in_memory_namespace:
-                    if in_memory_namespace[name] != value:
-                        logging.info('Persistent value of %s.%s from %s '
-                                     'overridding existing in-memory value',
-                                     namespace, name, file_path)
-                        in_memory_namespace[name] = value
+        if merge:
+            # merge the on-disk state with the in-memory state
+            for namespace, namespace_dict in on_disk_state.iteritems():
+                in_memory_namespace = self._state.setdefault(namespace, {})
+                for name, value in namespace_dict.iteritems():
+                    if name in in_memory_namespace:
+                        if in_memory_namespace[name] != value:
+                            logging.info('Persistent value of %s.%s from %s '
+                                         'overridding existing in-memory '
+                                         'value', namespace, name, file_path)
+                            in_memory_namespace[name] = value
+                        else:
+                            logging.debug('Value of %s.%s is unchanged, '
+                                          'skipping import', namespace, name)
                     else:
-                        logging.debug('Value of %s.%s is unchanged, skipping'
-                                      'import', namespace, name)
-                else:
-                    logging.debug('Importing %s.%s from state file %s',
-                                  namespace, name, file_path)
-                    in_memory_namespace[name] = value
+                        logging.debug('Importing %s.%s from state file %s',
+                                      namespace, name, file_path)
+                        in_memory_namespace[name] = value
+        else:
+            # just replace the in-memory state with the on-disk state
+            self._state = on_disk_state
+            logging.debug('Replacing in-memory state with on-disk state '
+                          'from %s', file_path)
 
-        # flush the merged state out to disk
         self._write_to_backing_file()
 
 
@@ -209,6 +237,10 @@
 
         @param file_path The path where the state should be written out to.
             Must be writable.
+
+        @warning This method is intentionally concurrency-unsafe. It makes no
+            attempt to control concurrent access to the file at file_path, or
+            to the backing file if one exists.
         """
         outfile = open(file_path, 'w')
         try:
@@ -218,12 +250,32 @@
         logging.debug('Persistent state flushed to %s', file_path)
 
 
+    def _read_from_backing_file(self):
+        """Refresh the current state from the backing file.
+
+        If the backing file has never been read before (indicated by checking
+        self._backing_file_initialized) it will merge the file with the
+        in-memory state, rather than overwriting it.
+        """
+        if self._backing_file:
+            merge_backing_file = not self._backing_file_initialized
+            self.read_from_file(self._backing_file, merge=merge_backing_file)
+            self._backing_file_initialized = True
+
+
     def _write_to_backing_file(self):
         """Flush the current state to the backing file."""
         if self._backing_file:
             self.write_to_file(self._backing_file)
 
 
+    @_with_backing_file
+    def _synchronize_backing_file(self):
+        """Synchronizes the contents of the in-memory and on-disk state."""
+        # state is implicitly synchronized in _with_backing_file methods
+        pass
+
+
     def set_backing_file(self, file_path):
         """Change the path used as the backing file for the persistent state.
 
@@ -236,12 +288,36 @@
         @param file_path A path on the filesystem that can be read from and
             written to, or None to turn off the backing store.
         """
-        self._backing_file = None
-        self.read_from_file(file_path)
+        self._synchronize_backing_file()
         self._backing_file = file_path
-        self._write_to_backing_file()
+        self._backing_file_initialized = False
+        self._synchronize_backing_file()
+
 
+    @_with_backing_file
+    def get(self, namespace, name, default=NO_DEFAULT):
+        """Returns the value associated with a particular name.
+
+        @param namespace The namespace that the property should be stored in.
+        @param name The name the value was saved with.
+        @param default A default value to return if no state is currently
+            associated with var.
 
+        @returns A deep copy of the value associated with name. Note that this
+            explicitly returns a deep copy to avoid problems with mutable
+            values; mutations are not persisted or shared.
+        @raises KeyError raised when no state is associated with var and a
+            default value is not provided.
+        """
+        if self.has(namespace, name):
+            return copy.deepcopy(self._state[namespace][name])
+        elif default is self.NO_DEFAULT:
+            raise KeyError('No key %s in namespace %s' % (name, namespace))
+        else:
+            return default
+
+
+    @_with_backing_file
     def set(self, namespace, name, value):
         """Saves the value given with the provided name.
 
@@ -251,11 +327,11 @@
         """
         namespace_dict = self._state.setdefault(namespace, {})
         namespace_dict[name] = copy.deepcopy(value)
-        self._write_to_backing_file()
         logging.debug('Persistent state %s.%s now set to %r', namespace,
                       name, value)
 
 
+    @_with_backing_file
     def has(self, namespace, name):
         """Return a boolean indicating if namespace.name is defined.
 
@@ -268,6 +344,7 @@
         return namespace in self._state and name in self._state[namespace]
 
 
+    @_with_backing_file
     def discard(self, namespace, name):
         """If namespace.name is a defined value, deletes it.
 
@@ -278,7 +355,6 @@
             del self._state[namespace][name]
             if len(self._state[namespace]) == 0:
                 del self._state[namespace]
-            self._write_to_backing_file()
             logging.debug('Persistent state %s.%s deleted', namespace, name)
         else:
             logging.debug(
@@ -286,6 +362,7 @@
                 namespace, name)
 
 
+    @_with_backing_file
     def discard_namespace(self, namespace):
         """Delete all defined namespace.* names.
 
@@ -293,7 +370,6 @@
         """
         if namespace in self._state:
             del self._state[namespace]
-        self._write_to_backing_file()
         logging.debug('Persistent state %s.* deleted', namespace)
 
 
--- autotest/client/common_lib/base_job_unittest.py     2010-01-21 
11:33:15.000000000 -0800
+++ autotest/client/common_lib/base_job_unittest.py     2010-01-21 
11:33:15.000000000 -0800
@@ -27,10 +27,11 @@
 class stub_job_state(base_job.job_state):
     """
     Stub job state class, for replacing the job._job_state factory.
-    Doesn't actually provide an persistence, just the state handling.
+    Doesn't actually provide any persistence, just the state handling.
     """
     def __init__(self):
         self._state = {}
+        self._backing_file_lock = None
 
     def read_from_file(self, file_path):
         pass
@@ -41,9 +42,18 @@
     def set_backing_file(self, file_path):
         pass
 
+    def _read_from_backing_file(self):
+        pass
+
     def _write_to_backing_file(self):
         pass
 
+    def _lock_backing_file(self):
+        pass
+
+    def _unlock_backing_file(self):
+        pass
+
 
 class test_init(unittest.TestCase):
     class generic_tests(object):
@@ -511,17 +521,6 @@
         shutil.rmtree(self.testdir, ignore_errors=True)
 
 
-    def test_read_missing_file_is_nop(self):
-        self.assert_(not os.path.exists('doesnotexist'))
-        state = base_job.job_state()
-        state.set('namespace', 'var', 'val1')
-        state.set('namespace2', 'var', 'val2')
-        state.write_to_file('initial')
-        state.read_from_file('doesnotexist')
-        state.write_to_file('final')
-        self.assertEqual(open('initial').read(), open('final').read())
-
-
     def test_write_read_transfers_all_state(self):
         state1 = base_job.job_state()
         state1.set('ns1', 'var0', 50)
@@ -560,6 +559,19 @@
         self.assertEqual('value2', state3.get('ns', 'var2'))
 
 
+    def test_read_without_merge(self):
+        state = base_job.job_state()
+        state.set('ns', 'myvar1', 'hello')
+        state.write_to_file('backup')
+        state.discard('ns', 'myvar1')
+        state.set('ns', 'myvar2', 'goodbye')
+        self.assertFalse(state.has('ns', 'myvar1'))
+        self.assertEqual('goodbye', state.get('ns', 'myvar2'))
+        state.read_from_file('backup', merge=False)
+        self.assertEqual('hello', state.get('ns', 'myvar1'))
+        self.assertFalse(state.has('ns', 'myvar2'))
+
+
 class test_job_state_set_backing_file(unittest.TestCase):
     def setUp(self):
         self.testdir = tempfile.mkdtemp(suffix='unittest')
@@ -655,6 +667,142 @@
         self.assertEqual(456, state2.get('n5', 'var2'))
 
 
+    def test_shared_backing_file_propagates_state_to_get(self):
+        state1 = base_job.job_state()
+        state1.set_backing_file('outfile6')
+        state2 = base_job.job_state()
+        state2.set_backing_file('outfile6')
+        self.assertRaises(KeyError, state1.get, 'n6', 'shared1')
+        self.assertRaises(KeyError, state2.get, 'n6', 'shared1')
+        state1.set('n6', 'shared1', 345)
+        self.assertEqual(345, state1.get('n6', 'shared1'))
+        self.assertEqual(345, state2.get('n6', 'shared1'))
+
+
+    def test_shared_backing_file_propagates_state_to_has(self):
+        state1 = base_job.job_state()
+        state1.set_backing_file('outfile7')
+        state2 = base_job.job_state()
+        state2.set_backing_file('outfile7')
+        self.assertFalse(state1.has('n6', 'shared2'))
+        self.assertFalse(state2.has('n6', 'shared2'))
+        state1.set('n6', 'shared2', 'hello')
+        self.assertTrue(state1.has('n6', 'shared2'))
+        self.assertTrue(state2.has('n6', 'shared2'))
+
+
+    def test_shared_backing_file_propagates_state_from_discard(self):
+        state1 = base_job.job_state()
+        state1.set_backing_file('outfile8')
+        state1.set('n6', 'shared3', 10000)
+        state2 = base_job.job_state()
+        state2.set_backing_file('outfile8')
+        self.assertEqual(10000, state1.get('n6', 'shared3'))
+        self.assertEqual(10000, state2.get('n6', 'shared3'))
+        state1.discard('n6', 'shared3')
+        self.assertRaises(KeyError, state1.get, 'n6', 'shared3')
+        self.assertRaises(KeyError, state2.get, 'n6', 'shared3')
+
+
+    def test_shared_backing_file_propagates_state_from_discard_namespace(self):
+        state1 = base_job.job_state()
+        state1.set_backing_file('outfile9')
+        state1.set('n7', 'shared4', -1)
+        state1.set('n7', 'shared5', -2)
+        state2 = base_job.job_state()
+        state2.set_backing_file('outfile9')
+        self.assertEqual(-1, state1.get('n7', 'shared4'))
+        self.assertEqual(-1, state2.get('n7', 'shared4'))
+        self.assertEqual(-2, state1.get('n7', 'shared5'))
+        self.assertEqual(-2, state2.get('n7', 'shared5'))
+        state1.discard_namespace('n7')
+        self.assertRaises(KeyError, state1.get, 'n7', 'shared4')
+        self.assertRaises(KeyError, state2.get, 'n7', 'shared4')
+        self.assertRaises(KeyError, state1.get, 'n7', 'shared5')
+        self.assertRaises(KeyError, state2.get, 'n7', 'shared5')
+
+
+class test_job_state_backing_file_locking(unittest.TestCase):
+    def setUp(self):
+        self.testdir = tempfile.mkdtemp(suffix='unittest')
+        self.original_wd = os.getcwd()
+        os.chdir(self.testdir)
+
+        # create a job_state object with stub read_* and write_* methods
+        # to check that a lock is always held during a call to them
+        ut_self = self
+        class mocked_job_state(base_job.job_state):
+            def read_from_file(self, file_path, merge=True):
+                if self._backing_file:
+                    ut_self.assertNotEqual(None, self._backing_file_lock)
+                return super(mocked_job_state, self).read_from_file(
+                    file_path, merge=True)
+            def write_to_file(self, file_path):
+                if self._backing_file:
+                    ut_self.assertNotEqual(None, self._backing_file_lock)
+                return super(mocked_job_state, self).write_to_file(file_path)
+        self.state = mocked_job_state()
+        self.state.set_backing_file('backing_file')
+
+
+    def tearDown(self):
+        os.chdir(self.original_wd)
+        shutil.rmtree(self.testdir, ignore_errors=True)
+
+
+    def test_set(self):
+        self.state.set('ns1', 'var1', 100)
+
+
+    def test_get_missing(self):
+        self.assertRaises(KeyError, self.state.get, 'ns2', 'var2')
+
+
+    def test_get_present(self):
+        self.state.set('ns3', 'var3', 333)
+        self.assertEqual(333, self.state.get('ns3', 'var3'))
+
+
+    def test_set_backing_file(self):
+        self.state.set_backing_file('some_other_file')
+
+
+    def test_has_missing(self):
+        self.assertFalse(self.state.has('ns4', 'var4'))
+
+
+    def test_has_present(self):
+        self.state.set('ns5', 'var5', 55555)
+        self.assertTrue(self.state.has('ns5', 'var5'))
+
+
+    def test_discard_missing(self):
+        self.state.discard('ns6', 'var6')
+
+
+    def test_discard_present(self):
+        self.state.set('ns7', 'var7', -777)
+        self.state.discard('ns7', 'var7')
+
+
+    def test_discard_missing_namespace(self):
+        self.state.discard_namespace('ns8')
+
+
+    def test_discard_present_namespace(self):
+        self.state.set('ns8', 'var8', 80)
+        self.state.set('ns8', 'var8.1', 81)
+        self.state.discard_namespace('ns8')
+
+
+    def test_disable_backing_file(self):
+        self.state.set_backing_file(None)
+
+
+    def test_change_backing_file(self):
+        self.state.set_backing_file('another_backing_file')
+
+
 class test_job_state_property_factory(unittest.TestCase):
     def setUp(self):
         class job_stub(object):
_______________________________________________
Autotest mailing list
[email protected]
http://test.kernel.org/cgi-bin/mailman/listinfo/autotest

Reply via email to