This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 16280ad  [1.x] Backport of #19078 (#19095)
16280ad is described below

commit 16280ad72aaa35f089ca4d4db680a5c5c7ff1676
Author: Nikolay Ulmasov <[email protected]>
AuthorDate: Tue Sep 29 22:58:51 2020 +0100

    [1.x] Backport of #19078 (#19095)
    
    * Assure NDArray.reshape does not change the array size
    
    * Truncate wikitext-2 to match target array size on reshape
    
    Co-authored-by: r3stl355 <[email protected]>
---
 python/mxnet/gluon/contrib/data/text.py | 7 +++++--
 python/mxnet/ndarray/ndarray.py         | 7 ++++++-
 tests/python/unittest/test_ndarray.py   | 3 ++-
 3 files changed, 13 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/gluon/contrib/data/text.py 
b/python/mxnet/gluon/contrib/data/text.py
index 0536ac5..cc5da7c 100644
--- a/python/mxnet/gluon/contrib/data/text.py
+++ b/python/mxnet/gluon/contrib/data/text.py
@@ -91,8 +91,11 @@ class _WikiText(_LanguageModelDataset):
 
         data, label = self._read_batch(path)
 
-        self._data = nd.array(data, dtype=data.dtype).reshape((-1, 
self._seq_len))
-        self._label = nd.array(label, dtype=label.dtype).reshape((-1, 
self._seq_len))
+        # https://github.com/apache/incubator-mxnet/issues/18886 breaks this 
unless array size is
+        # multiple of self._seq_len. Truncating the source is consistent with 
pre #18886 outcome
+        seq_len_mult = len(data) // self._seq_len * self._seq_len
+        self._data = nd.array(data, 
dtype=data.dtype)[:seq_len_mult].reshape((-1, self._seq_len))
+        self._label = nd.array(label, 
dtype=label.dtype)[:seq_len_mult].reshape((-1, self._seq_len))
 
     def __getitem__(self, idx):
         return self._data[idx], self._label[idx]
diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py
index 7ac666e..d3a7bc2 100644
--- a/python/mxnet/ndarray/ndarray.py
+++ b/python/mxnet/ndarray/ndarray.py
@@ -1517,7 +1517,12 @@ fixed-size items.
                                            c_array(ctypes.c_int64, shape),
                                            reverse,
                                            ctypes.byref(handle)))
-        return self.__class__(handle=handle, writable=self.writable)
+        res = self.__class__(handle=handle, writable=self.writable)
+
+        # Array size should not change
+        if np.prod(res.shape) != np.prod(self.shape):
+            raise ValueError('Cannot reshape array of size {} into shape 
{}'.format(np.prod(self.shape), shape))
+        return res
 
     def reshape_like(self, *args, **kwargs):
         """Convenience fluent method for :py:func:`reshape_like`.
diff --git a/tests/python/unittest/test_ndarray.py 
b/tests/python/unittest/test_ndarray.py
index 167d26e..c8fbf35 100644
--- a/tests/python/unittest/test_ndarray.py
+++ b/tests/python/unittest/test_ndarray.py
@@ -237,7 +237,8 @@ def test_ndarray_reshape():
     assert same(tensor.reshape(-1, 15).reshape(0, -4, 3, -1).asnumpy(), 
true_res.reshape(2, 3, 5).asnumpy())
     assert same(tensor.reshape(-1, 0).asnumpy(), true_res.reshape(10, 
3).asnumpy())
     assert same(tensor.reshape(-1, 0, reverse=True).asnumpy(), 
true_res.reshape(6, 5).asnumpy())
-
+    # https://github.com/apache/incubator-mxnet/issues/18886
+    assertRaises(ValueError, tensor.reshape, (2, 3))
 
 @with_seed()
 def test_ndarray_flatten():

Reply via email to