This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fb5241  [MXNET-155] Add support for np.int32 and np.int64 as basic 
index types (#10434)
4fb5241 is described below

commit 4fb5241b47c8147690fd6408b55cb694d544656e
Author: reminisce <wujun....@gmail.com>
AuthorDate: Sat Apr 7 19:08:15 2018 -0700

    [MXNET-155] Add support for np.int32 and np.int64 as basic index types 
(#10434)
    
    * Add support for np.int32 and np.int64 as basic index types
    
    * Address cr
    
    * Remove unnecessary import
---
 python/mxnet/base.py                         |  4 +-
 tests/python/unittest/test_ndarray.py        | 80 +++++++++++++++++++++++++---
 tests/python/unittest/test_sparse_ndarray.py | 11 ++++
 3 files changed, 87 insertions(+), 8 deletions(-)

diff --git a/python/mxnet/base.py b/python/mxnet/base.py
index 156ef35..9790e09 100644
--- a/python/mxnet/base.py
+++ b/python/mxnet/base.py
@@ -37,14 +37,14 @@ __all__ = ['MXNetError']
 if sys.version_info[0] == 3:
     string_types = str,
     numeric_types = (float, int, np.generic)
-    integer_types = int
+    integer_types = (int, np.int32, np.int64)
     # this function is needed for python3
     # to convert ctypes.char_p .value back to python str
     py_str = lambda x: x.decode('utf-8')
 else:
     string_types = basestring,
     numeric_types = (float, int, long, np.generic)
-    integer_types = (int, long)
+    integer_types = (int, long, np.int32, np.int64)
     py_str = lambda x: x
 
 class _NullType(object):
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 4265bf8..6f2bd45 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -1061,35 +1061,103 @@ def test_ndarray_indexing():
         x_grad[index] = value
         assert same(x_grad.asnumpy(), x.grad.asnumpy())
 
+    def np_int(index, int_type=np.int32):
+        def convert(num):
+            if num is None:
+                return num
+            else:
+                return int_type(num)
+
+        if isinstance(index, slice):
+            return slice(convert(index.start), convert(index.stop), 
convert(index.step))
+        elif isinstance(index, tuple):  # tuple of slices and integers
+            ret = []
+            for elem in index:
+                if isinstance(elem, slice):
+                    ret.append(slice(convert(elem.start), convert(elem.stop), 
convert(elem.step)))
+                else:
+                    ret.append(convert(elem))
+            return tuple(ret)
+        else:
+            assert False
+
     shape = (8, 16, 9, 9)
     np_array = np.arange(np.prod(shape), dtype='int32').reshape(shape)
     # index_list is a list of tuples. The tuple's first element is the index, 
the second one is a boolean value
     # indicating whether we should expect the result as a scalar compared to 
numpy.
-    index_list = [(0, False), (5, False), (-1, False),
-                  (slice(5), False), (slice(1, 5), False), (slice(1, 5, 2), 
False),
-                  (slice(7, 0, -1), False), (slice(None, 6), False), 
(slice(None, 6, 3), False),
-                  (slice(1, None), False), (slice(1, None, 3), False), 
(slice(None, None, 2), False),
-                  (slice(None, None, -1), False), (slice(None, None, -2), 
False),
+    index_list = [(0, False), (np.int32(0), False), (np.int64(0), False),
+                  (5, False), (np.int32(5), False), (np.int64(5), False),
+                  (-1, False), (np.int32(-1), False), (np.int64(-1), False),
+                  (slice(5), False), (np_int(slice(5), np.int32), False), 
(np_int(slice(5), np.int64), False),
+                  (slice(1, 5), False), (np_int(slice(1, 5), np.int32), 
False), (np_int(slice(1, 5), np.int64), False),
+                  (slice(1, 5, 2), False), (np_int(slice(1, 5, 2), np.int32), 
False),
+                  (np_int(slice(1, 5, 2), np.int64), False),
+                  (slice(7, 0, -1), False), (np_int(slice(7, 0, -1)), False),
+                  (np_int(slice(7, 0, -1), np.int64), False),
+                  (slice(None, 6), False), (np_int(slice(None, 6)), False),
+                  (np_int(slice(None, 6), np.int64), False),
+                  (slice(None, 6, 3), False), (np_int(slice(None, 6, 3)), 
False),
+                  (np_int(slice(None, 6, 3), np.int64), False),
+                  (slice(1, None), False), (np_int(slice(1, None)), False),
+                  (np_int(slice(1, None), np.int64), False),
+                  (slice(1, None, 3), False), (np_int(slice(1, None, 3)), 
False),
+                  (np_int(slice(1, None, 3), np.int64), False),
+                  (slice(None, None, 2), False), (np_int(slice(None, None, 
2)), False),
+                  (np_int(slice(None, None, 2), np.int64), False),
+                  (slice(None, None, -1), False),
+                  (np_int(slice(None, None, -1)), False), (np_int(slice(None, 
None, -1), np.int64), False),
+                  (slice(None, None, -2), False),
+                  (np_int(slice(None, None, -2), np.int32), False), 
(np_int(slice(None, None, -2), np.int64), False),
+                  ((slice(None), slice(None), 1, 8), False),
+                  (np_int((slice(None), slice(None), 1, 8)), False),
+                  (np_int((slice(None), slice(None), 1, 8), np.int64), False),
                   ((slice(None), slice(None), 1, 8), False),
+                  (np_int((slice(None), slice(None), 1, 8)), False),
+                  (np_int((slice(None), slice(None), 1, 8), np.int64), False),
                   ((slice(None), 2, slice(1, 5), 1), False),
-                  ((1, 2, 3), False), ((1, 2, 3, 4), True),
+                  (np_int((slice(None), 2, slice(1, 5), 1)), False),
+                  (np_int((slice(None), 2, slice(1, 5), 1), np.int64), False),
+                  ((1, 2, 3), False),
+                  (np_int((1, 2, 3)), False),
+                  (np_int((1, 2, 3), np.int64), False),
+                  ((1, 2, 3, 4), True),
+                  (np_int((1, 2, 3, 4)), True),
+                  (np_int((1, 2, 3, 4), np.int64), True),
                   ((slice(None, None, -1), 2, slice(1, 5), 1), False),
+                  (np_int((slice(None, None, -1), 2, slice(1, 5), 1)), False),
+                  (np_int((slice(None, None, -1), 2, slice(1, 5), 1), 
np.int64), False),
                   ((slice(None, None, -1), 2, slice(1, 7, 2), 1), False),
+                  (np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1)), 
False),
+                  (np_int((slice(None, None, -1), 2, slice(1, 7, 2), 1), 
np.int64), False),
                   ((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), slice(0, 7, 
3)), False),
+                  (np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), 
slice(0, 7, 3))), False),
+                  (np_int((slice(1, 8, 2), slice(14, 2, -2), slice(3, 8), 
slice(0, 7, 3)), np.int64), False),
                   ((slice(1, 8, 2), 1, slice(3, 8), 2), False),
