This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 71a7b5d  [numpy] Fix ffi split  (#18136)
71a7b5d is described below

commit 71a7b5d5c918b07a6afde858ffc74e90f65173bd
Author: Yiyan66 <57363390+yiya...@users.noreply.github.com>
AuthorDate: Fri Apr 24 02:50:49 2020 +0800

    [numpy] Fix ffi split  (#18136)
    
    * fix ffi split
    
    * add test
    
    * fix ffi split
    
    Co-authored-by: Ubuntu <ubu...@ip-172-31-18-97.us-east-2.compute.internal>
---
 src/api/operator/numpy/np_matrix_op.cc               | 5 ++++-
 tests/python/unittest/test_numpy_interoperability.py | 1 +
 tests/python/unittest/test_numpy_op.py               | 2 +-
 3 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/api/operator/numpy/np_matrix_op.cc 
b/src/api/operator/numpy/np_matrix_op.cc
index 58ee563..929a6a6 100644
--- a/src/api/operator/numpy/np_matrix_op.cc
+++ b/src/api/operator/numpy/np_matrix_op.cc
@@ -144,9 +144,12 @@ MXNET_REGISTER_API("_npi.split")
   if (args[1].type_code() == kDLInt) {
     param.indices = TShape(0, 0);
     param.sections = args[1].operator int();
+    int index = param.axis >= 0 ? param.axis :
+                                  param.axis + inputs[0]->shape().ndim();
+    CHECK_GE(index, 0) << "IndexError: tuple index out of range";
     CHECK_GT(param.sections, 0)
       << "ValueError: number sections must be larger than 0";
-    CHECK_EQ(inputs[0]->shape()[param.axis] % param.sections, 0)
+    CHECK_EQ(inputs[0]->shape()[index] % param.sections, 0)
       << "ValueError: array split does not result in an equal division";
   } else {
     TShape t = TShape(args[1].operator ObjectRef());
diff --git a/tests/python/unittest/test_numpy_interoperability.py 
b/tests/python/unittest/test_numpy_interoperability.py
index 824fa1e..c004d0c 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -345,6 +345,7 @@ def _add_workload_expand_dims():
 def _add_workload_split():
     OpArgMngr.add_workload('split', np.random.uniform(size=(4, 1)), 2)
     OpArgMngr.add_workload('split', np.arange(10), 2)
+    OpArgMngr.add_workload('split', np.random.uniform(size=(10, 10, 3)), 3, -1)
     assertRaises(ValueError, np.split, np.arange(10), 3)
 
 
diff --git a/tests/python/unittest/test_numpy_op.py 
b/tests/python/unittest/test_numpy_op.py
index 3b11b35..20a940f 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -2938,7 +2938,7 @@ def test_np_split():
     dim = random.randint(0, 3)
     shape = [0] + [random.randint(2, 4) for i in range(dim)]
     for hybridize in [True, False]:
-        for axis in range(len(shape)):
+        for axis in range(-len(shape)+1, len(shape)):
             indices = get_indices(shape[axis])
             sections = 7 if shape[axis] is 0 else shape[axis]
             for indices_or_sections in [indices, sections]:

Reply via email to