This is an automated email from the ASF dual-hosted git repository.

roywei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 251e6f6  Fix NDArrayIter cant pad when size is large (#17001)
251e6f6 is described below

commit 251e6f6cb7ec7741434c35a26892fa7450f751c0
Author: Jake Lee <gstu1...@gmail.com>
AuthorDate: Sun Dec 8 16:09:11 2019 -0800

    Fix NDArrayIter cant pad when size is large (#17001)
    
    * Fix NDArrayIter cant pad when size is large
    
    * ci
---
 python/mxnet/io/io.py            | 44 +++++++++++++++++++++-------------------
 tests/python/unittest/test_io.py |  6 +++---
 2 files changed, 26 insertions(+), 24 deletions(-)

diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py
index dcf964d..e36665e 100644
--- a/python/mxnet/io/io.py
+++ b/python/mxnet/io/io.py
@@ -36,7 +36,7 @@ from ..ndarray import NDArray
 from ..ndarray.sparse import CSRNDArray
 from ..ndarray import _ndarray_cls
 from ..ndarray import array
-from ..ndarray import concat
+from ..ndarray import concat, tile
 
 from .utils import _init_data, _has_instance, _getdata_by_idx
 
@@ -709,23 +709,27 @@ class NDArrayIter(DataIter):
 
     def _concat(self, first_data, second_data):
         """Helper function to concat two NDArrays."""
+        if (not first_data) or (not second_data):
+            return first_data if first_data else second_data
         assert len(first_data) == len(
             second_data), 'data source should contain the same size'
-        if first_data and second_data:
-            return [
-                concat(
-                    first_data[x],
-                    second_data[x],
-                    dim=0
-                ) for x in range(len(first_data))
-            ]
-        elif (not first_data) and (not second_data):
+        return [
+            concat(
+                first_data[i],
+                second_data[i],
+                dim=0
+            ) for i in range(len(first_data))
+        ]
+
+    def _tile(self, data, repeats):
+        if not data:
             return []
-        else:
-            return [
-                first_data[0] if first_data else second_data[0]
-                for x in range(len(first_data))
-            ]
+        res = []
+        for datum in data:
+            reps = [1] * len(datum.shape)
+            reps[0] = repeats
+            res.append(tile(datum, reps))
+        return res
 
     def _batchify(self, data_source):
         """Load data from underlying arrays, internal use only."""
@@ -749,12 +753,10 @@ class NDArrayIter(DataIter):
             pad = self.batch_size - self.num_data + self.cursor
             first_data = self._getdata(data_source, start=self.cursor)
             if pad > self.num_data:
-                while True:
-                    if pad <= self.num_data:
-                        break
-                    second_data = self._getdata(data_source, end=self.num_data)
-                    pad -= self.num_data
-                second_data = self._concat(second_data, 
self._getdata(data_source, end=pad))
+                repeats = pad // self.num_data
+                second_data = self._tile(self._getdata(data_source, 
end=self.num_data), repeats)
+                if pad % self.num_data != 0:
+                    second_data = self._concat(second_data, 
self._getdata(data_source, end=pad % self.num_data))
             else:
                 second_data = self._getdata(data_source, end=pad)
             return self._concat(first_data, second_data)
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index 2a806ef..a13addb 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -198,11 +198,11 @@ def _test_shuffle(data, labels=None):
         assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
         i += 1
 
-# fixes the issue https://github.com/apache/incubator-mxnet/issues/15535
+
 def _test_corner_case():
     data = np.arange(10)
-    data_iter = mx.io.NDArrayIter(data=data, batch_size=25, shuffle=False, 
last_batch_handle='pad')
-    expect = np.concatenate((np.tile(data, 2), np.arange(5)))
+    data_iter = mx.io.NDArrayIter(data=data, batch_size=205, shuffle=False, 
last_batch_handle='pad')
+    expect = np.concatenate((np.tile(data, 20), np.arange(5)))
     assert np.array_equal(data_iter.next().data[0].asnumpy(), expect)
 
 

Reply via email to