This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8b172d7  [MXNET-294] Fix element wise multiply for csr ndarrays 
(#10452)
8b172d7 is described below

commit 8b172d7fff080f8d3aa3596779a1e09e602fbe08
Author: Rahul Huilgol <rahulhuil...@gmail.com>
AuthorDate: Tue Apr 10 11:59:41 2018 -0700

    [MXNET-294] Fix element wise multiply for csr ndarrays (#10452)
    
    * initialize rhs
    
    * add test for elemwise op on same array
    
    * dont declare memory for rhs if same array
    
    * dont declare memory for rhs if same array
    
    * assign lhs to rhs if same arr
    
    * whitespace changes
    
    * lint fix
    
    * trigger ci
---
 src/operator/tensor/elemwise_binary_op-inl.h | 32 +++++++++++++++++-----------
 tests/python/unittest/test_sparse_ndarray.py |  1 +
 2 files changed, 20 insertions(+), 13 deletions(-)

diff --git a/src/operator/tensor/elemwise_binary_op-inl.h 
b/src/operator/tensor/elemwise_binary_op-inl.h
index 5cd3314..15b1f0e 100644
--- a/src/operator/tensor/elemwise_binary_op-inl.h
+++ b/src/operator/tensor/elemwise_binary_op-inl.h
@@ -288,12 +288,16 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s,
   mshadow::Tensor<cpu, 1, DType> lhs_row(reinterpret_cast<DType 
*>(workspace.dptr_
                                                                    + nr_cols * 
sizeof(IType)),
                                          Shape1(nr_cols));
-  mshadow::Tensor<cpu, 1, DType> rhs_row(lhs_row.dptr_ + nr_cols, 
Shape1(nr_cols));
+  mshadow::Tensor<cpu, 1, DType> rhs_row;
 
   OpBase::FillDense<IType>(s, next.shape_.Size(), IType(-1), req, next.dptr_);
   OpBase::FillDense<DType>(s, lhs_row.shape_.Size(), DType(0),  req, 
lhs_row.dptr_);
+
   if (!same_lhs_rhs) {
+    rhs_row = Tensor<cpu, 1, DType>(lhs_row.dptr_ + nr_cols, Shape1(nr_cols));
     OpBase::FillDense<DType>(s, rhs_row.shape_.Size(), DType(0), req, 
rhs_row.dptr_);
+  } else {
+    rhs_row = lhs_row;
   }
 
   // Column indices
@@ -331,17 +335,19 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s,
       }
     }
 
-    // add a row of B to rhs_row
-    const IType i_start_r = row_ptr_r[i];
-    const IType i_end_r = row_ptr_r[i + 1];
-    for (IType jj = i_start_r; jj < i_end_r; jj++) {
-      const IType col = col_indices_r[jj];
-      rhs_row[col] += data_r[jj];
-
-      if (next[col] == -1) {
-        next[col] = head;
-        head = col;
-        ++length;
+    if (!same_lhs_rhs) {
+      // add a row of B to rhs_row
+      const IType i_start_r = row_ptr_r[i];
+      const IType i_end_r = row_ptr_r[i + 1];
+      for (IType jj = i_start_r; jj < i_end_r; jj++) {
+        const IType col = col_indices_r[jj];
+        rhs_row[col] += data_r[jj];
+
+        if (next[col] == -1) {
+          next[col] = head;
+          head = col;
+          ++length;
+        }
       }
     }
 
@@ -361,7 +367,7 @@ void ElemwiseBinaryOp::CsrCsrOp(mshadow::Stream<cpu> *s,
 
       next[temp] = -1;
       lhs_row[temp] = 0;
-      rhs_row[temp] = 0;
+      if (!same_lhs_rhs) rhs_row[temp] = 0;
     }
 
     row_ptr_out[i + 1] = nnz;
diff --git a/tests/python/unittest/test_sparse_ndarray.py 
b/tests/python/unittest/test_sparse_ndarray.py
index ae3260a..a710038 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -261,6 +261,7 @@ def test_sparse_nd_binary():
             lhs_nd = mx.nd.array(lhs).tostype(stype)
             rhs_nd = mx.nd.array(rhs).tostype(stype)
             assert_allclose(fn(lhs, rhs), fn(lhs_nd, rhs_nd).asnumpy(), 
rtol=1e-4, atol=1e-4)
+            assert_allclose(fn(lhs, lhs), fn(lhs_nd, lhs_nd).asnumpy(), 
rtol=1e-4, atol=1e-4)
 
     stypes = ['row_sparse', 'csr']
     for stype in stypes:

-- 
To stop receiving notification emails like this one, please contact
hai...@apache.org.

Reply via email to