This is an automated email from the ASF dual-hosted git repository. sxjscience pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 71a7b5d [numpy] Fix ffi split (#18136) 71a7b5d is described below commit 71a7b5d5c918b07a6afde858ffc74e90f65173bd Author: Yiyan66 <57363390+yiya...@users.noreply.github.com> AuthorDate: Fri Apr 24 02:50:49 2020 +0800 [numpy] Fix ffi split (#18136) * fix ffi split * add test * fix ffi split Co-authored-by: Ubuntu <ubu...@ip-172-31-18-97.us-east-2.compute.internal> --- src/api/operator/numpy/np_matrix_op.cc | 5 ++++- tests/python/unittest/test_numpy_interoperability.py | 1 + tests/python/unittest/test_numpy_op.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/api/operator/numpy/np_matrix_op.cc b/src/api/operator/numpy/np_matrix_op.cc index 58ee563..929a6a6 100644 --- a/src/api/operator/numpy/np_matrix_op.cc +++ b/src/api/operator/numpy/np_matrix_op.cc @@ -144,9 +144,12 @@ MXNET_REGISTER_API("_npi.split") if (args[1].type_code() == kDLInt) { param.indices = TShape(0, 0); param.sections = args[1].operator int(); + int index = param.axis >= 0 ? param.axis : + param.axis + inputs[0]->shape().ndim(); + CHECK_GE(index, 0) << "IndexError: tuple index out of range"; CHECK_GT(param.sections, 0) << "ValueError: number sections must be larger than 0"; - CHECK_EQ(inputs[0]->shape()[param.axis] % param.sections, 0) + CHECK_EQ(inputs[0]->shape()[index] % param.sections, 0) << "ValueError: array split does not result in an equal division"; } else { TShape t = TShape(args[1].operator ObjectRef()); diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 824fa1e..c004d0c 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -345,6 +345,7 @@ def _add_workload_expand_dims(): def _add_workload_split(): OpArgMngr.add_workload('split', np.random.uniform(size=(4, 1)), 2) OpArgMngr.add_workload('split', np.arange(10), 2) + OpArgMngr.add_workload('split', np.random.uniform(size=(10, 10, 3)), 3, -1) assertRaises(ValueError, np.split, np.arange(10), 3) diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 3b11b35..20a940f 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -2938,7 +2938,7 @@ def test_np_split(): dim = random.randint(0, 3) shape = [0] + [random.randint(2, 4) for i in range(dim)] for hybridize in [True, False]: - for axis in range(len(shape)): + for axis in range(-len(shape)+1, len(shape)): indices = get_indices(shape[axis]) sections = 7 if shape[axis] is 0 else shape[axis] for indices_or_sections in [indices, sections]: