This is an automated email from the ASF dual-hosted git repository.

apeforest pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 2527553  [Large Tensor] Fix cumsum op (#17677)
2527553 is described below

commit 2527553e8c8bf34d919e21bb4f37e2e13b6b6834
Author: Connor Goggins <cgoggi...@gmail.com>
AuthorDate: Sat Feb 29 00:37:33 2020 -0800

    [Large Tensor] Fix cumsum op (#17677)
    
    * Implemented fix and nightly test for cumsum
    
    * Changed IType to index_t
    
    * Also changed in backward
    
    * Reverting to IType
    
    * Added type assertion on first element to force evaluation of output 
NDArray
    
    * Reverted to IType in relevant places
    
    * Last reversion
    
    * Changed type assertion to value check
---
 src/operator/numpy/np_cumsum-inl.h | 24 ++++++++++++------------
 tests/nightly/test_large_array.py  | 11 +++++++++++
 2 files changed, 23 insertions(+), 12 deletions(-)

diff --git a/src/operator/numpy/np_cumsum-inl.h 
b/src/operator/numpy/np_cumsum-inl.h
index 375d83b..65e6581 100644
--- a/src/operator/numpy/np_cumsum-inl.h
+++ b/src/operator/numpy/np_cumsum-inl.h
@@ -60,17 +60,17 @@ struct CumsumParam : public dmlc::Parameter<CumsumParam> {
 
 struct cumsum_forward {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i,
+  MSHADOW_XINLINE static void Map(index_t i,
                                   OType *out,
                                   const IType *in,
-                                  const int middle,
-                                  const int trailing) {
-    int left = i / trailing, right = i % trailing;
-    int offset = left * middle * trailing + right;
+                                  const index_t middle,
+                                  const index_t trailing) {
+    index_t left = i / trailing, right = i % trailing;
+    index_t offset = left * middle * trailing + right;
     const IType *lane_in = in + offset;
     OType *lane_out = out + offset;
     lane_out[0] = OType(lane_in[0]);
-    for (int j = 1; j < middle; ++j) {
+    for (index_t j = 1; j < middle; ++j) {
       lane_out[j * trailing] = lane_out[(j - 1) * trailing] + OType(lane_in[j 
* trailing]);
     }
   }
@@ -125,17 +125,17 @@ void CumsumForward(const nnvm::NodeAttrs& attrs,
 
 struct cumsum_backward {
   template<typename IType, typename OType>
-  MSHADOW_XINLINE static void Map(int i,
+  MSHADOW_XINLINE static void Map(index_t i,
                                   IType *igrad,
                                   const OType *ograd,
-                                  const int middle,
-                                  const int trailing) {
-    int left = i / trailing, right = i % trailing;
-    int offset = left * middle * trailing + right;
+                                  const index_t middle,
+                                  const index_t trailing) {
+    index_t left = i / trailing, right = i % trailing;
+    index_t offset = left * middle * trailing + right;
     const OType *lane_ograd = ograd + offset;
     IType *lane_igrad = igrad + offset;
     lane_igrad[(middle - 1) * trailing] = IType(lane_ograd[(middle - 1) * 
trailing]);
-    for (int j = middle - 2; j >= 0; --j) {
+    for (index_t j = middle - 2; j >= 0; --j) {
       lane_igrad[j * trailing] = lane_igrad[(j + 1) * trailing] + 
IType(lane_ograd[j * trailing]);
     }
   }
diff --git a/tests/nightly/test_large_array.py 
b/tests/nightly/test_large_array.py
index ee57f17..222c452 100644
--- a/tests/nightly/test_large_array.py
+++ b/tests/nightly/test_large_array.py
@@ -504,6 +504,16 @@ def test_nn():
 
         assert out.shape[0] == LARGE_TENSOR_SHAPE
 
+    def check_cumsum():
+        a = nd.ones((LARGE_X, SMALL_Y))
+        axis = 1
+
+        res = nd.cumsum(a=a, axis=axis)
+
+        assert res.shape[0] == LARGE_X
+        assert res.shape[1] == SMALL_Y
+        assert res[0][SMALL_Y - 1] == 50.
+
     check_gluon_embedding()
     check_fully_connected()
     check_dense()
@@ -527,6 +537,7 @@ def test_nn():
     check_embedding()
     check_spatial_transformer()
     check_ravel()
+    check_cumsum()
 
 
 def test_tensor():

Reply via email to