+                  (np_int((slice(1, 8, 2), 1, slice(3, 8), 2)), False),
+                  (np_int((slice(1, 8, 2), 1, slice(3, 8), 2), np.int64), 
False),
                   ([1], False), ([1, 2], False), ([2, 1, 3], False), ([7, 5, 
0, 3, 6, 2, 1], False),
                   (np.array([6, 3], dtype=np.int32), False),
                   (np.array([[3, 4], [0, 6]], dtype=np.int32), False),
                   (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int32), 
False),
+                  (np.array([[7, 3], [2, 6], [0, 5], [4, 1]], dtype=np.int64), 
False),
                   (np.array([[2], [0], [1]], dtype=np.int32), False),
+                  (np.array([[2], [0], [1]], dtype=np.int64), False),
                   (mx.nd.array([4, 7], dtype=np.int32), False),
+                  (mx.nd.array([4, 7], dtype=np.int64), False),
                   (mx.nd.array([[3, 6], [2, 1]], dtype=np.int32), False),
+                  (mx.nd.array([[3, 6], [2, 1]], dtype=np.int64), False),
                   (mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], 
dtype=np.int32), False),
+                  (mx.nd.array([[7, 3], [2, 6], [0, 5], [4, 1]], 
dtype=np.int64), False),
                   ((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], 
dtype=np.int32)), False),
+                  ((1, [2, 3]), False), ((1, [2, 3], np.array([[3], [0]], 
dtype=np.int64)), False),
                   ((1, [2], np.array([[5], [3]], dtype=np.int32), 
slice(None)), False),
+                  ((1, [2], np.array([[5], [3]], dtype=np.int64), 
slice(None)), False),
                   ((1, [2, 3], np.array([[6], [0]], dtype=np.int32), slice(2, 
5)), False),
+                  ((1, [2, 3], np.array([[6], [0]], dtype=np.int64), slice(2, 
5)), False),
                   ((1, [2, 3], np.array([[4], [7]], dtype=np.int32), slice(2, 
5, 2)), False),
+                  ((1, [2, 3], np.array([[4], [7]], dtype=np.int64), slice(2, 
5, 2)), False),
                   ((1, [2], np.array([[3]], dtype=np.int32), slice(None, None, 
-1)), False),
+                  ((1, [2], np.array([[3]], dtype=np.int64), slice(None, None, 
-1)), False),
                   ((1, [2], np.array([[3]], dtype=np.int32), np.array([[5, 7], 
[2, 4]], dtype=np.int64)), False),
                   ((1, [2], mx.nd.array([[4]], dtype=np.int32), 
mx.nd.array([[1, 3], [5, 7]], dtype='int64')),
                    False),
diff --git a/tests/python/unittest/test_sparse_ndarray.py 
b/tests/python/unittest/test_sparse_ndarray.py
index 25eaf42..ae3260a 100644
--- a/tests/python/unittest/test_sparse_ndarray.py
+++ b/tests/python/unittest/test_sparse_ndarray.py
@@ -638,6 +638,17 @@ def test_create_row_sparse():
         rsp_copy = mx.nd.array(rsp_created)
         assert(same(rsp_copy.asnumpy(), rsp_created.asnumpy()))
 
+        # add this test since we added np.int32 and np.int64 to integer_types
+        if len(shape) == 2:
+            for np_int_type in (np.int32, np.int64):
+                shape = list(shape)
+                shape = [np_int_type(x) for x in shape]
+                arg1 = tuple(shape)
+                mx.nd.sparse.row_sparse_array(arg1, tuple(shape))
+                shape[0] += 1
+                assert_exception(mx.nd.sparse.row_sparse_array, ValueError, 
arg1, tuple(shape))
+
+
 
 @with_seed()
 def test_create_sparse_nd_infer_shape():

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to