http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/examples/cifar10/predict.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/predict.py b/examples/cifar10/predict.py index 123818a..7cab4b9 100644 --- a/examples/cifar10/predict.py +++ b/examples/cifar10/predict.py @@ -52,7 +52,7 @@ def predict(net, images, dev, topk=5): def load_dataset(filepath): print('Loading data file %s' % filepath) with open(filepath, 'rb') as fd: - cifar10 = pickle.load(fd) + cifar10 = pickle.load(fd, encoding='latin1') image = cifar10['data'].astype(dtype=np.uint8) image = image.reshape((-1, 3, 32, 32)) label = np.asarray(cifar10['labels'], dtype=np.uint8)
http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/examples/cifar10/train.py ---------------------------------------------------------------------- diff --git a/examples/cifar10/train.py b/examples/cifar10/train.py index 8204055..9f90e58 100644 --- a/examples/cifar10/train.py +++ b/examples/cifar10/train.py @@ -48,7 +48,7 @@ import resnet def load_dataset(filepath): print('Loading data file %s' % filepath) with open(filepath, 'rb') as fd: - cifar10 = pickle.load(fd) + cifar10 = pickle.load(fd, encoding='latin1') image = cifar10['data'].astype(dtype=np.uint8) image = image.reshape((-1, 3, 32, 32)) label = np.asarray(cifar10['labels'], dtype=np.uint8) http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/python/singa/net.py ---------------------------------------------------------------------- diff --git a/python/singa/net.py b/python/singa/net.py index c49b9fa..a53fc68 100644 --- a/python/singa/net.py +++ b/python/singa/net.py @@ -487,6 +487,7 @@ class FeedForwardNet(object): f = f[0:-4] sp = snapshot.Snapshot(f, False, buffer_size) params = sp.read() + version = __version__ if 'SINGA_VERSION' in params: version = params['SINGA_VERSION'] for name, val in zip(self.param_names(), self.param_values()): http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/bd5a8f8d/python/singa/snapshot.py ---------------------------------------------------------------------- diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py index 392ab3d..a4ac988 100644 --- a/python/singa/snapshot.py +++ b/python/singa/snapshot.py @@ -36,7 +36,6 @@ from builtins import object from . import singa_wrap as singa from . import tensor - class Snapshot(object): ''' Class and member functions for singa::Snapshot. @@ -58,7 +57,7 @@ class Snapshot(object): param_name (string): name of the parameter param_val (Tensor): value tensor of the parameter ''' - self.snapshot.Write(str(param_name).encode(), param_val.singa_tensor) + self.snapshot.Write(param_name.encode(), param_val.singa_tensor) def read(self): '''Call read method to load all (param_name, param_val)
