sandeep-krishnamurthy closed pull request #12545: Change the way NDArrayIter 
handle the last batch
URL: https://github.com/apache/incubator-mxnet/pull/12545
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py
new file mode 100644
index 00000000000..5c5e2e68d84
--- /dev/null
+++ b/python/mxnet/io/__init__.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+""" Data iterators for common data formats and utility functions."""
+from __future__ import absolute_import
+
+from . import io
+from .io import *
+
+from . import utils
+from .utils import *
diff --git a/python/mxnet/io.py b/python/mxnet/io/io.py
similarity index 81%
rename from python/mxnet/io.py
rename to python/mxnet/io/io.py
index 884e9294741..20da2ea7c67 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io/io.py
@@ -17,30 +17,26 @@
 
 """Data iterators for common data formats."""
 from __future__ import absolute_import
-from collections import OrderedDict, namedtuple
+from collections import namedtuple
 
 import sys
 import ctypes
 import logging
 import threading
-try:
-    import h5py
-except ImportError:
-    h5py = None
 import numpy as np
-from .base import _LIB
-from .base import c_str_array, mx_uint, py_str
-from .base import DataIterHandle, NDArrayHandle
-from .base import mx_real_t
-from .base import check_call, build_param_doc as _build_param_doc
-from .ndarray import NDArray
-from .ndarray.sparse import CSRNDArray
-from .ndarray.sparse import array as sparse_array
-from .ndarray import _ndarray_cls
-from .ndarray import array
-from .ndarray import concatenate
-from .ndarray import arange
-from .ndarray.random import shuffle as random_shuffle
+
+from ..base import _LIB
+from ..base import c_str_array, mx_uint, py_str
+from ..base import DataIterHandle, NDArrayHandle
+from ..base import mx_real_t
+from ..base import check_call, build_param_doc as _build_param_doc
+from ..ndarray import NDArray
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray import _ndarray_cls
+from ..ndarray import array
+from ..ndarray import concat
+
+from .utils import _init_data, _has_instance, _getdata_by_idx
 
 class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
     """DataDesc is used to store name, shape, type and layout
@@ -489,59 +485,6 @@ def getindex(self):
     def getpad(self):
         return self.current_batch.pad
 
-def _init_data(data, allow_empty, default_name):
-    """Convert data into canonical form."""
-    assert (data is not None) or allow_empty
-    if data is None:
-        data = []
-
-    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
-                  if h5py else (np.ndarray, NDArray)):
-        data = [data]
-    if isinstance(data, list):
-        if not allow_empty:
-            assert(len(data) > 0)
-        if len(data) == 1:
-            data = OrderedDict([(default_name, data[0])]) # pylint: 
disable=redefined-variable-type
-        else:
-            data = OrderedDict( # pylint: disable=redefined-variable-type
-                [('_%d_%s' % (i, default_name), d) for i, d in 
enumerate(data)])
-    if not isinstance(data, dict):
-        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " 
+ \
-                "a list of them or dict with them as values")
-    for k, v in data.items():
-        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
-            try:
-                data[k] = array(v)
-            except:
-                raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) 
+ \
-                                "should be NDArray, numpy.ndarray or 
h5py.Dataset")
-
-    return list(sorted(data.items()))
-
-def _has_instance(data, dtype):
-    """Return True if ``data`` has instance of ``dtype``.
-    This function is called after _init_data.
-    ``data`` is a list of (str, NDArray)"""
-    for item in data:
-        _, arr = item
-        if isinstance(arr, dtype):
-            return True
-    return False
-
-def _shuffle(data, idx):
-    """Shuffle the data."""
-    shuffle_data = []
-
-    for k, v in data:
-        if (isinstance(v, h5py.Dataset) if h5py else False):
-            shuffle_data.append((k, v))
-        elif isinstance(v, CSRNDArray):
-            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
-        else:
-            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
-
-    return shuffle_data
 
 class NDArrayIter(DataIter):
     """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, 
``h5py.Dataset``
@@ -601,6 +544,22 @@ class NDArrayIter(DataIter):
     ...
     >>> batchidx # Remaining examples are discarded. So, 10/3 batches are 
