SINGA-362 Add functions to support einsum function 1.fix one problem in device.cc 2.use add(t,0) to reset the stride and could use reshape after transpose
Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8d9eb297 Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8d9eb297 Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8d9eb297 Branch: refs/heads/master Commit: 8d9eb297dd7c2263b5face4bfcaf80a1a6680be8 Parents: 5e8f6a4 Author: sheyujian <[email protected]> Authored: Fri May 25 10:36:36 2018 +0800 Committer: sheyujian <[email protected]> Committed: Fri May 25 14:51:59 2018 +0800 ---------------------------------------------------------------------- python/singa/tensor.py | 84 ++++++++++++----------------------------- src/core/device/device.cc | 4 +- test/python/test_tensor.py | 2 - 3 files changed, 27 insertions(+), 63 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8d9eb297/python/singa/tensor.py ---------------------------------------------------------------------- diff --git a/python/singa/tensor.py b/python/singa/tensor.py index 5f38ef2..21a362a 100644 --- a/python/singa/tensor.py +++ b/python/singa/tensor.py @@ -271,24 +271,18 @@ class Tensor(object): return _call_singa_func(self.data.Clone) def repeat(self, repeats, axis): - # ret = CTensor() - # if isinstance(repeats, int): - # if axis == 9999: - # Repeats = [repeats,] - # ret = self.data.Repeat(Repeats, axis) - # else: - # Repeats = [repeats,] - # ret = self.data.Repeat(Repeats, axis) - - - # elif isinstance(repeats, tuple) or isinstance(repeats, list): - # if axis == 9999: - # ret = self.data.Repeat(list(repeats), axis) - - # elif axis >= 0: - # ret = self.data.Repeat(list(repeats), axis) - # return ret + '''Repeat data of a tensor + Args: + repeats(int or a sequence): the number that the tensor need to repeat for + axis (int):the axis to do repeat + If it is None, then the repeated tensor will be flattened.If it isn't None, + the repeats could be sequence, but it's size should match the axis's shape + + Return: + the tensor which has been repeated + + ''' t_ndim = self.ndim() if isinstance(repeats, int) or isinstance(repeats, long): if repeats < 0: @@ -1144,26 +1138,6 @@ def einsum(ops, *args): reshape_A = list(A.shape) + broadcast_a reshape_B = list(B.shape) + broadcast_b - # A_ = to_numpy(A) - # B_ = to_numpy(B) - - # mult_A = np.repeat(A_, np.product(broadcast_a)).reshape( - # reshape_A).transpose(transpose_A) - # mult_B = np.repeat(B_, np.product(broadcast_b)).reshape( - # reshape_B).transpose(transpose_B) - - # if mult_A.shape != mult_B.shape: - # raise ValueError("Error: matrix dimension mismatch") - # res_ = np.multiply(mult_A, mult_B) - - # reduce the axis and find the final transpose for the output - # sum_R = sorted(sums, reverse=True) - # for i in sum_R: - # res_ = res_.sum(axis=i) - # transpose_res = [sorted(list(outputops)).index(x) for x in list(outputops)] - # res_ = res_.transpose(transpose_res) - # res = from_numpy(res_) - # return res if len(broadcast_a) == 0: broadcast_a = [1] if len(broadcast_b) == 0: @@ -1352,24 +1326,19 @@ def tensordot (A,B,axes=2): newshape_b = (N2, N1) oldb = [b_shape[axis] for axis in notin] # do transpose and reshape to get the 2D matrix to do multiplication - A_ = to_numpy(A) - B_ = to_numpy(B) - at_ = np.transpose(A_,newaxes_a).reshape(newshape_a) - bt_ = np.transpose(B_,newaxes_b).reshape(newshape_b) - # print(at_) - # print(bt_) - at = from_numpy(at_) - bt = from_numpy(bt_) - - # A = transpose(A, newaxes_a) - # B = transpose(B, newaxes_b) - # A = - # at = Reshape(A, newshape_a) - # bt = Reshape(B, newshape_b) - # _at = to_numpy(at) - # _bt = to_numpy(bt) - # print(_at) - # print(_bt) + # A_ = to_numpy(A) + # B_ = to_numpy(B) + # at_ = np.transpose(A_,newaxes_a).reshape(newshape_a) + # bt_ = np.transpose(B_,newaxes_b).reshape(newshape_b) + # at = from_numpy(at_) + # bt = from_numpy(bt_) + + A = transpose(A, newaxes_a) + B = transpose(B, newaxes_b) + A = add(A, 0) + B = add(B, 0) + at = Reshape(A, newshape_a) + bt = Reshape(B, newshape_b) res = mult(at,bt) if len(olda + oldb) == 0: @@ -1378,10 +1347,7 @@ def tensordot (A,B,axes=2): res.reshape(tuple(olda + oldb)) else: res.reshape(tuple(olda + oldb)) - # print(res.shape) - # res_ = np.dot(at_, bt_) - # res = from_numpy(res_.reshape(olda + oldb)) - #reshape the result + return res def div(lhs, rhs, ret=None): http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8d9eb297/src/core/device/device.cc ---------------------------------------------------------------------- diff --git a/src/core/device/device.cc b/src/core/device/device.cc index 135ae3a..0c9c6a2 100644 --- a/src/core/device/device.cc +++ b/src/core/device/device.cc @@ -68,8 +68,8 @@ void Device::RepeatDataToFrom(Block* dst, Block* src, size_t nBytes, CopyDirection direct, bool broadcast_flag, int axis_shape, int shape_outer, int chunk, vector<size_t> repeats, int dst_offset, int src_offset) { - const char *src_data = reinterpret_cast<const char*>(src->data()) + dst_offset; - char *dst_data = reinterpret_cast<char*>(dst->mutable_data()) + src_offset; + const char *src_data = reinterpret_cast<const char*>(src->data()) + src_offset; + char *dst_data = reinterpret_cast<char*>(dst->mutable_data()) + dst_offset; for (int i = 0; i < shape_outer; i++) { for (int j = 0; j < axis_shape; j++) { http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8d9eb297/test/python/test_tensor.py ---------------------------------------------------------------------- diff --git a/test/python/test_tensor.py b/test/python/test_tensor.py index 7d83677..098994b 100644 --- a/test/python/test_tensor.py +++ b/test/python/test_tensor.py @@ -208,8 +208,6 @@ class TestTensorMethods(unittest.TestCase): ta_repeat2 = tensor.repeat(ta, 4, axis = 1) a_repeat2 = np.repeat(a, 4, axis = 1) Ta_repeat2 = tensor.to_numpy(ta_repeat2) - # print(Ta_repeat2) - # print(a_repeat2) self.assertAlmostEqual(np.sum(Ta_repeat1 - a_repeat1), 0., places=3) self.assertAlmostEqual(np.sum(Ta_repeat2 - a_repeat2), 0., places=3)
