SINGA-347 Create a function that supports einsum 1. provide functions to calculate einsum calculation but still needs one function to make the tensor diagonalization along one axis 2. give a credit to numpy for tensordot
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/f595f10c Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/f595f10c Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/f595f10c Branch: refs/heads/master Commit: f595f10c2556ef1dad27dfc7a532a09ee85665be Parents: 9a3ce58 Author: sheyujian <[email protected]> Authored: Mon Apr 16 06:27:38 2018 +0800 Committer: sheyujian <[email protected]> Committed: Wed Apr 18 13:56:27 2018 +0800 ---------------------------------------------------------------------- python/singa/tensor.py | 79 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 77 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/f595f10c/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py index 49fa052..fde1667 100644 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -959,15 +959,17 @@ def tensordot(A, B, axes=2): Return: singa.tensor: The tensor product of ''a'' and ''b'' along the axes specified by ''axes''. + + Thanks to numpy.tensordot. """ # when axes is an integer, axes_A and axes_B represent axes at the last of ''a`'' and # the first of ''b''. For example, when axes is 1, we do the normal multiplication : # if A is in shape(3,2,4), B is in shape(4,2,5), it will return a matrix in shape(3,2,2,5) - # when axes is 2 and A,B are in the same shape, it will return a matrix in shape(3,5) + # when axes is 2 and A,B are in the shape (3,2,4) and (2,4,5), it will return a matrix in shape(3,5) if type(axes) == int: axes_A = list(range(-axes, 0)) axes_B = list(range(0, axes)) - axes_B = axes_B.reverse() + axes_B = axes_B else: axes_A, axes_B = axes # when axes is a pair of sequences of integers.For example, A is in shape(3,2,4), @@ -1038,6 +1040,79 @@ def tensordot(A, B, axes=2): # reshape the result return res.reshape(olda + oldb) + +def einsum(A,B,ops): + ''' + Thanks to nils-werner/sparse.einsum() + ''' + + if len(ops) == 0: + raise ValueError("No input operands") + + nputops, outputops = ops.split('->') + inputops = inputops.split(',') + + # All indices that are in input AND in output are multiplies + multiplies = sorted(list(set(inputops[0]) & set(inputops[1]) & set(outputops))) + # All indices that are in input BUT NOT in output are sum contractions + sums = sorted(list((set(inputops[0]) & set(inputops[1])) - set(outputops))) + + # Map sums and indices to axis integers + multiplies = [[inop.find(x) for x in multiplies] for inop in inputops] + + sums = [[inop.find(x) for x in sums] for inop in inputops] + # Find output axes in input axes for final transpose + # Values very likely lie outside of output tensor shape, so + # just map them values to their rank (index in ordered list) + transpose = [''.join(inputops).find(x) for x in outputops] + transpose = sorted(range(len(transpose)), key = transpose.__getitem__) + + return tensormult(A,B, sum=sums, multiply=multiplies).transpose(transpose) + +def tensordotmult(A, B, sum=None, multiply=None): + if sum is None: + sum = [[], []] + else: + sum = list(sum) + + if multiply is None: + multiply = [[], []] + else: + multiply = list(multiply) + + # For each multiply[0] we are adding one axis, thus we need to increment + # all following items by one: (0, 1, 2) -> (0, 2, 4) + idx = multipliessort(multiply[0]) + post_multiply = multiply[0] + for i, v in enumerate(idx): + post_multiply[v] += i + + for i in post_multiply: + A = diag(A,i) + + sum[0] += post_multiply + sum[1] += multiply[1] + + return tensordot(A, B, axes=sum) + +def multipliessort(multiplies): + + if multiplies is None: + multiplies = [] + else: + multiplies = list(multiplies) + + sort_multiplies = sorted(enumerate(multiplies), key=lambda x: x[1]) + idx = [x[0] for x in sort_multiplies] + return idx + +def diag(A,axis = -1): + A_diag = clone(A) + '''sheyujian todo: to make a tensor's axis to be diagonalization''' + return A_diag + + + def div(lhs, rhs, ret=None): '''Elementi-wise division.