created.
     3
+    >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, 
last_batch_handle='roll_over')
+    >>> batchidx = 0
+    >>> for batch in dataiter:
+    ...     batchidx += 1
+    ...
+    >>> batchidx # Remaining examples are rolled over to the next iteration.
+    3
+    >>> dataiter.reset()
+    >>> dataiter.next().data[0].asnumpy()
+    [[[ 36.  37.]
+      [ 38.  39.]]
+     [[ 0.  1.]
+      [ 2.  3.]]
+     [[ 4.  5.]
+      [ 6.  7.]]]
+    (3L, 2L, 2L)
 
     `NDArrayIter` also supports multiple input and labels.
 
@@ -633,8 +592,11 @@ class NDArrayIter(DataIter):
         Only supported if no h5py.Dataset inputs are used.
     last_batch_handle : str, optional
         How to handle the last batch. This parameter can be 'pad', 'discard' or
-        'roll_over'. 'roll_over' is intended for training and can cause 
problems
-        if used for prediction.
+        'roll_over'.
+        If 'pad', the last batch will be padded with data starting from the 
begining
+        If 'discard', the last batch will be discarded
+        If 'roll_over', the remaining elements will be rolled over to the next 
iteration and
+        note that it is intended for training and can cause problems if used 
for prediction.
     data_name : str, optional
         The data name.
     label_name : str, optional
