Repository: incubator-singa Updated Branches: refs/heads/master e3df3bd76 -> deb187bb8
SINGA-252 Use the snapshot methods to dump and load models for pysinga Use the snapshot methods to dump and load models for pysinga to make the model checkpoint slimmer. Previously we use Pickle to checkpoint in pysinga, which make the model checkpoints heavier than using io/snapshot which leveraging protobuf to serialize parameters of models. Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/538bdac5 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/538bdac5 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/538bdac5 Branch: refs/heads/master Commit: 538bdac5cf62a9c74750d7940183ea1ed4669c73 Parents: e3df3bd Author: WANG Ji <[email protected]> Authored: Tue Sep 27 15:03:56 2016 +0800 Committer: WANG Ji <[email protected]> Committed: Tue Sep 27 15:03:56 2016 +0800 ---------------------------------------------------------------------- examples/cifar10/predict.py | 2 +- examples/cifar10/train.py | 2 +- python/singa/net.py | 23 ++++++++-------- python/singa/snapshot.py | 59 ++++++++++++++++++++++++++++++++++++++++ src/api/io_snapshot.i | 46 +++++++++++++++++++++++++++++++ src/api/singa.i | 1 + 6 files changed, 119 insertions(+), 14 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/examples/cifar10/predict.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/predict.py b/examples/cifar10/predict.py index f2150f4..dca44fe 100644 --- a/examples/cifar10/predict.py +++ b/examples/cifar10/predict.py @@ -81,7 +81,7 @@ def compute_image_mean(train_dir): if __name__ == '__main__': model = alexnet.create_net(True) - model.load('model.bin') # the checkpoint from train.py + model.load('model', 20) # the checkpoint from train.py dev = device.get_default_device() model.to_device(dev) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/examples/cifar10/train.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py index d2d70df..671c861 100644 --- a/examples/cifar10/train.py +++ b/examples/cifar10/train.py @@ -156,7 +156,7 @@ def train(data, net, max_epoch, get_lr, weight_decay, batch_size=100, print 'test loss = %f, test accuracy = %f' \ % (loss / num_test_batch, acc / num_test_batch) - net.save('model.bin') # save model params into checkpoint file + net.save('model', 20) # save model params into checkpoint file if __name__ == '__main__': parser = argparse.ArgumentParser(description='Train vgg/alexnet for cifar10') http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/python/singa/net.py ---------------------------------------------------------------------- diff --git a/python/singa/net.py b/python/singa/net.py index 61603c6..4a9d650 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -23,6 +23,7 @@ functions for net info, e.g., parameters. from .proto.model_pb2 import kTrain, kEval import tensor import layer +import snapshot import cPickle as pickle '''For display training information, e.g L1 value of layer data''' @@ -209,18 +210,16 @@ class FeedForwardNet(object): ret.extend(pgrad) return ret - def save(self, f): - """Save model parameters using cpickle""" - params = {} + def save(self, f, buffer_size = 10): + """Save model parameters using io/snapshot""" + sp = snapshot.Snapshot(f, True, buffer_size) for (specs, val) in zip(self.param_specs(), self.param_values()): val.to_host() - params[specs.name] = tensor.to_numpy(val) - with open(f, 'wb') as fd: - pickle.dump(params, fd) - - def load(self, f): - """Load model parameters using cpickle""" - with open(f, 'rb') as fd: - params = pickle.load(fd) + sp.write(specs.name, val) + + def load(self, f, buffer_size = 10): + """Load model parameters using io/snapshot""" + sp = snapshot.Snapshot(f, False, buffer_size) + params = sp.read() for (specs, val) in zip(self.param_specs(), self.param_values()): - val.copy_from_numpy(params[specs.name]) + val.copy_data(params[specs.name]) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/python/singa/snapshot.py ---------------------------------------------------------------------- diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py new file mode 100644 index 0000000..c259850 --- /dev/null +++ b/python/singa/snapshot.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# ============================================================================= +''' +This script includes io::snapshot class and its methods. + +''' + +from . import singa_wrap as singa +import tensor + +class Snapshot(object): + ''' Class and member functions for singa::Snapshot. + + ''' + def __init__(self, f, mode, buffer_size = 10): + '''Snapshot constructor given file name and R/W mode. + + Args: + file (string): snapshot file name. + mode (boolean): True for write, False for read + buffer_size (int): Buffer size (in MB), default is 10 + ''' + self.snapshot = singa.Snapshot(f, mode, buffer_size) + + def write(self, param_name, param_val): + '''Call Write method to write a parameter + + Args: + param_name (string): name of the parameter + param_val (Tensor): value tensor of the parameter + ''' + self.snapshot.Write(str(param_name), param_val.singa_tensor) + def read(self): + '''Call read method to load all (param_name, param_val) + + Returns: + a dict of (parameter name, parameter Tensor) + ''' + params = {} + p = self.snapshot.Read(); + for (param_name, param_val) in p: + print param_name + params[param_name] = tensor.from_raw_tensor(param_val) + return params http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/src/api/io_snapshot.i ---------------------------------------------------------------------- diff --git a/src/api/io_snapshot.i b/src/api/io_snapshot.i new file mode 100644 index 0000000..2203295 --- /dev/null +++ b/src/api/io_snapshot.i @@ -0,0 +1,46 @@ +/************************************************************ +* +* Licensed to the Apache Software Foundation (ASF) under one +* or more contributor license agreements. See the NOTICE file +* distributed with this work for additional information +* regarding copyright ownership. The ASF licenses this file +* to you under the Apache License, Version 2.0 (the +* "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, +* software distributed under the License is distributed on an +* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +* KIND, either express or implied. See the License for the +* specific language governing permissions and limitations +* under the License. +* +*************************************************************/ + +/*interface file for swig */ + +%module io_snapshot + +%{ +#include "singa/io/snapshot.h" +%} + +namespace std{ +%template(nametensorPair) std::pair<string, singa::Tensor>; +%template(nametensorVec) std::vector<std::pair<string, singa::Tensor>>; +} + +namespace singa { + +class Snapshot { + public: + enum Mode { kRead, kWrite }; + Snapshot(const std::string& prefix, Mode mode, int max_param_size = 10); + ~Snapshot() {} + std::vector<std::pair<std::string, Tensor>> Read(); + void Write(const std::string& key, const Tensor& param); +}; + +} http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/538bdac5/src/api/singa.i ---------------------------------------------------------------------- diff --git a/src/api/singa.i b/src/api/singa.i index 12f46f3..3fc3b47 100644 --- a/src/api/singa.i +++ b/src/api/singa.i @@ -29,3 +29,4 @@ %include "model_optimizer.i" %include "model_loss.i" %include "model_metric.i" +%include "io_snapshot.i"
