This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 16280ad [1.x] Backport of #19078 (#19095)
16280ad is described below
commit 16280ad72aaa35f089ca4d4db680a5c5c7ff1676
Author: Nikolay Ulmasov <[email protected]>
AuthorDate: Tue Sep 29 22:58:51 2020 +0100
[1.x] Backport of #19078 (#19095)
* Assure NDArray.reshape does not change the array size
* Truncate wikitext-2 to match target array size on reshape
Co-authored-by: r3stl355 <[email protected]>
---
python/mxnet/gluon/contrib/data/text.py | 7 +++++--
python/mxnet/ndarray/ndarray.py | 7 ++++++-
tests/python/unittest/test_ndarray.py | 3 ++-
3 files changed, 13 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/gluon/contrib/data/text.py
b/python/mxnet/gluon/contrib/data/text.py
index 0536ac5..cc5da7c 100644
--- a/python/mxnet/gluon/contrib/data/text.py
+++ b/python/mxnet/gluon/contrib/data/text.py
@@ -91,8 +91,11 @@ class _WikiText(_LanguageModelDataset):
data, label = self._read_batch(path)
- self._data = nd.array(data, dtype=data.dtype).reshape((-1,
self._seq_len))
- self._label = nd.array(label, dtype=label.dtype).reshape((-1,
self._seq_len))
+ # https://github.com/apache/incubator-mxnet/issues/18886 breaks this
unless array size is
+ # multiple of self._seq_len. Truncating the source is consistent with
pre #18886 outcome
+ seq_len_mult = len(data) // self._seq_len * self._seq_len
+ self._data = nd.array(data,
dtype=data.dtype)[:seq_len_mult].reshape((-1, self._seq_len))
+ self._label = nd.array(label,
dtype=label.dtype)[:seq_len_mult].reshape((-1, self._seq_len))
def __getitem__(self, idx):
return self._data[idx], self._label[idx]
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7ac666e..d3a7bc2 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1517,7 +1517,12 @@ fixed-size items.
c_array(ctypes.c_int64, shape),
reverse,
ctypes.byref(handle)))
- return self.__class__(handle=handle, writable=self.writable)
+ res = self.__class__(handle=handle, writable=self.writable)
+
+ # Array size should not change
+ if np.prod(res.shape) != np.prod(self.shape):
+ raise ValueError('Cannot reshape array of size {} into shape
{}'.format(np.prod(self.shape), shape))
+ return res
def reshape_like(self, *args, **kwargs):
"""Convenience fluent method for :py:func:`reshape_like`.
diff --git a/tests/python/unittest/test_ndarray.py
b/tests/python/unittest/test_ndarray.py
index 167d26e..c8fbf35 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -237,7 +237,8 @@ def test_ndarray_reshape():
assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(),
true_res.reshape(2, 3, 5).asnumpy())
assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10,
3).asnumpy())
assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(),
true_res.reshape(6, 5).asnumpy())
-
+ # https://github.com/apache/incubator-mxnet/issues/18886
+ assertRaises(ValueError, tensor.reshape, (2, 3))
@with_seed()
def test_ndarray_flatten():