@@ -648,33 +610,26 @@ def __init__(self, data, label=None, batch_size=1, 
shuffle=False,
         self.data = _init_data(data, allow_empty=False, default_name=data_name)
         self.label = _init_data(label, allow_empty=True, 
default_name=label_name)
 
-        if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, 
CSRNDArray)) and
+        if ((_has_instance(self.data, CSRNDArray) or
+             _has_instance(self.label, CSRNDArray)) and
                 (last_batch_handle != 'discard')):
             raise NotImplementedError("`NDArrayIter` only supports 
``CSRNDArray``" \
                                       " with `last_batch_handle` set to 
`discard`.")
 
-        # shuffle data
-        if shuffle:
-            tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
-            self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
-            self.data = _shuffle(self.data, self.idx)
-            self.label = _shuffle(self.label, self.idx)
-        else:
-            self.idx = np.arange(self.data[0][1].shape[0])
-
-        # batching
-        if last_batch_handle == 'discard':
-            new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % 
batch_size
-            self.idx = self.idx[:new_n]
+        self.idx = np.arange(self.data[0][1].shape[0])
+        self.shuffle = shuffle
+        self.last_batch_handle = last_batch_handle
+        self.batch_size = batch_size
+        self.cursor = -self.batch_size
+        self.num_data = self.idx.shape[0]
+        # shuffle
+        self.reset()
 
         self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
         self.num_source = len(self.data_list)
-        self.num_data = self.idx.shape[0]
-        assert self.num_data >= batch_size, \
-            "batch_size needs to be smaller than data size."
-        self.cursor = -batch_size
-        self.batch_size = batch_size
-        self.last_batch_handle = last_batch_handle
+        # used for 'roll_over'
+        self._cache_data = None
+        self._cache_label = None
 
     @property
     def provide_data(self):
@@ -694,74 +649,141 @@ def provide_label(self):
 
     def hard_reset(self):
         """Ignore roll over data and set to start."""
+        if self.shuffle:
+            self._shuffle_data()
         self.cursor = -self.batch_size
+        self._cache_data = None
+        self._cache_label = None
 
     def reset(self):
-        if self.last_batch_handle == 'roll_over' and self.cursor > 
self.num_data:
-            self.cursor = -self.batch_size + 
(self.cursor%self.num_data)%self.batch_size
+        """Resets the iterator to the beginning of the data."""
+        if self.shuffle:
+            self._shuffle_data()
+        # the range below indicate the last batch
+        if self.last_batch_handle == 'roll_over' and \
+            self.num_data - self.batch_size < self.cursor < self.num_data:
+            # (self.cursor - self.num_data) represents the data we have for 
the last batch
+            self.cursor = self.cursor - self.num_data - self.batch_size
         else:
             self.cursor = -self.batch_size
 
     def iter_next(self):
+        """Increments the coursor by batch_size for next batch
+        and check current cursor if it exceed the number of data points."""
         self.cursor += self.batch_size
         return self.cursor < self.num_data
 
     def next(self):
-        if self.iter_next():
-            return DataBatch(data=self.getdata(), label=self.getlabel(), \
-                    pad=self.getpad(), index=None)
-        else:
+        """Returns the next batch of data."""
+        if not self.iter_next():
             raise StopIteration
+        data = self.getdata()
+        label = self.getlabel()
+        # iter should stop when last batch is not complete
+        if data[0].shape[0] != self.batch_size:
+        # in this case, cache it for next epoch
+            self._cache_data = data
+            self._cache_label = label
+            raise StopIteration
+        return DataBatch(data=data, label=label, \
+            pad=self.getpad(), index=None)
+
+    def _getdata(self, data_source, start=None, end=None):
+        """Load data from underlying arrays."""
+        assert start is not None or end is not None, 'should at least specify 
start or end'
+        start = start if start is not None else 0
+        if end is None:
+            end = data_source[0][1].shape[0] if data_source else 0
+        s = slice(start, end)
+        return [
+            x[1][s]
+            if isinstance(x[1], (np.ndarray, NDArray)) else
+            # h5py (only supports indices in increasing order)
+            array(x[1][sorted(self.idx[s])][[
+                list(self.idx[s]).index(i)
+                for i in sorted(self.idx[s])
+            ]]) for x in data_source
+        ]
 
-    def _getdata(self, data_source):
-        """Load data from underlying arrays, internal use only."""
-        assert(self.cursor < self.num_data), "DataIter needs reset."
-        if self.cursor + self.batch_size <= self.num_data:
+    def _concat(self, first_data, second_data):
+        """Helper function to concat two NDArrays."""
+        assert len(first_data) == len(
+            second_data), 'data source should contain the same size'
+        if first_data and second_data:
             return [
-                # np.ndarray or NDArray case
-                x[1][self.cursor:self.cursor + self.batch_size]
-                if isinstance(x[1], (np.ndarray, NDArray)) else
-                # h5py (only supports indices in increasing order)
-                array(x[1][sorted(self.idx[
-                    self.cursor:self.cursor + self.batch_size])][[
-                        list(self.idx[self.cursor:
-                                      self.cursor + self.batch_size]).index(i)
-                        for i in sorted(self.idx[
-                            self.cursor:self.cursor + self.batch_size])
-                    ]]) for x in data_source
+                concat(
+                    first_data[x],
+                    second_data[x],
+                    dim=0
+                ) for x in range(len(first_data))
             ]
+        elif (not first_data) and (not second_data):
+            return []
         else:
-            pad = self.batch_size - self.num_data + self.cursor
             return [
-                # np.ndarray or NDArray case
-                concatenate([x[1][self.cursor:], x[1][:pad]])
-                if isinstance(x[1], (np.ndarray, NDArray)) else
-                # h5py (only supports indices in increasing order)
-                concatenate([
-                    array(x[1][sorted(self.idx[self.cursor:])][[
-                        list(self.idx[self.cursor:]).index(i)
-                        for i in sorted(self.idx[self.cursor:])
-                    ]]),
-                    array(x[1][sorted(self.idx[:pad])][[
-                        list(self.idx[:pad]).index(i)
-                        for i in sorted(self.idx[:pad])
-                    ]])
-                ]) for x in data_source
+                first_data[0] if first_data else second_data[0]
+                for x in range(len(first_data))
             ]
 
+    def _batchify(self, data_source):
+        """Load data from underlying arrays, internal use only."""
+        assert self.cursor < self.num_data, 'DataIter needs reset.'
+        # first batch of next epoch with 'roll_over'
+        if self.last_batch_handle == 'roll_over' and \
+            -self.batch_size < self.cursor < 0:
+            assert self._cache_data is not None or self._cache_label is not 
None, \
+                'next epoch should have cached data'
+            cache_data = self._cache_data if self._cache_data is not None else 
self._cache_label
+            second_data = self._getdata(
+                data_source, end=self.cursor + self.batch_size)
+            if self._cache_data is not None:
+                self._cache_data = None
+            else:
+                self._cache_label = None
+            return self._concat(cache_data, second_data)
+        # last batch with 'pad'
+        elif self.last_batch_handle == 'pad' and \
+            self.cursor + self.batch_size > self.num_data:
+            pad = self.batch_size - self.num_data + self.cursor
+            first_data = self._getdata(data_source, start=self.cursor)
+            second_data = self._getdata(data_source, end=pad)
+            return self._concat(first_data, second_data)
+        # normal case
+        else:
+            if self.cursor + self.batch_size < self.num_data:
+                end_idx = self.cursor + self.batch_size
+            # get incomplete last batch
+            else:
+                end_idx = self.num_data
+            return self._getdata(data_source, self.cursor, end_idx)
+
     def getdata(self):
-        return self._getdata(self.data)
+        """Get data."""
+        return self._batchify(self.data)
 
     def getlabel(self):
-        return self._getdata(self.label)
+        """Get label."""
+        return self._batchify(self.label)
 
     def getpad(self):
+        """Get pad value of DataBatch."""
         if self.last_batch_handle == 'pad' and \
            self.cursor + self.batch_size > self.num_data:
             return self.cursor + self.batch_size - self.num_data
+        # check the first batch
+        elif self.last_batch_handle == 'roll_over' and \
+            -self.batch_size < self.cursor < 0:
+            return -self.cursor
         else:
             return 0
 
+    def _shuffle_data(self):
+        """Shuffle the data."""
+        # shuffle index
+        np.random.shuffle(self.idx)
+        # get the data by corresponding index
+        self.data = _getdata_by_idx(self.data, self.idx)
+        self.label = _getdata_by_idx(self.label, self.idx)
 
 class MXDataIter(DataIter):
     """A python wrapper a C++ data iterator.
diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py
new file mode 100644
index 00000000000..55ba34aea42
--- /dev/null
+++ b/python/mxnet/io/utils.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""utility functions for io.py"""
+from collections import OrderedDict
+
+import numpy as np
+try:
+    import h5py
+except ImportError:
+    h5py = None
+
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray.sparse import array as sparse_array
+from ..ndarray import NDArray
+from ..ndarray import array
+
+def _init_data(data, allow_empty, default_name):
+    """Convert data into canonical form."""
+    assert (data is not None) or allow_empty
+    if data is None:
+        data = []
+
+    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
+                  if h5py else (np.ndarray, NDArray)):
+        data = [data]
+    if isinstance(data, list):
+        if not allow_empty:
+            assert(len(data) > 0)
+        if len(data) == 1:
+            data = OrderedDict([(default_name, data[0])])  # pylint: 
disable=redefined-variable-type
+        else:
+            data = OrderedDict(  # pylint: disable=redefined-variable-type
+                [('_%d_%s' % (i, default_name), d) for i, d in 
enumerate(data)])
+    if not isinstance(data, dict):
+        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " +
+                        "a list of them or dict with them as values")
+    for k, v in data.items():
+        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
+            try:
+                data[k] = array(v)
+            except:
+                raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) +
+                                "should be NDArray, numpy.ndarray or 
h5py.Dataset")
+
+    return list(sorted(data.items()))
+
+
+def _has_instance(data, dtype):
+    """Return True if ``data`` has instance of ``dtype``.
+    This function is called after _init_data.
+    ``data`` is a list of (str, NDArray)"""
+    for item in data:
+        _, arr = item
+        if isinstance(arr, dtype):
+            return True
+    return False
+
+
+def _getdata_by_idx(data, idx):
+    """Shuffle the data."""
+    shuffle_data = []
+
+    for k, v in data:
+        if (isinstance(v, h5py.Dataset) if h5py else False):
+            shuffle_data.append((k, v))
+        elif isinstance(v, CSRNDArray):
+            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
+        else:
+            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
+
+    return shuffle_data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index 872763f5a78..0641f235aa7 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -17,6 +17,7 @@
 
 # pylint: skip-file
 import mxnet as mx
