This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch v1.6.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.6.x by this push:
new a3b0aa4 [Numpy] Fix imperative basic indexing in numpy (#16902)
(#16919)
a3b0aa4 is described below
commit a3b0aa486cf914f74b603dd64c0adcbc83fbc864
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Wed Nov 27 08:21:47 2019 -0800
[Numpy] Fix imperative basic indexing in numpy (#16902) (#16919)
* fix bug
add test case
fix
Update test_numpy_ndarray.py
* revise function name
---
python/mxnet/ndarray/ndarray.py | 81 ++++++++++++++++++-----------
src/ndarray/ndarray.cc | 5 +-
src/operator/nn/mkldnn/mkldnn_base-inl.h | 4 ++
tests/python/unittest/test_numpy_ndarray.py | 20 ++++---
4 files changed, 71 insertions(+), 39 deletions(-)
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7d8cc52..a7ad8e6 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -847,26 +847,32 @@ fixed-size items.
"""Whether indexing with the given key results in a contiguous array.
The rule is: From right to left, if in an axis, a slice produces a
- proper subset, no later axis can produce a proper subset or use
- a step different from 1.
+ proper subset, the later slice must have <=1 elements.
The ``slc_key`` sequence must have the same length as ``shape`` and
only contain `slice` objects.
"""
assert len(slc_key) == len(shape)
- subset = False
+ is_subset = False
+ total_sliced_elements = np.prod([_get_slice_len(slc, n)
+ for slc, n in zip(slc_key, shape)])
+ if total_sliced_elements in (0, 1):
+ return True
for idx, n in zip(reversed(slc_key), reversed(shape)):
- start, stop, step = idx.indices(n)
- if step > 0:
- num = int(np.ceil(max(stop - start, 0) / step))
- else:
- num = int(np.ceil(min(stop - start, 0) / step))
-
- if num != 1 and (subset or step != 1):
+ _, _, step = idx.indices(n)
+ num_elements = _get_slice_len(idx, n)
+ if num_elements == 0:
+ return True
+ elif num_elements > 1 and (step > 1 or step < 0):
+ # We do not support the case of reverse slicing of multiple
elements and
+ # forward slicing of #elements > 1 and step > 1
return False
- if num != n:
- subset = True
-
+ elif is_subset:
+ if num_elements > 1:
+ return False
+ else:
+ if num_elements < n:
+ is_subset = True
return True
# pylint: enable=invalid-name
@@ -875,14 +881,9 @@ fixed-size items.
"""Return the shape after slicing with the given key."""
assert len(slc_key) == len(shape)
sliced_shape = []
- for idx, n in zip(slc_key, shape):
- start, stop, step = idx.indices(n)
- if step > 0:
- num = int(np.ceil(max(stop - start, 0) / step))
- else:
- num = int(np.ceil(min(stop - start, 0) / step))
- sliced_shape.append(num)
-
+ for slc, n in zip(slc_key, shape):
+ num_elements = _get_slice_len(slc, n)
+ sliced_shape.append(num_elements)
return tuple(sliced_shape)
# pylint: disable=invalid-name
@@ -890,15 +891,17 @@ fixed-size items.
def _basic_indexing_contiguous_flat_begin_end(slc_key, shape):
"""Return the flat indices of begin and end for contiguous slicing."""
assert len(slc_key) == len(shape)
- begin, end, _ = slc_key[0].indices(shape[0])
- flat_begin, flat_end = begin, end - 1
- for idx, n in zip(slc_key[1:], shape[1:]):
+ flat_begin, flat_end = 0, 0
+ for slc, n in zip(slc_key, shape):
flat_begin *= n
flat_end *= n
- begin, end, _ = idx.indices(n)
- flat_begin += begin
- flat_end += end - 1
-
+ begin, _, _ = slc.indices(n)
+ num_elements = _get_slice_len(slc, n)
+ if num_elements == 0:
+ return 0, 0
+ else:
+ flat_begin += begin
+ flat_end += begin + num_elements - 1
return flat_begin, flat_end + 1
# pylint: enable=invalid-name
@@ -1062,7 +1065,7 @@ fixed-size items.
for ax in new_axes: # pylint: disable=invalid-name
final_shape.insert(ax, 1)
- if final_shape == []:
+ if len(final_shape) == 0:
# Override for single element indexing
final_shape = [1]
return sliced.reshape(final_shape)
@@ -3108,6 +3111,26 @@ def _get_dim_size(start, stop, step):
return dim_size
+def _get_slice_len(slc, seq_length):
+ """Given a python slice object and the length of the sequence, calculate
the number of elements
+ in the slice.
+
+ Parameters
+ ----------
+ slc : py_slice
+ The slice object
+ seq_length : int
+ The length of the object you are going to apply the slice on
+
+ Returns
+ -------
+ ret : int
+ Total number of elements in the slice
+ """
+ start, stop, step = slc.indices(seq_length)
+ return max(0, (stop - start + (step - (1 if step > 0 else -1))) // step)
+
+
def _get_broadcast_shape(shape1, shape2):
"""Given two shapes that are not identical, find the shape
that both input shapes can broadcast to."""
diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc
index 6dc6baf..9375bed 100644
--- a/src/ndarray/ndarray.cc
+++ b/src/ndarray/ndarray.cc
@@ -293,10 +293,9 @@ NDArray NDArray::Slice(index_t begin, index_t end) const {
NDArray NDArray::SliceWithRecord(index_t begin, index_t end) {
NDArray ret = this->Slice(begin, end);
if (!Imperative::Get()->is_recording()) return ret;
- // fake a slice_axis op
+ // fake a slice op
nnvm::NodeAttrs attrs;
- attrs.op = nnvm::Op::Get("slice_axis");
- attrs.dict.insert({"axis", "0"});
+ attrs.op = nnvm::Op::Get("slice");
attrs.dict.insert({"begin", std::to_string(begin)});
attrs.dict.insert({"end", std::to_string(end)});
attrs.op->attr_parser(&attrs);
diff --git a/src/operator/nn/mkldnn/mkldnn_base-inl.h
b/src/operator/nn/mkldnn/mkldnn_base-inl.h
index 0f371d1..9bfc20c 100644
--- a/src/operator/nn/mkldnn/mkldnn_base-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_base-inl.h
@@ -125,6 +125,10 @@ static inline bool SupportStorageMKLDNN(int stype) {
static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
int ndim = shape.ndim();
+ if (ndim == 0 || shape.Size() == 0) {
+ // MKLDNN currently does not support 0-dim Tensor and 0-size Tensor
+ return false;
+ }
return dtype == mshadow::kFloat32 && (ndim == 1 || ndim == 2 || ndim == 4);
}
diff --git a/tests/python/unittest/test_numpy_ndarray.py
b/tests/python/unittest/test_numpy_ndarray.py
index 8e46f03..9f4e62c 100644
--- a/tests/python/unittest/test_numpy_ndarray.py
+++ b/tests/python/unittest/test_numpy_ndarray.py
@@ -642,13 +642,18 @@ def test_np_ndarray_indexing():
)
np_indexed_array = np_array[np_index]
mx_np_array = np.array(np_array, dtype=np_array.dtype)
- try:
- mx_indexed_array = mx_np_array[index]
- except Exception as e:
- print('Failed with index = {}'.format(index))
- raise e
- mx_indexed_array = mx_indexed_array.asnumpy()
- assert same(np_indexed_array, mx_indexed_array), 'Failed with index =
{}'.format(index)
+ for autograd in [True, False]:
+ try:
+ if autograd:
+ with mx.autograd.record():
+ mx_indexed_array = mx_np_array[index]
+ else:
+ mx_indexed_array = mx_np_array[index]
+ except Exception as e:
+ print('Failed with index = {}'.format(index))
+ raise e
+ mx_indexed_array = mx_indexed_array.asnumpy()
+ assert same(np_indexed_array, mx_indexed_array), 'Failed with
index = {}'.format(index)
def test_setitem(np_array, index):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value,
np_value=None):
@@ -768,6 +773,7 @@ def test_np_ndarray_indexing():
np_int(slice(1, 5), np.int32),
np_int(slice(1, 5), np.int64),
slice(1, 5, 2),
+ slice(1, 2, 2),
np_int(slice(1, 5, 2), np.int32),
np_int(slice(1, 5, 2), np.int64),
slice(7, 0, -1),