This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 4fb5241 [MXNET-155] Add support for np.int32 and np.int64 as basic index types (#10434) 4fb5241 is described below commit 4fb5241b47c8147690fd6408b55cb694d544656e Author: reminisce <wujun....@gmail.com> AuthorDate: Sat Apr 7 19:08:15 2018 -0700 [MXNET-155] Add support for np.int32 and np.int64 as basic index types (#10434) * Add support for np.int32 and np.int64 as basic index types * Address cr * Remove unnecessary import --- python/mxnet/base.py | 4 +- tests/python/unittest/test_ndarray.py | 80 +++++++++++++++++++++++++--- tests/python/unittest/test_sparse_ndarray.py | 11 ++++ 3 files changed, 87 insertions(+), 8 deletions(-) diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 156ef35..9790e09 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -37,14 +37,14 @@ __all__ = ['MXNetError'] if sys.version_info[0] == 3: string_types = str, numeric_types = (float, int, np.generic) - integer_types = int + integer_types = (int, np.int32, np.int64) # this function is needed for python3 # to convert ctypes.char_p .value back to python str py_str = lambda x: x.decode('utf-8') else: string_types = basestring, numeric_types = (float, int, long, np.generic) - integer_types = (int, long) + integer_types = (int, long, np.int32, np.int64) py_str = lambda x: x class _NullType(object): diff --git a/tests/python/unittest/test_ndarray.py b/tests/python/unittest/test_ndarray.py index 4265bf8..6f2bd45 100644 --- a/tests/python/unittest/test_ndarray.py +++ b/tests/python/unittest/test_ndarray.py @@ -1061,35 +1061,103 @@ def test_ndarray_indexing(): x_grad[index] = value assert same(x_grad.asnumpy(), x.grad.asnumpy()) + def np_int(index, int_type=np.int32): + def convert(num): + if num is None: + return num + else: + return int_type(num) + + if isinstance(index, slice): + return slice(convert(index.start), convert(index.stop), convert(index.step)) + elif isinstance(index, tuple): # tuple of slices and integers + ret = [] + for elem in index: + if isinstance(elem, slice): + ret.append(slice(convert(elem.start), convert(elem.stop), convert(elem.step))) + else: + ret.append(convert(elem)) + return tuple(ret) + else: + assert False + shape = (8, 16, 9, 9) np_array = np.arange(np.prod(shape), dtype='int32').reshape(shape) # index_list is a list of tuples. The tuple's first element is the index, the second one is a boolean value # indicating whether we should expect the result as a scalar compared to numpy. - index_list = [(0, False), (5, False), (-1, False), - (slice(5), False), (slice(1, 5), False), (slice(1, 5, 2), False), - (slice(7, 0, -1), False), (slice(None, 6), False), (slice(None, 6, 3), False), - (slice(1, None), False), (slice(1, None, 3), False), (slice(None, None, 2), False), - (slice(None, None, -1), False), (slice(None, None, -2), False), + index_list = [(0, False), (np.int32(0), False), (np.int64(0), False), + (5, False), (np.int32(5), False), (np.int64(5), False), + (-1, False), (np.int32(-1), False), (np.int64(-1), False), + (slice(5), False), (np_int(slice(5), np.int32), False), (np_int(slice(5), np.int64), False), + (slice(1, 5), False), (np_int(slice(1, 5), np.int32), False), (np_int(slice(1, 5), np.int64), False), + (slice(1, 5, 2), False), (np_int(slice(1, 5, 2), np.int32), False), + (np_int(slice(1, 5, 2), np.int64), False), + (slice(7, 0, -1), False), (np_int(slice(7, 0, -1)), False), + (np_int(slice(7, 0, -1), np.int64), False), + (slice(None, 6), False), (np_int(slice(None, 6)), False), + (np_int(slice(None, 6), np.int64), False), + (slice(None, 6, 3), False), (np_int(slice(None, 6, 3)), False), + (np_int(slice(None, 6, 3), np.int64), False), + (slice(1, None), False), (np_int(slice(1, None)), False), + (np_int(slice(1, None), np.int64), False), + (slice(1, None, 3), False), (np_int(slice(1, None, 3)), False), + (np_int(slice(1, None, 3), np.int64), False), + (slice(None, None, 2), False), (np_int(slice(None, None, 2)), False), + (np_int(slice(None, None, 2), np.int64), False), + (slice(None, None, -1), False), + (np_int(slice(None, None, -1)), False), (np_int(slice(None, None, -1), np.int64), False), + (slice(None, None, -2), False), + (np_int(slice(None, None, -2), np.int32), False), (np_int(slice(None, None, -2), np.int64), False), + ((slice(None), slice(None), 1, 8), False), + (np_int((slice(None), slice(None), 1, 8)), False), + (np_int((slice(None), slice(None), 1, 8), np.int64), False), ((slice(None), slice(None), 1, 8), False), + (np_int((slice(None), slice(None), 1, 8)), False), + (np_int((slice(None), slice(None), 1, 8), np.int64), False), ((slice(None), 2, slice(1, 5), 1), False), - ((1, 2, 3), False), ((1, 2, 3, 4), True), + (np_int((slice(None), 2, slice(1, 5), 1)), False), + (np_int((slice(None), 2, slice(1, 5), 1), np.int64), False), + ((1, 2, 3), False), + (np_int((1, 2, 3)), False), + (np_int((1, 2, 3), np.int64), False), + ((1, 2, 3, 4), True), + (np_int((1, 2, 3, 4)), True), + (np_int((1, 2, 3, 4), np.int64), True), ((slice(None, None, -1), 2, slice(1, 5), 1), False), + (np_int((slice(None, None, -1), 2, slice(1, 5), 1)), False), + (np_int((slice(None, None, -1), 2, slice(1, 5), 1), np.int64), False), ((slice(None, None, -1), 2, slice(1, 7, 2), 1), False), + (np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), False), + (np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), np.int64), False), ((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), False), + (np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3))), False), + (np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 3)), np.int64), False), ((slice(1, 8, 2), 1, slice(3, 8), 2), False), + (np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), False), + (np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), False), ([1], False), ([1, 2], False), ([2, 1, 3], False), ([7, 5, 0, 3, 6, 2, 1], False), (np.array([6, 3], dtype=np.int32), False), (np.array([[3, 4], [0, 6]], dtype=np.int32), False), (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False), + (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False), (np.array([[2], [0], [1]], dtype=np.int32), False), + (np.array([[2], [0], [1]], dtype=np.int64), False), (mx.nd.array([4, 7], dtype=np.int32), False), + (mx.nd.array([4, 7], dtype=np.int64), False), (mx.nd.array([[3, 6], [2, 1]], dtype=np.int32), False), + (mx.nd.array([[3, 6], [2, 1]], dtype=np.int64), False), (mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), False), + (mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), False), ((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], dtype=np.int32)), False), + ((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], dtype=np.int64)), False), ((1, [2], np.array([[5], [3]], dtype=np.int32), slice(None)), False), + ((1, [2], np.array([[5], [3]], dtype=np.int64), slice(None)), False), ((1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 5)), False), + ((1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 5)), False), ((1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 5, 2)), False), + ((1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 5, 2)), False), ((1, [2], np.array([[3]], dtype=np.int32), slice(None, None, -1)), False), + ((1, [2], np.array([[3]], dtype=np.int64), slice(None, None, -1)), False), ((1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], [2, 4]], dtype=np.int64)), False), ((1, [2], mx.nd.array([[4]], dtype=np.int32), mx.nd.array([[1, 3], [5, 7]], dtype='int64')), False), diff --git a/tests/python/unittest/test_sparse_ndarray.py b/tests/python/unittest/test_sparse_ndarray.py index 25eaf42..ae3260a 100644 --- a/tests/python/unittest/test_sparse_ndarray.py +++ b/tests/python/unittest/test_sparse_ndarray.py @@ -638,6 +638,17 @@ def test_create_row_sparse(): rsp_copy = mx.nd.array(rsp_created) assert(same(rsp_copy.asnumpy(), rsp_created.asnumpy())) + # add this test since we added np.int32 and np.int64 to integer_types + if len(shape) == 2: + for np_int_type in (np.int32, np.int64): + shape = list(shape) + shape = [np_int_type(x) for x in shape] + arg1 = tuple(shape) + mx.nd.sparse.row_sparse_array(arg1, tuple(shape)) + shape[0] += 1 + assert_exception(mx.nd.sparse.row_sparse_array, ValueError, arg1, tuple(shape)) + + @with_seed() def test_create_sparse_nd_infer_shape(): -- To stop receiving notification emails like this one, please contact j...@apache.org.