+import mxnet.ndarray as nd
 from mxnet.test_utils import *
 from mxnet.base import MXNetError
 import numpy as np
@@ -106,79 +107,139 @@ def check_cifar10_exception():
             pass
     assertRaises(MXNetError, check_cifar10_exception)
 
-def test_NDArrayIter():
-    data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
+def _init_NDArrayIter_data(data_type, is_image=False):
+    if is_image:
+        data = nd.random.uniform(0, 255, shape=(5000, 1, 28, 28))
+        labels = nd.ones((5000, 1))
+        return data, labels
+    if data_type == 'NDArray':
+        data = nd.ones((1000, 2, 2))
+        labels = nd.ones((1000, 1))
+    else:
+        data = np.ones((1000, 2, 2))
+        labels = np.ones((1000, 1))
     for i in range(1000):
         data[i] = i / 100
-        label[i] = i / 100
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, True, last_batch_handle='pad')
-    batchidx = 0
+        labels[i] = i / 100
+    return data, labels
+
+
+def _test_last_batch_handle(data, labels=None, is_image=False):
+    # Test the three parameters 'pad', 'discard', 'roll_over'
+    last_batch_handle_list = ['pad', 'discard', 'roll_over']
+    if labels is not None and not is_image and len(labels) != 0:
+        labelcount_list = [(124, 100), (100, 96), (100, 96)]
+    if is_image:
+        batch_count_list = [40, 39, 39]
+    else:
+        batch_count_list = [8, 7, 7]
+    
+    for idx in range(len(last_batch_handle_list)):
+        dataiter = mx.io.NDArrayIter(
+            data, labels, 128, False, 
last_batch_handle=last_batch_handle_list[idx])
+        batch_count = 0
+        if labels is not None and len(labels) != 0 and not is_image:
+            labelcount = [0 for i in range(10)]
+        for batch in dataiter:
+            if len(data) == 2:
+                assert len(batch.data) == 2
+            if labels is not None and len(labels) != 0:
+                if not is_image:
+                    label = batch.label[0].asnumpy().flatten()
+                    # check data if it matches corresponding labels
+                    assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
+                    for i in range(label.shape[0]):
+                       labelcount[int(label[i])] += 1
+            else:
+                assert not batch.label, 'label is not empty list'
+            # keep the last batch of 'pad' to be used later 
+            # to test first batch of roll_over in second iteration
+            batch_count += 1
+            if last_batch_handle_list[idx] == 'pad' and \
+                batch_count == batch_count_list[0]:
+                cache = batch.data[0].asnumpy()
+        # check if batchifying functionality work properly
+        if labels is not None and len(labels) != 0 and not is_image:
+            assert labelcount[0] == labelcount_list[idx][0], 
last_batch_handle_list[idx]
+            assert labelcount[8] == labelcount_list[idx][1], 
last_batch_handle_list[idx]
+        assert batch_count == batch_count_list[idx]
+    # roll_over option
+    dataiter.reset()
+    assert np.array_equal(dataiter.next().data[0].asnumpy(), cache)
+
+
+def _test_shuffle(data, labels=None):
+    dataiter = mx.io.NDArrayIter(data, labels, 1, False)
+    batch_list = []
     for batch in dataiter:
