SINGA-347 Create a function that supports einsum
1. provide functions to calculate einsum calculation but still needs one 
function to make the tensor diagonalization along one axis
2. give a credit to numpy for tensordot


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f595f10c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f595f10c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f595f10c

Branch: refs/heads/master
Commit: f595f10c2556ef1dad27dfc7a532a09ee85665be
Parents: 9a3ce58
Author: sheyujian <[email protected]>
Authored: Mon Apr 16 06:27:38 2018 +0800
Committer: sheyujian <[email protected]>
Committed: Wed Apr 18 13:56:27 2018 +0800

----------------------------------------------------------------------
 python/singa/tensor.py | 79 +++++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 77 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f595f10c/python/singa/tensor.py
----------------------------------------------------------------------
diff --git a/python/singa/tensor.py b/python/singa/tensor.py
index 49fa052..fde1667 100644
--- a/python/singa/tensor.py
+++ b/python/singa/tensor.py
@@ -959,15 +959,17 @@ def tensordot(A, B, axes=2):
     Return:
         singa.tensor: The tensor  product of ''a'' and ''b'' along the
         axes specified by ''axes''.
+
+    Thanks to numpy.tensordot.
     """
     # when axes is an integer, axes_A and axes_B represent axes at the last of 
''a`'' and
     # the first of ''b''. For example, when axes is 1, we do the normal 
multiplication :
     # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix 
in shape(3,2,2,5)
-    # when axes is 2 and A,B are in the same shape, it will return a matrix in 
shape(3,5)
+    # when axes is 2 and A,B are in the shape (3,2,4) and (2,4,5), it will 
return a matrix in shape(3,5)
     if type(axes) == int:
         axes_A = list(range(-axes, 0))
         axes_B = list(range(0, axes))
-        axes_B = axes_B.reverse()
+        axes_B = axes_B
     else:
         axes_A, axes_B = axes
     # when axes is a pair of sequences of integers.For example, A is in 
shape(3,2,4),
@@ -1038,6 +1040,79 @@ def tensordot(A, B, axes=2):
     # reshape the result
     return res.reshape(olda + oldb)
 
+
+def einsum(A,B,ops):
+    '''
+    Thanks to nils-werner/sparse.einsum()
+    '''
+
+    if len(ops) == 0:
+        raise ValueError("No input operands")
+
+    nputops, outputops = ops.split('->')
+    inputops = inputops.split(',')
+
+    # All indices that are in input AND in output are multiplies
+    multiplies = sorted(list(set(inputops[0]) & set(inputops[1]) & 
set(outputops)))
+    # All indices that are in input BUT NOT in output are sum contractions
+    sums = sorted(list((set(inputops[0]) & set(inputops[1])) - set(outputops)))
+
+    # Map sums and indices to axis integers
+    multiplies = [[inop.find(x) for x in multiplies] for inop in inputops]
+
+    sums = [[inop.find(x) for x in sums] for inop in inputops]
+    # Find output axes in input axes for final transpose
+    # Values very likely lie outside of output tensor shape, so
+    # just map them values to their rank (index in ordered list)
+    transpose = [''.join(inputops).find(x) for x in outputops]
+    transpose = sorted(range(len(transpose)), key = transpose.__getitem__)
+
+    return tensormult(A,B, sum=sums, multiply=multiplies).transpose(transpose)
+
+def tensordotmult(A, B, sum=None, multiply=None):
+    if sum is None:
+        sum = [[], []]
+    else:
+        sum = list(sum)
+
+    if multiply is None:
+        multiply = [[], []]
+    else:
+        multiply = list(multiply)
+
+    # For each multiply[0] we are adding one axis, thus we need to increment
+    # all following items by one: (0, 1, 2) -> (0, 2, 4)
+    idx = multipliessort(multiply[0])
+    post_multiply = multiply[0]
+    for i, v in enumerate(idx):
+        post_multiply[v] += i
+
+    for i in post_multiply:
+        A = diag(A,i)
+
+    sum[0] += post_multiply
+    sum[1] += multiply[1]
+
+    return tensordot(A, B, axes=sum)
+
+def multipliessort(multiplies):
+
+    if multiplies is None:
+        multiplies = []
+    else:
+        multiplies = list(multiplies)
+
+    sort_multiplies = sorted(enumerate(multiplies), key=lambda x: x[1])
+    idx = [x[0] for x in sort_multiplies]
+    return idx
+
+def diag(A,axis = -1):
+    A_diag = clone(A)
+    '''sheyujian todo: to make a tensor's axis to be diagonalization'''
+    return A_diag
+
+
+
 def div(lhs, rhs, ret=None):
     '''Elementi-wise division.
 

Reply via email to