Repository: incubator-singa
Updated Branches:
  refs/heads/master 29b87b991 -> c5b454d1e


SINGA-369 the errors of examples in testing

Fix the error of loading checkpoint files due to version mismatch.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/01ed8777
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/01ed8777
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/01ed8777

Branch: refs/heads/master
Commit: 01ed8777e1c53054bfe46c72b42a9f9d69d27cc2
Parents: 29b87b9
Author: Wang Wei <[email protected]>
Authored: Sun May 20 13:24:56 2018 +0800
Committer: Wang Wei <[email protected]>
Committed: Sun May 20 13:30:12 2018 +0800

----------------------------------------------------------------------
 python/singa/net.py      | 11 +++++++++--
 python/singa/snapshot.py |  1 -
 python/singa/tensor.py   | 22 +++++++++++-----------
 3 files changed, 20 insertions(+), 14 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/01ed8777/python/singa/net.py
----------------------------------------------------------------------
diff --git a/python/singa/net.py b/python/singa/net.py
index 42570f4..8d6bc85 100644
--- a/python/singa/net.py
+++ b/python/singa/net.py
@@ -60,6 +60,7 @@ from __future__ import absolute_import
 from builtins import zip
 from builtins import str
 from builtins import object
+import numpy as np
 import os
 
 from .proto.model_pb2 import kTrain, kEval
@@ -464,6 +465,8 @@ class FeedForwardNet(object):
             if f.endswith('.bin'):
                 f = f[0:-4]
             sp = snapshot.Snapshot(f, True, buffer_size)
+            v = tensor.from_numpy(np.array([__version__]))
+            params['SINGA_VERSION'] = v
             for (name, val) in zip(self.param_names(), self.param_values()):
                 val.to_host()
                 sp.write(name, val)
@@ -493,7 +496,7 @@ class FeedForwardNet(object):
                     f = f + '.pickle'
             assert os.path.exists(f), 'file not exists %s w/o .pickle' % f
             with open(f, 'rb') as fd:
-                params = pickle.load(fd,encoding='iso-8859-1')
+                params = pickle.load(fd, encoding='iso-8859-1')
         else:
             print('NOTE: If your model was saved using pickle, '
                   'then set use_pickle=True for loading it')
@@ -501,9 +504,13 @@ class FeedForwardNet(object):
                 f = f[0:-4]
             sp = snapshot.Snapshot(f, False, buffer_size)
             params = sp.read()
-        version = __version__
+
         if 'SINGA_VERSION' in params:
             version = params['SINGA_VERSION']
+            if isinstance(version, tensor.Tensor):
+                version = tensor.to_numpy(version)[0]
+        else:
+            version = 1100
         for name, val in zip(self.param_names(), self.param_values()):
             name = get_name(name)
             if name not in params:

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/01ed8777/python/singa/snapshot.py
----------------------------------------------------------------------
diff --git a/python/singa/snapshot.py b/python/singa/snapshot.py
index 7c97f0f..cae151d 100644
--- a/python/singa/snapshot.py
+++ b/python/singa/snapshot.py
@@ -31,7 +31,6 @@ Example usages::
 '''
 from __future__ import absolute_import
 
-from builtins import str
 from builtins import object
 from . import singa_wrap as singa
 from . import tensor

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/01ed8777/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index d5bd2e0..8f36775 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -1011,7 +1011,6 @@ def einsum(ops, *args):
           [179 221]]]
     '''
 
-
     if len(ops) == 0:
         raise ValueError("No input operands")
 
@@ -1037,7 +1036,7 @@ def einsum(ops, *args):
     # to get all the indices in input
     outputall = sorted(list(set(inputops[0]) | set(inputops[1])))
 
-    ## Map indices to axis integers
+    # Map indices to axis integers
     sums = [outputall.index(x) for x in sums]
     broadcast_idA = [inputops[1].find(x) for x in broadcast_A]
     broadcast_idB = [inputops[0].find(x) for x in broadcast_B]
@@ -1045,9 +1044,12 @@ def einsum(ops, *args):
     broadcast_a = [B.shape[x] for x in broadcast_idA]
     broadcast_b = [A.shape[x] for x in broadcast_idB]
 
-    # get the the transpose and reshape parameter used in the elementwise 
calculation
-    transpose_A = [(list(inputops[0]) + broadcast_A).index(x) for x in 
outputall]
-    transpose_B = [(list(inputops[1]) + broadcast_B).index(x) for x in 
outputall]
+    # get the the transpose and reshape parameter used in the elementwise
+    # calculation
+    transpose_A = [(list(inputops[0]) + broadcast_A).index(x)
+                   for x in outputall]
+    transpose_B = [(list(inputops[1]) + broadcast_B).index(x)
+                   for x in outputall]
 
     reshape_A = list(A.shape) + broadcast_a
     reshape_B = list(B.shape) + broadcast_b
@@ -1055,8 +1057,10 @@ def einsum(ops, *args):
     A_ = to_numpy(A)
     B_ = to_numpy(B)
 
-    mult_A = np.repeat(A_, 
np.product(broadcast_a)).reshape(reshape_A).transpose(transpose_A)
-    mult_B = np.repeat(B_, 
np.product(broadcast_b)).reshape(reshape_B).transpose(transpose_B)
+    mult_A = np.repeat(A_, np.product(broadcast_a)).reshape(
+        reshape_A).transpose(transpose_A)
+    mult_B = np.repeat(B_, np.product(broadcast_b)).reshape(
+        reshape_B).transpose(transpose_B)
 
     if mult_A.shape != mult_B.shape:
         raise ValueError("Error: matrix dimension mismatch")
@@ -1073,9 +1077,6 @@ def einsum(ops, *args):
     return res
 
 
-
-
-
 def div(lhs, rhs, ret=None):
     '''Elementi-wise division.
 
@@ -1255,4 +1256,3 @@ def copy_from_numpy(data, np_array):
         data.CopyIntDataFromHostPtr(np_array)
     else:
         print('Not implemented yet for ', dt)
-

Reply via email to