-        batchidx += 1
-    assert(batchidx == 8)
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, False, last_batch_handle='pad')
-    batchidx = 0
-    labelcount = [0 for i in range(10)]
+        # cache the original data
+        batch_list.append(batch.data[0].asnumpy())
+    dataiter = mx.io.NDArrayIter(data, labels, 1, True)
+    idx_list = dataiter.idx
+    i = 0
     for batch in dataiter:
-        label = batch.label[0].asnumpy().flatten()
-        assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-        for i in range(label.shape[0]):
-            labelcount[int(label[i])] += 1
+        # check if each data point have been shuffled to corresponding 
positions
+        assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
+        i += 1
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
+
+def test_NDArrayIter():
+    dtype_list = ['NDArray', 'ndarray']
+    tested_data_type = [False, True]
+    for dtype in dtype_list:
+        for is_image in tested_data_type:
+            data, labels = _init_NDArrayIter_data(dtype, is_image)
+            _test_last_batch_handle(data, labels, is_image)
+            _test_last_batch_handle([data, data], labels, is_image)
+            _test_last_batch_handle(data=[data, data], is_image=is_image)
+            _test_last_batch_handle(
+                {'data1': data, 'data2': data}, labels, is_image)
+            _test_last_batch_handle(data={'data1': data, 'data2': data}, 
is_image=is_image)
+            _test_last_batch_handle(data, [], is_image)
+            _test_last_batch_handle(data=data, is_image=is_image)
+            _test_shuffle(data, labels)
+            _test_shuffle([data, data], labels)
+            _test_shuffle([data, data])
+            _test_shuffle({'data1': data, 'data2': data}, labels)
+            _test_shuffle({'data1': data, 'data2': data})
+            _test_shuffle(data, [])
+            _test_shuffle(data)
 
 
 def test_NDArrayIter_h5py():
     if not h5py:
         return
 
