SINGA-252 Use the snapshot methods to dump and load models for pysinga

An argument of use_pickle was added to save() and load(), which decides to use
pickle or Snapshot.
By default Snapshot is used.
pickle could be turn on to load previous checkpoint files


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/deb187bb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/deb187bb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/deb187bb

Branch: refs/heads/master
Commit: deb187bb8f332370733cb04a7328f705c1dead60
Parents: 7c12f40
Author: Wei Wang <[email protected]>
Authored: Fri Sep 30 19:27:59 2016 +0800
Committer: Wei Wang <[email protected]>
Committed: Fri Sep 30 19:27:59 2016 +0800

----------------------------------------------------------------------
 python/singa/net.py | 57 +++++++++++++++++++++++++++++++++++++-----------
 1 file changed, 44 insertions(+), 13 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/deb187bb/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 4a9d650..caf5732 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -210,16 +210,47 @@ class FeedForwardNet(object):
             ret.extend(pgrad)
         return ret
 
-    def save(self, f, buffer_size = 10):
-        """Save model parameters using io/snapshot"""
-        sp = snapshot.Snapshot(f, True, buffer_size)
-        for (specs, val) in zip(self.param_specs(), self.param_values()):
-            val.to_host()
-            sp.write(specs.name, val)
-
-    def load(self, f, buffer_size = 10):
-        """Load model parameters using io/snapshot"""
-        sp = snapshot.Snapshot(f, False, buffer_size)
-        params = sp.read()
-        for (specs, val) in zip(self.param_specs(), self.param_values()):
-            val.copy_data(params[specs.name])
+    def save(self, f, buffer_size = 10, use_pickle=False):
+        '''Save model parameters using io/snapshot.
+
+        Args:
+            f: file name
+            buffer_size: size (MB) of the IO, default setting is 10MB; Please
+                make sure it is larger than any single parameter object.
+            use_pickle(Boolean): if true, it would use pickle for dumping;
+                otherwise, it would use protobuf for serialization, which uses
+                less space.
+        '''
+        if use_pickle:
+            params = {}
+            for (specs, val) in zip(self.param_specs(), self.param_values()):
+                val.to_host()
+                params[specs.name] = tensor.to_numpy(val)
+                with open(f, 'wb') as fd:
+                    pickle.dump(params, fd)
+        else:
+            sp = snapshot.Snapshot(f, True, buffer_size)
+            for (specs, val) in zip(self.param_specs(), self.param_values()):
+                val.to_host()
+                sp.write(specs.name, val)
+
+    def load(self, f, buffer_size = 10, use_pickle=False):
+        '''Load model parameters using io/snapshot.
+
+        Please refer to the argument description in save().
+        '''
+        if use_pickle:
+            print 'NOTE: If your model was saved using Snapshot, '\
+                    'then set use_pickle=False for loading it'
+            with open(f, 'rb') as fd:
+                params = pickle.load(fd)
+                for (specs, val) in zip(self.param_specs(),
+                                        self.param_values()):
+                    val.copy_from_numpy(params[specs.name])
+        else:
+            print 'NOTE: If your model was saved using pickle, '\
+                    'then set use_pickle=True for loading it'
+            sp = snapshot.Snapshot(f, False, buffer_size)
+            params = sp.read()
+            for (specs, val) in zip(self.param_specs(), self.param_values()):
+                val.copy_data(params[specs.name])

Reply via email to