Repository: incubator-singa
Updated Branches:
  refs/heads/master 9090160ed -> 08675e3a5


SINGA-347 Create a function that supports einsum
provide the tensordot function to do the tensor multiplication of two tensors 
along specified axes


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/e27498df
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/e27498df
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/e27498df

Branch: refs/heads/master
Commit: e27498df74fecf30cb7e0408c36465791802320d
Parents: b4ea650
Author: sheyujian <[email protected]>
Authored: Thu Apr 12 22:58:32 2018 +0800
Committer: sheyujian <[email protected]>
Committed: Wed Apr 18 13:55:31 2018 +0800

----------------------------------------------------------------------
 python/singa/tensor.py | 99 ++++++++++++++++++++++++++++++++++++++++++++-
 1 file changed, 98 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/e27498df/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 2fcadb4..b5f5d99 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -938,7 +938,104 @@ def mult(A, B, C=None, alpha=1.0, beta=0.0):
         singa.MultWithScale(alpha, A.singa_tensor, B.singa_tensor,
                             beta, C.singa_tensor)
         return C
-
+def tensordot (A,B,axes=2):
+    
+    """Returns the tensor multiplication of two tensors along specified axes.
+        
+    This is equivalent to compute dot product along the specified axes which
+    are treated as one axis by reshaping.
+        
+    Args:
+        a (Singa.Tensor): The first argument.
+        b (Singa.Tensor): The second argument.
+        axes:
+        - If it is an integer, then ''axes'' represent axes at the last of 
''a`'' and
+        the first of ''b'' are used.
+        - If it is a pair of sequences of integers, then these two
+        sequences specify the list of axes for ''a'' and ''b''. The
+        corresponding axes are paired for sum-product.
+        
+    Return:
+        singa.tensor: The tensor  product of ''a'' and ''b'' along the
+        axes specified by ''axes''.
+        """
+    # when axes is an integer, axes_A and axes_B represent axes at the last of 
''a`'' and
+    # the first of ''b''. For example, when axes is 1, we do the normal 
multiplication :
+    # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix 
in shape(3,2,2,5)
+    #when axes is 2 and A,B are in the same shape, it will return a matrix in 
shape(3,5)
+    try:
+        iter(axes)
+    except Exception:
+        axes_A = list(range(-axes, 0))
+        axes_B = list(range(0, axes))
+    else:
+        axes_A, axes_B = axes
+    # when axes is a pair of sequences of integers.For example, A is in 
shape(3,2,4),
+    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), it will return a 
matrix in shape(3,5)
+    try:
+        na = len(axes_A)
+        axes_A = list(axes_B)
+    except TypeError:
+        axes_A = [axes_A]
+        na = 1
+    try:
+        nb = len(axes_A)
+        axes_B = list(axes_B)
+    except TypeError:
+        axes_B = [axes_B]
+        nb = 1
+    # a_shape and b_shape are the shape of tensor A and B, while nda and ndb 
are the dim of A and B
+    a_shape = A.shape
+    nda = A.ndim
+    b_shape = B.shape
+    ndb = B.ndim
+    equal = True
+    # to check if the length of axe_A is equal to axes_B
+    if na != nb:
+        equal = False
+    else:
+    # to make the shape match
+        for k in range(na):
+            if a_shape[axes_a[k]] != b_shape[axes_b[k]]:
+                equal = False
+                break
+            if axes_a[k] < 0:
+                axes_a[k] += nda
+            if axes_b[k] < 0:
+                axes_b[k] += ndb
+    if not equal:
+        raise ValueError("shape-mismatch for sum")
+    '''start to do the calculation according to the axes'''
+    notin = [k for k in range(nda) if k not in axes_a]
+    # nda is the dim of A, and axes_a is the axis for A, notin is the axis 
which is not in axes_A
+    newaxes_a = notin + axes_a
+    N2 = 1
+    for axis in axes_a:
+        N2 *= a_shape[axis]
+    N1 = 1
+    for ax in notin:
+        N1 *=a_shape[ax]
+    # newshape_a is the shape to do multiplication.For example, A is in 
shape(3,2,4),
+    #B is in shape(4,2,5), we set axes as ([1,2],[1,0]), then newshape_a 
should be (3,8)
+    #olda is the shape that will be shown in the result.
+    newshape_a = (N1,N2)
+    olda = [a_shape[axis] for axis in notin]
+    notin = [k for k in range(ndb) if k not in axes_b]
+    newaxes_b = axes_b + notin
+    N2 = 1
+    for axis in axes_b:
+        N2 *= b_shape[axis]
+    N1 = 1
+    for ax in notin:
+        N1 *=a_shape[ax]
+    newshape_b = (N2, N1)
+    oldb = [b_shape[axis] for axis in notin]
+    # do transpose and reshape to get the 2D matrix to do multiplication
+    at = a.transpose(newaxes_a).reshape(newshape_a)
+    bt = b.transpose(newaxes_b).reshape(newshape_b)
+    res = mult(a, b)
+    #reshape the result
+    return res.reshape(olda + oldb)
 
 def div(lhs, rhs, ret=None):
     '''Elementi-wise division.

Reply via email to