Repository: incubator-singa Updated Branches: refs/heads/master 9090160ed -> 08675e3a5
SINGA-347 Create a function that supports einsum provide the tensordot function to do the tensor multiplication of two tensors along specified axes Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e27498df Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e27498df Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e27498df Branch: refs/heads/master Commit: e27498df74fecf30cb7e0408c36465791802320d Parents: b4ea650 Author: sheyujian <[email protected]> Authored: Thu Apr 12 22:58:32 2018 +0800 Committer: sheyujian <[email protected]> Committed: Wed Apr 18 13:55:31 2018 +0800 ---------------------------------------------------------------------- python/singa/tensor.py | 99 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 98 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e27498df/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py index 2fcadb4..b5f5d99 100644 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -938,7 +938,104 @@ def mult(A, B, C=None, alpha=1.0, beta=0.0): singa.MultWithScale(alpha, A.singa_tensor, B.singa_tensor, beta, C.singa_tensor) return C - +def tensordot (A,B,axes=2): + + """Returns the tensor multiplication of two tensors along specified axes. + + This is equivalent to compute dot product along the specified axes which + are treated as one axis by reshaping. + + Args: + a (Singa.Tensor): The first argument. + b (Singa.Tensor): The second argument. + axes: + - If it is an integer, then ''axes'' represent axes at the last of ''a`'' and + the first of ''b'' are used. + - If it is a pair of sequences of integers, then these two + sequences specify the list of axes for ''a'' and ''b''. The + corresponding axes are paired for sum-product. + + Return: + singa.tensor: The tensor product of ''a'' and ''b'' along the + axes specified by ''axes''. + """ + # when axes is an integer, axes_A and axes_B represent axes at the last of ''a`'' and + # the first of ''b''. For example, when axes is 1, we do the normal multiplication : + # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5) + #when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5) + try: + iter(axes) + except Exception: + axes_A = list(range(-axes, 0)) + axes_B = list(range(0, axes)) + else: + axes_A, axes_B = axes + # when axes is a pair of sequences of integers.For example, A is in shape(3,2,4), + #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a matrix in shape(3,5) + try: + na = len(axes_A) + axes_A = list(axes_B) + except TypeError: + axes_A = [axes_A] + na = 1 + try: + nb = len(axes_A) + axes_B = list(axes_B) + except TypeError: + axes_B = [axes_B] + nb = 1 + # a_shape and b_shape are the shape of tensor A and B, while nda and ndb are the dim of A and B + a_shape = A.shape + nda = A.ndim + b_shape = B.shape + ndb = B.ndim + equal = True + # to check if the length of axe_A is equal to axes_B + if na != nb: + equal = False + else: + # to make the shape match + for k in range(na): + if a_shape[axes_a[k]] != b_shape[axes_b[k]]: + equal = False + break + if axes_a[k] < 0: + axes_a[k] += nda + if axes_b[k] < 0: + axes_b[k] += ndb + if not equal: + raise ValueError("shape-mismatch for sum") + '''start to do the calculation according to the axes''' + notin = [k for k in range(nda) if k not in axes_a] + # nda is the dim of A, and axes_a is the axis for A, notin is the axis which is not in axes_A + newaxes_a = notin + axes_a + N2 = 1 + for axis in axes_a: + N2 *= a_shape[axis] + N1 = 1 + for ax in notin: + N1 *=a_shape[ax] + # newshape_a is the shape to do multiplication.For example, A is in shape(3,2,4), + #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a should be (3,8) + #olda is the shape that will be shown in the result. + newshape_a = (N1,N2) + olda = [a_shape[axis] for axis in notin] + notin = [k for k in range(ndb) if k not in axes_b] + newaxes_b = axes_b + notin + N2 = 1 + for axis in axes_b: + N2 *= b_shape[axis] + N1 = 1 + for ax in notin: + N1 *=a_shape[ax] + newshape_b = (N2, N1) + oldb = [b_shape[axis] for axis in notin] + # do transpose and reshape to get the 2D matrix to do multiplication + at = a.transpose(newaxes_a).reshape(newshape_a) + bt = b.transpose(newaxes_b).reshape(newshape_b) + res = mult(a, b) + #reshape the result + return res.reshape(olda + oldb) def div(lhs, rhs, ret=None): '''Elementi-wise division.
