SINGA-347 Create a function that supports einsum assuming have the numpy.repeat, finish the einsum function (using elementwisemult, no need to use tensordot)
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/2ec13649 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/2ec13649 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/2ec13649 Branch: refs/heads/master Commit: 2ec1364930a7a5b17640151b7cd16ae8fe75e131 Parents: 2ec06ed Author: sheyujian <[email protected]> Authored: Thu Apr 19 16:35:38 2018 +0800 Committer: sheyujian <[email protected]> Committed: Thu Apr 19 16:35:38 2018 +0800 ---------------------------------------------------------------------- python/singa/tensor.py | 60 ++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/2ec13649/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py index e6502fc..38acf6d 100644 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -1042,7 +1042,6 @@ def tensordot(A, B, axes=2): #reshape the result return res.reshape(olda + oldb) - def einsum_(A,B,ops): '''Do the matrix to matrix einsum calculation according to the operands @@ -1183,6 +1182,65 @@ def diag(A,axis=-1): A_diag = from_numpy(npA_diag) return A_diag +def einsum2(A,B,ops): + ''' + Do the matrix to matrix einsum calculation according to the operands + + Args: + A (Singa.Tensor): The first argument. + B (Singa.Tensor): The second argument. + ops(string): + the string specifies the subscripts for summation such as 'ki,kj->kij' + + Returns: Singa.Tensor + the output matirx of the einsum calculation + ''' + + if len(ops) == 0: + raise ValueError("No input operands") + + inputops, outputops = ops.split('->') + inputops = inputops.split(',') + + if A.ndim != len(inputops[0]) or B.ndim != len(inputops[1]): + raise ValueError("input dim doesn't match operands") + + sums = sorted(list((set(inputops[0]) | set(inputops[1])) - set(outputops))) + + broadcast_A = sorted(list(set(inputops[1]) - set(inputops[0]))) + broadcast_B = sorted(list(set(inputops[0]) - set(inputops[1]))) + + + outputall = sorted(list(set(inputops[0]) | set(inputops[1]))) + + sums = [outputall.index(x) for x in sums] + broadcast_idA = [inputops[1].find(x) for x in broadcast_A] + broadcast_idB = [inputops[0].find(x) for x in broadcast_B] + + broadcast_a = [B.shape[x] for x in broadcast_idA] + broadcast_b = [A.shape[x] for x in broadcast_idB] + + transpose_A = [(list(inputops[0])+broadcast_A) .index(x) for x in outputall] + transpose_B = [(list(inputops[1])+broadcast_B) .index(x) for x in outputall] + + + reshape_A = list(A.shape)+broadcast_a + reshape_B = list(B.shape)+broadcast_b + + mult_A = np.repeat(A, np.product(broadcast_a)).reshape(reshape_A).transpose(transpose_A) + mult_B = np.repeat(B, np.product(broadcast_b)).reshape(reshape_B).transpose(transpose_B) + + if mult_A.shape != mult_B.shape: + raise ValueError("error: matrix dimension mismatch") + res = eltwise_mult(mult_A, mult_B) + + sum_R = sorted(sums, reverse=True) + for i in sum_R: + res = res.sum(axis=i) + transpose_res = [sorted(list(outputops)).index(x) for x in list(outputops)] + + return res.transpose(transpose_res) + def div(lhs, rhs, ret=None):
