Repository: incubator-singa
Updated Branches:
  refs/heads/master e3df3bd76 -> deb187bb8


SINGA-252 Use the snapshot methods to dump and load models for pysinga

Use the snapshot methods to dump and load models for pysinga to make the
model checkpoint slimmer.

Previously we use Pickle to checkpoint in pysinga, which make the
model checkpoints heavier than using io/snapshot which leveraging
protobuf to serialize parameters of models.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/538bdac5
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/538bdac5
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/538bdac5

Branch: refs/heads/master
Commit: 538bdac5cf62a9c74750d7940183ea1ed4669c73
Parents: e3df3bd
Author: WANG Ji <[email protected]>
Authored: Tue Sep 27 15:03:56 2016 +0800
Committer: WANG Ji <[email protected]>
Committed: Tue Sep 27 15:03:56 2016 +0800

----------------------------------------------------------------------
 examples/cifar10/predict.py |  2 +-
 examples/cifar10/train.py   |  2 +-
 python/singa/net.py         | 23 ++++++++--------
 python/singa/snapshot.py    | 59 ++++++++++++++++++++++++++++++++++++++++
 src/api/io_snapshot.i       | 46 +++++++++++++++++++++++++++++++
 src/api/singa.i             |  1 +
 6 files changed, 119 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/examples/cifar10/predict.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/predict.py b/examples/cifar10/predict.py
index f2150f4..dca44fe 100644
--- a/examples/cifar10/predict.py
+++ b/examples/cifar10/predict.py
@@ -81,7 +81,7 @@ def compute_image_mean(train_dir):
 
 if __name__ == '__main__':
     model = alexnet.create_net(True)
-    model.load('model.bin')  # the checkpoint from train.py
+    model.load('model', 20)  # the checkpoint from train.py
     dev = device.get_default_device()
     model.to_device(dev)
 

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/examples/cifar10/train.py
----------------------------------------------------------------------
diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py
index d2d70df..671c861 100644
--- a/examples/cifar10/train.py
+++ b/examples/cifar10/train.py
@@ -156,7 +156,7 @@ def train(data, net, max_epoch, get_lr, weight_decay, 
batch_size=100,
 
         print 'test loss = %f, test accuracy = %f' \
             % (loss / num_test_batch, acc / num_test_batch)
-    net.save('model.bin')  # save model params into checkpoint file
+    net.save('model', 20)  # save model params into checkpoint file
 
 if __name__ == '__main__':
     parser = argparse.ArgumentParser(description='Train vgg/alexnet for 
cifar10')

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 61603c6..4a9d650 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -23,6 +23,7 @@ functions for net info, e.g., parameters.
 from .proto.model_pb2 import kTrain, kEval
 import tensor
 import layer
+import snapshot
 import cPickle as pickle
 
 '''For display training information, e.g L1 value of layer data'''
@@ -209,18 +210,16 @@ class FeedForwardNet(object):
             ret.extend(pgrad)
         return ret
 
-    def save(self, f):
-        """Save model parameters using cpickle"""
-        params = {}
+    def save(self, f, buffer_size = 10):
+        """Save model parameters using io/snapshot"""
+        sp = snapshot.Snapshot(f, True, buffer_size)
         for (specs, val) in zip(self.param_specs(), self.param_values()):
             val.to_host()
-            params[specs.name] = tensor.to_numpy(val)
-        with open(f, 'wb') as fd:
-            pickle.dump(params, fd)
-
-    def load(self, f):
-        """Load model parameters using cpickle"""
-        with open(f, 'rb') as fd:
-            params = pickle.load(fd)
+            sp.write(specs.name, val)
+
+    def load(self, f, buffer_size = 10):
+        """Load model parameters using io/snapshot"""
+        sp = snapshot.Snapshot(f, False, buffer_size)
+        params = sp.read()
         for (specs, val) in zip(self.param_specs(), self.param_values()):
-            val.copy_from_numpy(params[specs.name])
+            val.copy_data(params[specs.name])

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/python/singa/snapshot.py
----------------------------------------------------------------------
diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py
new file mode 100644
index 0000000..c259850
--- /dev/null
+++ b/python/singa/snapshot.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# =============================================================================
+'''
+This script includes io::snapshot class and its methods.
+
+'''
+
+from . import singa_wrap as singa
+import tensor
+
+class Snapshot(object):
+    ''' Class and member functions for singa::Snapshot.
+
+    '''
+    def __init__(self, f, mode, buffer_size = 10):
+        '''Snapshot constructor given file name and R/W mode.
+
+        Args:
+            file (string): snapshot file name.
+            mode (boolean): True for write, False for read
+            buffer_size (int): Buffer size (in MB), default is 10
+        '''
+        self.snapshot = singa.Snapshot(f, mode, buffer_size)
+    
+    def write(self, param_name, param_val):
+        '''Call Write method to write a parameter
+
+        Args:
+            param_name (string): name of the parameter
+            param_val (Tensor): value tensor of the parameter
+        '''
+        self.snapshot.Write(str(param_name), param_val.singa_tensor)
+    def read(self):
+        '''Call read method to load all (param_name, param_val)
+
+        Returns:
+            a dict of (parameter name, parameter Tensor)
+        '''
+        params = {}
+        p = self.snapshot.Read();
+        for (param_name, param_val) in p:
+            print param_name
+            params[param_name] = tensor.from_raw_tensor(param_val)
+        return params

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/src/api/io_snapshot.i
----------------------------------------------------------------------
diff --git a/src/api/io_snapshot.i b/src/api/io_snapshot.i
new file mode 100644
index 0000000..2203295
--- /dev/null
+++ b/src/api/io_snapshot.i
@@ -0,0 +1,46 @@
+/************************************************************
+*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*
+*************************************************************/
+
+/*interface file for swig */
+
+%module io_snapshot
+
+%{
+#include "singa/io/snapshot.h"
+%}
+
+namespace std{
+%template(nametensorPair) std::pair<string, singa::Tensor>;
+%template(nametensorVec) std::vector<std::pair<string, singa::Tensor>>;
+}
+
+namespace singa {
+
+class Snapshot {
+ public:
+  enum Mode { kRead, kWrite };
+  Snapshot(const std::string& prefix, Mode mode, int max_param_size = 10);
+  ~Snapshot() {}
+  std::vector<std::pair<std::string, Tensor>> Read();
+  void Write(const std::string& key, const Tensor& param);
+};
+
+}

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/src/api/singa.i
----------------------------------------------------------------------
diff --git a/src/api/singa.i b/src/api/singa.i
index 12f46f3..3fc3b47 100644
--- a/src/api/singa.i
+++ b/src/api/singa.i
@@ -29,3 +29,4 @@
 %include "model_optimizer.i"
 %include "model_loss.i"
 %include "model_metric.i"
+%include "io_snapshot.i"

Reply via email to