SINGA-252 Use the snapshot methods to dump and load models for pysinga An argument of use_pickle was added to save() and load(), which decides to use pickle or Snapshot. By default Snapshot is used. pickle could be turn on to load previous checkpoint files
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/deb187bb Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/deb187bb Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/deb187bb Branch: refs/heads/master Commit: deb187bb8f332370733cb04a7328f705c1dead60 Parents: 7c12f40 Author: Wei Wang <[email protected]> Authored: Fri Sep 30 19:27:59 2016 +0800 Committer: Wei Wang <[email protected]> Committed: Fri Sep 30 19:27:59 2016 +0800 ---------------------------------------------------------------------- python/singa/net.py | 57 +++++++++++++++++++++++++++++++++++++----------- 1 file changed, 44 insertions(+), 13 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/deb187bb/python/singa/net.py ---------------------------------------------------------------------- diff --git a/python/singa/net.py b/python/singa/net.py index 4a9d650..caf5732 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -210,16 +210,47 @@ class FeedForwardNet(object): ret.extend(pgrad) return ret - def save(self, f, buffer_size = 10): - """Save model parameters using io/snapshot""" - sp = snapshot.Snapshot(f, True, buffer_size) - for (specs, val) in zip(self.param_specs(), self.param_values()): - val.to_host() - sp.write(specs.name, val) - - def load(self, f, buffer_size = 10): - """Load model parameters using io/snapshot""" - sp = snapshot.Snapshot(f, False, buffer_size) - params = sp.read() - for (specs, val) in zip(self.param_specs(), self.param_values()): - val.copy_data(params[specs.name]) + def save(self, f, buffer_size = 10, use_pickle=False): + '''Save model parameters using io/snapshot. + + Args: + f: file name + buffer_size: size (MB) of the IO, default setting is 10MB; Please + make sure it is larger than any single parameter object. + use_pickle(Boolean): if true, it would use pickle for dumping; + otherwise, it would use protobuf for serialization, which uses + less space. + ''' + if use_pickle: + params = {} + for (specs, val) in zip(self.param_specs(), self.param_values()): + val.to_host() + params[specs.name] = tensor.to_numpy(val) + with open(f, 'wb') as fd: + pickle.dump(params, fd) + else: + sp = snapshot.Snapshot(f, True, buffer_size) + for (specs, val) in zip(self.param_specs(), self.param_values()): + val.to_host() + sp.write(specs.name, val) + + def load(self, f, buffer_size = 10, use_pickle=False): + '''Load model parameters using io/snapshot. + + Please refer to the argument description in save(). + ''' + if use_pickle: + print 'NOTE: If your model was saved using Snapshot, '\ + 'then set use_pickle=False for loading it' + with open(f, 'rb') as fd: + params = pickle.load(fd) + for (specs, val) in zip(self.param_specs(), + self.param_values()): + val.copy_from_numpy(params[specs.name]) + else: + print 'NOTE: If your model was saved using pickle, '\ + 'then set use_pickle=True for loading it' + sp = snapshot.Snapshot(f, False, buffer_size) + params = sp.read() + for (specs, val) in zip(self.param_specs(), self.param_values()): + val.copy_data(params[specs.name])
