SINGA-347 Create a function that supports einsum 1.test the tensordot function and fix some error in the function 2.tweak the code to be more readable and fix the some error in the comment
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/9a3ce585 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/9a3ce585 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/9a3ce585 Branch: refs/heads/master Commit: 9a3ce585e468f4f5f17b8ea82658baf9d6ccd2aa Parents: e27498d Author: sheyujian <[email protected]> Authored: Fri Apr 13 14:18:08 2018 +0800 Committer: sheyujian <[email protected]> Committed: Wed Apr 18 13:56:27 2018 +0800 ---------------------------------------------------------------------- python/singa/tensor.py | 85 +++++++++++++++++++++++---------------------- 1 file changed, 43 insertions(+), 42 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/9a3ce585/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py index b5f5d99..49fa052 100644 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -938,52 +938,53 @@ def mult(A, B, C=None, alpha=1.0, beta=0.0): singa.MultWithScale(alpha, A.singa_tensor, B.singa_tensor, beta, C.singa_tensor) return C -def tensordot (A,B,axes=2): - + + +def tensordot(A, B, axes=2): """Returns the tensor multiplication of two tensors along specified axes. - + This is equivalent to compute dot product along the specified axes which are treated as one axis by reshaping. - + Args: a (Singa.Tensor): The first argument. b (Singa.Tensor): The second argument. axes: - - If it is an integer, then ''axes'' represent axes at the last of ''a`'' and - the first of ''b'' are used. - - If it is a pair of sequences of integers, then these two - sequences specify the list of axes for ''a'' and ''b''. The - corresponding axes are paired for sum-product. - + - If it is an integer, then ''axes'' represent axes at the last of ''a`'' and + the first of ''b'' are used. + - If it is a pair of sequences of integers, then these two + sequences specify the list of axes for ''a'' and ''b''. The + corresponding axes are paired for sum-product. + Return: singa.tensor: The tensor product of ''a'' and ''b'' along the axes specified by ''axes''. - """ + """ # when axes is an integer, axes_A and axes_B represent axes at the last of ''a`'' and # the first of ''b''. For example, when axes is 1, we do the normal multiplication : # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5) - #when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5) - try: - iter(axes) - except Exception: + # when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5) + if type(axes) == int: axes_A = list(range(-axes, 0)) axes_B = list(range(0, axes)) + axes_B = axes_B.reverse() else: axes_A, axes_B = axes # when axes is a pair of sequences of integers.For example, A is in shape(3,2,4), - #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5) - try: + # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5) + if isinstance(axes_A, list): na = len(axes_A) - axes_A = list(axes_B) - except TypeError: + axes_A = list(axes_A) + else: axes_A = [axes_A] na = 1 - try: - nb = len(axes_A) + if isinstance(axes_B, list): + nb = len(axes_B) axes_B = list(axes_B) - except TypeError: + else: axes_B = [axes_B] nb = 1 + # a_shape and b_shape are the shape of tensor A and B, while nda and ndb are the dim of A and B a_shape = A.shape nda = A.ndim @@ -994,47 +995,47 @@ def tensordot (A,B,axes=2): if na != nb: equal = False else: - # to make the shape match + # to make the shape match for k in range(na): - if a_shape[axes_a[k]] != b_shape[axes_b[k]]: + if a_shape[axes_A[k]] != b_shape[axes_B[k]]: equal = False break - if axes_a[k] < 0: - axes_a[k] += nda - if axes_b[k] < 0: - axes_b[k] += ndb + if axes_A[k] < 0: + axes_A[k] += nda + if axes_B[k] < 0: + axes_B[k] += ndb if not equal: raise ValueError("shape-mismatch for sum") '''start to do the calculation according to the axes''' - notin = [k for k in range(nda) if k not in axes_a] + + notin = [k for k in range(nda) if k not in axes_A] # nda is the dim of A, and axes_a is the axis for A, notin is the axis which is not in axes_A - newaxes_a = notin + axes_a + newaxes_a = notin + axes_A N2 = 1 - for axis in axes_a: + for axis in axes_A: N2 *= a_shape[axis] N1 = 1 for ax in notin: - N1 *=a_shape[ax] + N1 *= a_shape[ax] # newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4), - #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,8) - #olda is the shape that will be shown in the result. - newshape_a = (N1,N2) + # B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,5) + # olda is the shape that will be shown in the result. + newshape_a = (N1, N2) olda = [a_shape[axis] for axis in notin] - notin = [k for k in range(ndb) if k not in axes_b] - newaxes_b = axes_b + notin + newaxes_b = axes_B + notin N2 = 1 - for axis in axes_b: + for axis in axes_B: N2 *= b_shape[axis] N1 = 1 - for ax in notin: - N1 *=a_shape[ax] + for bx in notin: + N1 *= b_shape[bx] newshape_b = (N2, N1) oldb = [b_shape[axis] for axis in notin] # do transpose and reshape to get the 2D matrix to do multiplication at = a.transpose(newaxes_a).reshape(newshape_a) bt = b.transpose(newaxes_b).reshape(newshape_b) - res = mult(a, b) - #reshape the result + res = mult(at, bt) + # reshape the result return res.reshape(olda + oldb) def div(lhs, rhs, ret=None):
