This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.6.x by this push:
     new a3b0aa4  [Numpy] Fix imperative basic indexing in numpy (#16902) 
(#16919)
a3b0aa4 is described below

commit a3b0aa486cf914f74b603dd64c0adcbc83fbc864
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Wed Nov 27 08:21:47 2019 -0800

    [Numpy] Fix imperative basic indexing in numpy (#16902) (#16919)
    
    * fix bug
    
    add test case
    
    fix
    
    Update test_numpy_ndarray.py
    
    * revise function name
---
 python/mxnet/ndarray/ndarray.py             | 81 ++++++++++++++++++-----------
 src/ndarray/ndarray.cc                      |  5 +-
 src/operator/nn/mkldnn/mkldnn_base-inl.h    |  4 ++
 tests/python/unittest/test_numpy_ndarray.py | 20 ++++---
 4 files changed, 71 insertions(+), 39 deletions(-)

diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7d8cc52..a7ad8e6 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -847,26 +847,32 @@ fixed-size items.
         """Whether indexing with the given key results in a contiguous array.
 
         The rule is: From right to left, if in an axis, a slice produces a
-        proper subset, no later axis can produce a proper subset or use
-        a step different from 1.
+        proper subset, the later slice must have <=1 elements.
 
         The ``slc_key`` sequence must have the same length as ``shape`` and
         only contain `slice` objects.
         """
         assert len(slc_key) == len(shape)
-        subset = False
+        is_subset = False
+        total_sliced_elements = np.prod([_get_slice_len(slc, n)
+                                         for slc, n in zip(slc_key, shape)])
+        if total_sliced_elements in (0, 1):
+            return True
         for idx, n in zip(reversed(slc_key), reversed(shape)):
-            start, stop, step = idx.indices(n)
-            if step > 0:
-                num = int(np.ceil(max(stop - start, 0) / step))
-            else:
-                num = int(np.ceil(min(stop - start, 0) / step))
-
-            if num != 1 and (subset or step != 1):
+            _, _, step = idx.indices(n)
+            num_elements = _get_slice_len(idx, n)
+            if num_elements == 0:
+                return True
+            elif num_elements > 1 and (step > 1 or step < 0):
+                # We do not support the case of reverse slicing of multiple 
elements and
+                # forward slicing of #elements > 1 and step > 1
                 return False
-            if num != n:
-                subset = True
-
+            elif is_subset:
+                if num_elements > 1:
+                    return False
+            else:
+                if num_elements < n:
+                    is_subset = True
         return True
     # pylint: enable=invalid-name
 
@@ -875,14 +881,9 @@ fixed-size items.
         """Return the shape after slicing with the given key."""
         assert len(slc_key) == len(shape)
         sliced_shape = []
-        for idx, n in zip(slc_key, shape):
-            start, stop, step = idx.indices(n)
-            if step > 0:
-                num = int(np.ceil(max(stop - start, 0) / step))
-            else:
-                num = int(np.ceil(min(stop - start, 0) / step))
-            sliced_shape.append(num)
-
+        for slc, n in zip(slc_key, shape):
+            num_elements = _get_slice_len(slc, n)
+            sliced_shape.append(num_elements)
         return tuple(sliced_shape)
 
     # pylint: disable=invalid-name
@@ -890,15 +891,17 @@ fixed-size items.
     def _basic_indexing_contiguous_flat_begin_end(slc_key, shape):
         """Return the flat indices of begin and end for contiguous slicing."""
         assert len(slc_key) == len(shape)
-        begin, end, _ = slc_key[0].indices(shape[0])
-        flat_begin, flat_end = begin, end - 1
-        for idx, n in zip(slc_key[1:], shape[1:]):
+        flat_begin, flat_end = 0, 0
+        for slc, n in zip(slc_key, shape):
             flat_begin *= n
             flat_end *= n
-            begin, end, _ = idx.indices(n)
-            flat_begin += begin
-            flat_end += end - 1
-
+            begin, _, _ = slc.indices(n)
+            num_elements = _get_slice_len(slc, n)
+            if num_elements == 0:
+                return 0, 0
+            else:
+                flat_begin += begin
+                flat_end += begin + num_elements - 1
         return flat_begin, flat_end + 1
     # pylint: enable=invalid-name
 
@@ -1062,7 +1065,7 @@ fixed-size items.
         for ax in new_axes:  # pylint: disable=invalid-name
             final_shape.insert(ax, 1)
 
-        if final_shape == []:
+        if len(final_shape) == 0:
             # Override for single element indexing
             final_shape = [1]
         return sliced.reshape(final_shape)
@@ -3108,6 +3111,26 @@ def _get_dim_size(start, stop, step):
     return dim_size
 
 
+def _get_slice_len(slc, seq_length):
+    """Given a python slice object and the length of the sequence, calculate 
the number of elements
+     in the slice.
+
+    Parameters
+    ----------
+    slc : py_slice
+        The slice object
+    seq_length : int
+        The length of the object you are going to apply the slice on
+
+    Returns
+    -------
+    ret : int
+        Total number of elements in the slice
+    """
+    start, stop, step = slc.indices(seq_length)
+    return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)
+
+
 def _get_broadcast_shape(shape1, shape2):
     """Given two shapes that are not identical, find the shape
     that both input shapes can broadcast to."""
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 6dc6baf..9375bed 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -293,10 +293,9 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
 NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
   NDArray ret = this->Slice(begin, end);
   if (!Imperative::Get()->is_recording()) return ret;
-  // fake a slice_axis op
+  // fake a slice op
   nnvm::NodeAttrs attrs;
-  attrs.op = nnvm::Op::Get("slice_axis");
-  attrs.dict.insert({"axis", "0"});
+  attrs.op = nnvm::Op::Get("slice");
   attrs.dict.insert({"begin", std::to_string(begin)});
   attrs.dict.insert({"end", std::to_string(end)});
   attrs.op->attr_parser(&attrs);
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h 
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 0f371d1..9bfc20c 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -125,6 +125,10 @@ static inline bool SupportStorageMKLDNN(int stype) {
 
 static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
   int ndim = shape.ndim();
+  if (ndim == 0 || shape.Size() == 0) {
+    // MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
+    return false;
+  }
   return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
 }
 
diff --git a/tests/python/unittest/test_numpy_ndarray.py 
b/tests/python/unittest/test_numpy_ndarray.py
index 8e46f03..9f4e62c 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -642,13 +642,18 @@ def test_np_ndarray_indexing():
             )
         np_indexed_array = np_array[np_index]
         mx_np_array = np.array(np_array, dtype=np_array.dtype)
-        try:
-            mx_indexed_array = mx_np_array[index]
-        except Exception as e:
-            print('Failed with index = {}'.format(index))
-            raise e
-        mx_indexed_array = mx_indexed_array.asnumpy()
-        assert same(np_indexed_array, mx_indexed_array), 'Failed with index = 
{}'.format(index)
+        for autograd in [True, False]:
+            try:
+                if autograd:
+                    with mx.autograd.record():
+                        mx_indexed_array = mx_np_array[index]
+                else:
+                    mx_indexed_array = mx_np_array[index]
+            except Exception as e:
+                print('Failed with index = {}'.format(index))
+                raise e
+            mx_indexed_array = mx_indexed_array.asnumpy()
+            assert same(np_indexed_array, mx_indexed_array), 'Failed with 
index = {}'.format(index)
 
     def test_setitem(np_array, index):
         def assert_same(np_array, np_index, mx_array, mx_index, mx_value, 
np_value=None):
@@ -768,6 +773,7 @@ def test_np_ndarray_indexing():
         np_int(slice(1, 5), np.int32),
         np_int(slice(1, 5), np.int64),
         slice(1, 5, 2),
+        slice(1, 2, 2),
         np_int(slice(1, 5, 2), np.int32),
         np_int(slice(1, 5, 2), np.int64),
         slice(7, 0, -1),

Reply via email to