-    data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
-    for i in range(1000):
-        data[i] = i / 100
-        label[i] = i / 100
+    data, labels = _init_NDArrayIter_data('ndarray')
 
     try:
-        os.remove("ndarraytest.h5")
+        os.remove('ndarraytest.h5')
     except OSError:
         pass
-    with h5py.File("ndarraytest.h5") as f:
-        f.create_dataset("data", data=data)
-        f.create_dataset("label", data=label)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, True, last_batch_handle='pad')
-        batchidx = 0
-        for batch in dataiter:
-            batchidx += 1
-        assert(batchidx == 8)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, False, last_batch_handle='pad')
-        labelcount = [0 for i in range(10)]
-        for batch in dataiter:
-            label = batch.label[0].asnumpy().flatten()
-            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-            for i in range(label.shape[0]):
-                labelcount[int(label[i])] += 1
-
+    with h5py.File('ndarraytest.h5') as f:
+        f.create_dataset('data', data=data)
+        f.create_dataset('label', data=labels)
+        
+        _test_last_batch_handle(f['data'], f['label'])
+        _test_last_batch_handle(f['data'], [])
+        _test_last_batch_handle(f['data'])
     try:
         os.remove("ndarraytest.h5")
     except OSError:
         pass
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
+
+def _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list, csr_iter_None, 
num_rows, batch_size):
+    num_batch = 0
+    for _, batch_empty_list, batch_empty_None in zip(csr_iter, 
csr_iter_empty_list, csr_iter_None):
+        assert not batch_empty_list.label, 'label is not empty list'
+        assert not batch_empty_None.label, 'label is not empty list'
+        num_batch += 1
+
+    assert(num_batch == num_rows // batch_size)
+    assertRaises(StopIteration, csr_iter.next)
+    assertRaises(StopIteration, csr_iter_empty_list.next)
+    assertRaises(StopIteration, csr_iter_None.next)
 
 
 def test_NDArrayIter_csr():
@@ -200,15 +261,26 @@ def test_NDArrayIter_csr():
                      {'data': train_data}, dns, batch_size)
     except ImportError:
         pass
+    
+    # scipy.sparse.csr_matrix with shuffle
+    csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    csr_iter_empty_list = iter(mx.io.NDArrayIter({'data': train_data}, [], 
batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    csr_iter_None = iter(mx.io.NDArrayIter({'data': train_data}, None, 
batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list,
+                          csr_iter_None, num_rows, batch_size)
 
     # CSRNDArray with shuffle
     csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, 
batch_size,
                                       shuffle=True, 
last_batch_handle='discard'))
-    num_batch = 0
-    for batch in csr_iter:
-        num_batch += 1
-
-    assert(num_batch == num_rows // batch_size)
+    csr_iter_empty_list = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': 
dns}, [], batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    csr_iter_None = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, 
None, batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list,
+                          csr_iter_None, num_rows, batch_size)
 
     # make iterators
     csr_iter = iter(mx.io.NDArrayIter(


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to