This is an automated email from the ASF dual-hosted git repository. haibin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 8b172d7 [MXNET-294] Fix element wise multiply for csr ndarrays (#10452) 8b172d7 is described below commit 8b172d7fff080f8d3aa3596779a1e09e602fbe08 Author: Rahul Huilgol <rahulhuil...@gmail.com> AuthorDate: Tue Apr 10 11:59:41 2018 -0700 [MXNET-294] Fix element wise multiply for csr ndarrays (#10452) * initialize rhs * add test for elemwise op on same array * dont declare memory for rhs if same array * dont declare memory for rhs if same array * assign lhs to rhs if same arr * whitespace changes * lint fix * trigger ci --- src/operator/tensor/elemwise_binary_op-inl.h | 32 +++++++++++++++++----------- tests/python/unittest/test_sparse_ndarray.py | 1 + 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/src/operator/tensor/elemwise_binary_op-inl.h b/src/operator/tensor/elemwise_binary_op-inl.h index 5cd3314..15b1f0e 100644 --- a/src/operator/tensor/elemwise_binary_op-inl.h +++ b/src/operator/tensor/elemwise_binary_op-inl.h @@ -288,12 +288,16 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s, mshadow::Tensor<cpu, 1, DType> lhs_row(reinterpret_cast<DType *>(workspace.dptr_ + nr_cols * sizeof(IType)), Shape1(nr_cols)); - mshadow::Tensor<cpu, 1, DType> rhs_row(lhs_row.dptr_ + nr_cols, Shape1(nr_cols)); + mshadow::Tensor<cpu, 1, DType> rhs_row; OpBase::FillDense<IType>(s, next.shape_.Size(), IType(-1), req, next.dptr_); OpBase::FillDense<DType>(s, lhs_row.shape_.Size(), DType(0), req, lhs_row.dptr_); + if (!same_lhs_rhs) { + rhs_row = Tensor<cpu, 1, DType>(lhs_row.dptr_ + nr_cols, Shape1(nr_cols)); OpBase::FillDense<DType>(s, rhs_row.shape_.Size(), DType(0), req, rhs_row.dptr_); + } else { + rhs_row = lhs_row; } // Column indices @@ -331,17 +335,19 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s, } } - // add a row of B to rhs_row - const IType i_start_r = row_ptr_r[i]; - const IType i_end_r = row_ptr_r[i + 1]; - for (IType jj = i_start_r; jj < i_end_r; jj++) { - const IType col = col_indices_r[jj]; - rhs_row[col] += data_r[jj]; - - if (next[col] == -1) { - next[col] = head; - head = col; - ++length; + if (!same_lhs_rhs) { + // add a row of B to rhs_row + const IType i_start_r = row_ptr_r[i]; + const IType i_end_r = row_ptr_r[i + 1]; + for (IType jj = i_start_r; jj < i_end_r; jj++) { + const IType col = col_indices_r[jj]; + rhs_row[col] += data_r[jj]; + + if (next[col] == -1) { + next[col] = head; + head = col; + ++length; + } } } @@ -361,7 +367,7 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s, next[temp] = -1; lhs_row[temp] = 0; - rhs_row[temp] = 0; + if (!same_lhs_rhs) rhs_row[temp] = 0; } row_ptr_out[i + 1] = nnz; diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index ae3260a..a710038 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -261,6 +261,7 @@ def test_sparse_nd_binary(): lhs_nd = mx.nd.array(lhs).tostype(stype) rhs_nd = mx.nd.array(rhs).tostype(stype) assert_allclose(fn(lhs, rhs), fn(lhs_nd, rhs_nd).asnumpy(), rtol=1e-4, atol=1e-4) + assert_allclose(fn(lhs, lhs), fn(lhs_nd, lhs_nd).asnumpy(), rtol=1e-4, atol=1e-4) stypes = ['row_sparse', 'csr'] for stype in stypes: -- To stop receiving notification emails like this one, please contact hai...@apache.org.