This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 13030b6  Change the way NDArrayIter handle the last batch (#12545)
13030b6 is described below

commit 13030b6e35ee9bbdaf674e6e5828069edf9b85f5
Author: Jake Lee <[email protected]>
AuthorDate: Tue Oct 16 10:00:47 2018 -0700

    Change the way NDArrayIter handle the last batch (#12545)
    
    * 1. move the shuffle to the reset 2. modify the roll_over behavior 
accordingly
    
    * refactor the concat part
    
    * refactor the code
    
    * implement unit test for last_batch_handle
    
    * refactor the getdata part
    
    * add docstring and refine the code according to linter
    
    * 1. add test case for NDArrayIter_h5py 2. refactor the implementation
    
    * update contributions doc
    
    * fix wording
    
    * update doc for roll_over
    
    * 1. add test for second iteration of roll_over 2. add shuffle test case
    
    * fix some wording and refine the variables naming
    
    * move utility function to new file
    
    * move utility function to io_utils.py
    
    * change shuffle function name to avoid redefining name
    
    * make io as a module
    
    * rename the utility functions
    
    * disable wildcard-import
    
    * fix the algorithm
    
    * refactor the code
    
    * test the NDArrayIter with different combinations of shuffle=True, 
data_source type and lables
    
    * add edge case of label data for csr NDArrayIter
    
    * trigger Travis CI
    
    * handle the 'list' of data source
    
    * check the list of data source
    
    * fix the extra blank
    
    * Trigger CI
    
    * add _ to the utility functions
    
    * Trigger CI
    
    * update several test cases
    
    * add test case for airbnb
    
    * fix the typo
    
    * fix wrong labels data shape
    
    * switch the order of condition to make more sense
---
 python/mxnet/io/__init__.py      |  29 ++++
 python/mxnet/{ => io}/io.py      | 284 +++++++++++++++++++++------------------
 python/mxnet/io/utils.py         |  86 ++++++++++++
 tests/python/unittest/test_io.py | 188 ++++++++++++++++++--------
 4 files changed, 398 insertions(+), 189 deletions(-)

diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py
new file mode 100644
index 0000000..5c5e2e6
--- /dev/null
+++ b/python/mxnet/io/__init__.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+""" Data iterators for common data formats and utility functions."""
+from __future__ import absolute_import
+
+from . import io
+from .io import *
+
+from . import utils
+from .utils import *
diff --git a/python/mxnet/io.py b/python/mxnet/io/io.py
similarity index 81%
rename from python/mxnet/io.py
rename to python/mxnet/io/io.py
index 884e929..20da2ea 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io/io.py
@@ -17,30 +17,26 @@
 
 """Data iterators for common data formats."""
 from __future__ import absolute_import
-from collections import OrderedDict, namedtuple
+from collections import namedtuple
 
 import sys
 import ctypes
 import logging
 import threading
-try:
-    import h5py
-except ImportError:
-    h5py = None
 import numpy as np
-from .base import _LIB
-from .base import c_str_array, mx_uint, py_str
-from .base import DataIterHandle, NDArrayHandle
-from .base import mx_real_t
-from .base import check_call, build_param_doc as _build_param_doc
-from .ndarray import NDArray
-from .ndarray.sparse import CSRNDArray
-from .ndarray.sparse import array as sparse_array
-from .ndarray import _ndarray_cls
-from .ndarray import array
-from .ndarray import concatenate
-from .ndarray import arange
-from .ndarray.random import shuffle as random_shuffle
+
+from ..base import _LIB
+from ..base import c_str_array, mx_uint, py_str
+from ..base import DataIterHandle, NDArrayHandle
+from ..base import mx_real_t
+from ..base import check_call, build_param_doc as _build_param_doc
+from ..ndarray import NDArray
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray import _ndarray_cls
+from ..ndarray import array
+from ..ndarray import concat
+
+from .utils import _init_data, _has_instance, _getdata_by_idx
 
 class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
     """DataDesc is used to store name, shape, type and layout
@@ -489,59 +485,6 @@ class PrefetchingIter(DataIter):
     def getpad(self):
         return self.current_batch.pad
 
-def _init_data(data, allow_empty, default_name):
-    """Convert data into canonical form."""
-    assert (data is not None) or allow_empty
-    if data is None:
-        data = []
-
-    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
-                  if h5py else (np.ndarray, NDArray)):
-        data = [data]
-    if isinstance(data, list):
-        if not allow_empty:
-            assert(len(data) > 0)
-        if len(data) == 1:
-            data = OrderedDict([(default_name, data[0])]) # pylint: 
disable=redefined-variable-type
-        else:
-            data = OrderedDict( # pylint: disable=redefined-variable-type
-                [('_%d_%s' % (i, default_name), d) for i, d in 
enumerate(data)])
-    if not isinstance(data, dict):
-        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " 
+ \
-                "a list of them or dict with them as values")
-    for k, v in data.items():
-        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
-            try:
-                data[k] = array(v)
-            except:
-                raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) 
+ \
-                                "should be NDArray, numpy.ndarray or 
h5py.Dataset")
-
-    return list(sorted(data.items()))
-
-def _has_instance(data, dtype):
-    """Return True if ``data`` has instance of ``dtype``.
-    This function is called after _init_data.
-    ``data`` is a list of (str, NDArray)"""
-    for item in data:
-        _, arr = item
-        if isinstance(arr, dtype):
-            return True
-    return False
-
-def _shuffle(data, idx):
-    """Shuffle the data."""
-    shuffle_data = []
-
-    for k, v in data:
-        if (isinstance(v, h5py.Dataset) if h5py else False):
-            shuffle_data.append((k, v))
-        elif isinstance(v, CSRNDArray):
-            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
-        else:
-            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
-
-    return shuffle_data
 
 class NDArrayIter(DataIter):
     """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, 
``h5py.Dataset``
@@ -601,6 +544,22 @@ class NDArrayIter(DataIter):
     ...
     >>> batchidx # Remaining examples are discarded. So, 10/3 batches are 
created.
     3
+    >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, 
last_batch_handle='roll_over')
+    >>> batchidx = 0
+    >>> for batch in dataiter:
+    ...     batchidx += 1
+    ...
+    >>> batchidx # Remaining examples are rolled over to the next iteration.
+    3
+    >>> dataiter.reset()
+    >>> dataiter.next().data[0].asnumpy()
+    [[[ 36.  37.]
+      [ 38.  39.]]
+     [[ 0.  1.]
+      [ 2.  3.]]
+     [[ 4.  5.]
+      [ 6.  7.]]]
+    (3L, 2L, 2L)
 
     `NDArrayIter` also supports multiple input and labels.
 
@@ -633,8 +592,11 @@ class NDArrayIter(DataIter):
         Only supported if no h5py.Dataset inputs are used.
     last_batch_handle : str, optional
         How to handle the last batch. This parameter can be 'pad', 'discard' or
-        'roll_over'. 'roll_over' is intended for training and can cause 
problems
-        if used for prediction.
+        'roll_over'.
+        If 'pad', the last batch will be padded with data starting from the 
begining
+        If 'discard', the last batch will be discarded
+        If 'roll_over', the remaining elements will be rolled over to the next 
iteration and
+        note that it is intended for training and can cause problems if used 
for prediction.
     data_name : str, optional
         The data name.
     label_name : str, optional
@@ -648,33 +610,26 @@ class NDArrayIter(DataIter):
         self.data = _init_data(data, allow_empty=False, default_name=data_name)
         self.label = _init_data(label, allow_empty=True, 
default_name=label_name)
 
-        if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, 
CSRNDArray)) and
+        if ((_has_instance(self.data, CSRNDArray) or
+             _has_instance(self.label, CSRNDArray)) and
                 (last_batch_handle != 'discard')):
             raise NotImplementedError("`NDArrayIter` only supports 
``CSRNDArray``" \
                                       " with `last_batch_handle` set to 
`discard`.")
 
-        # shuffle data
-        if shuffle:
-            tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
-            self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
-            self.data = _shuffle(self.data, self.idx)
-            self.label = _shuffle(self.label, self.idx)
-        else:
-            self.idx = np.arange(self.data[0][1].shape[0])
-
-        # batching
-        if last_batch_handle == 'discard':
-            new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % 
batch_size
-            self.idx = self.idx[:new_n]
+        self.idx = np.arange(self.data[0][1].shape[0])
+        self.shuffle = shuffle
+        self.last_batch_handle = last_batch_handle
+        self.batch_size = batch_size
+        self.cursor = -self.batch_size
+        self.num_data = self.idx.shape[0]
+        # shuffle
+        self.reset()
 
         self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
         self.num_source = len(self.data_list)
-        self.num_data = self.idx.shape[0]
-        assert self.num_data >= batch_size, \
-            "batch_size needs to be smaller than data size."
-        self.cursor = -batch_size
-        self.batch_size = batch_size
-        self.last_batch_handle = last_batch_handle
+        # used for 'roll_over'
+        self._cache_data = None
+        self._cache_label = None
 
     @property
     def provide_data(self):
@@ -694,74 +649,141 @@ class NDArrayIter(DataIter):
 
     def hard_reset(self):
         """Ignore roll over data and set to start."""
+        if self.shuffle:
+            self._shuffle_data()
         self.cursor = -self.batch_size
+        self._cache_data = None
+        self._cache_label = None
 
     def reset(self):
-        if self.last_batch_handle == 'roll_over' and self.cursor > 
self.num_data:
-            self.cursor = -self.batch_size + 
(self.cursor%self.num_data)%self.batch_size
+        """Resets the iterator to the beginning of the data."""
+        if self.shuffle:
+            self._shuffle_data()
+        # the range below indicate the last batch
+        if self.last_batch_handle == 'roll_over' and \
+            self.num_data - self.batch_size < self.cursor < self.num_data:
+            # (self.cursor - self.num_data) represents the data we have for 
the last batch
+            self.cursor = self.cursor - self.num_data - self.batch_size
         else:
             self.cursor = -self.batch_size
 
     def iter_next(self):
+        """Increments the coursor by batch_size for next batch
+        and check current cursor if it exceed the number of data points."""
         self.cursor += self.batch_size
         return self.cursor < self.num_data
 
     def next(self):
-        if self.iter_next():
-            return DataBatch(data=self.getdata(), label=self.getlabel(), \
-                    pad=self.getpad(), index=None)
-        else:
+        """Returns the next batch of data."""
+        if not self.iter_next():
             raise StopIteration
+        data = self.getdata()
+        label = self.getlabel()
+        # iter should stop when last batch is not complete
+        if data[0].shape[0] != self.batch_size:
+        # in this case, cache it for next epoch
+            self._cache_data = data
+            self._cache_label = label
+            raise StopIteration
+        return DataBatch(data=data, label=label, \
+            pad=self.getpad(), index=None)
+
+    def _getdata(self, data_source, start=None, end=None):
+        """Load data from underlying arrays."""
+        assert start is not None or end is not None, 'should at least specify 
start or end'
+        start = start if start is not None else 0
+        if end is None:
+            end = data_source[0][1].shape[0] if data_source else 0
+        s = slice(start, end)
+        return [
+            x[1][s]
+            if isinstance(x[1], (np.ndarray, NDArray)) else
+            # h5py (only supports indices in increasing order)
+            array(x[1][sorted(self.idx[s])][[
+                list(self.idx[s]).index(i)
+                for i in sorted(self.idx[s])
+            ]]) for x in data_source
+        ]
 
-    def _getdata(self, data_source):
-        """Load data from underlying arrays, internal use only."""
-        assert(self.cursor < self.num_data), "DataIter needs reset."
-        if self.cursor + self.batch_size <= self.num_data:
+    def _concat(self, first_data, second_data):
+        """Helper function to concat two NDArrays."""
+        assert len(first_data) == len(
+            second_data), 'data source should contain the same size'
+        if first_data and second_data:
             return [
-                # np.ndarray or NDArray case
-                x[1][self.cursor:self.cursor + self.batch_size]
-                if isinstance(x[1], (np.ndarray, NDArray)) else
-                # h5py (only supports indices in increasing order)
-                array(x[1][sorted(self.idx[
-                    self.cursor:self.cursor + self.batch_size])][[
-                        list(self.idx[self.cursor:
-                                      self.cursor + self.batch_size]).index(i)
-                        for i in sorted(self.idx[
-                            self.cursor:self.cursor + self.batch_size])
-                    ]]) for x in data_source
+                concat(
+                    first_data[x],
+                    second_data[x],
+                    dim=0
+                ) for x in range(len(first_data))
             ]
+        elif (not first_data) and (not second_data):
+            return []
         else:
-            pad = self.batch_size - self.num_data + self.cursor
             return [
-                # np.ndarray or NDArray case
-                concatenate([x[1][self.cursor:], x[1][:pad]])
-                if isinstance(x[1], (np.ndarray, NDArray)) else
-                # h5py (only supports indices in increasing order)
-                concatenate([
-                    array(x[1][sorted(self.idx[self.cursor:])][[
-                        list(self.idx[self.cursor:]).index(i)
-                        for i in sorted(self.idx[self.cursor:])
-                    ]]),
-                    array(x[1][sorted(self.idx[:pad])][[
-                        list(self.idx[:pad]).index(i)
-                        for i in sorted(self.idx[:pad])
-                    ]])
-                ]) for x in data_source
+                first_data[0] if first_data else second_data[0]
+                for x in range(len(first_data))
             ]
 
+    def _batchify(self, data_source):
+        """Load data from underlying arrays, internal use only."""
+        assert self.cursor < self.num_data, 'DataIter needs reset.'
+        # first batch of next epoch with 'roll_over'
+        if self.last_batch_handle == 'roll_over' and \
+            -self.batch_size < self.cursor < 0:
+            assert self._cache_data is not None or self._cache_label is not 
None, \
+                'next epoch should have cached data'
+            cache_data = self._cache_data if self._cache_data is not None else 
self._cache_label
+            second_data = self._getdata(
+                data_source, end=self.cursor + self.batch_size)
+            if self._cache_data is not None:
+                self._cache_data = None
+            else:
+                self._cache_label = None
+            return self._concat(cache_data, second_data)
+        # last batch with 'pad'
+        elif self.last_batch_handle == 'pad' and \
+            self.cursor + self.batch_size > self.num_data:
+            pad = self.batch_size - self.num_data + self.cursor
+            first_data = self._getdata(data_source, start=self.cursor)
+            second_data = self._getdata(data_source, end=pad)
+            return self._concat(first_data, second_data)
+        # normal case
+        else:
+            if self.cursor + self.batch_size < self.num_data:
+                end_idx = self.cursor + self.batch_size
+            # get incomplete last batch
+            else:
+                end_idx = self.num_data
+            return self._getdata(data_source, self.cursor, end_idx)
+
     def getdata(self):
-        return self._getdata(self.data)
+        """Get data."""
+        return self._batchify(self.data)
 
     def getlabel(self):
-        return self._getdata(self.label)
+        """Get label."""
+        return self._batchify(self.label)
 
     def getpad(self):
+        """Get pad value of DataBatch."""
         if self.last_batch_handle == 'pad' and \
            self.cursor + self.batch_size > self.num_data:
             return self.cursor + self.batch_size - self.num_data
+        # check the first batch
+        elif self.last_batch_handle == 'roll_over' and \
+            -self.batch_size < self.cursor < 0:
+            return -self.cursor
         else:
             return 0
 
+    def _shuffle_data(self):
+        """Shuffle the data."""
+        # shuffle index
+        np.random.shuffle(self.idx)
+        # get the data by corresponding index
+        self.data = _getdata_by_idx(self.data, self.idx)
+        self.label = _getdata_by_idx(self.label, self.idx)
 
 class MXDataIter(DataIter):
     """A python wrapper a C++ data iterator.
diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py
new file mode 100644
index 0000000..55ba34a
--- /dev/null
+++ b/python/mxnet/io/utils.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""utility functions for io.py"""
+from collections import OrderedDict
+
+import numpy as np
+try:
+    import h5py
+except ImportError:
+    h5py = None
+
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray.sparse import array as sparse_array
+from ..ndarray import NDArray
+from ..ndarray import array
+
+def _init_data(data, allow_empty, default_name):
+    """Convert data into canonical form."""
+    assert (data is not None) or allow_empty
+    if data is None:
+        data = []
+
+    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
+                  if h5py else (np.ndarray, NDArray)):
+        data = [data]
+    if isinstance(data, list):
+        if not allow_empty:
+            assert(len(data) > 0)
+        if len(data) == 1:
+            data = OrderedDict([(default_name, data[0])])  # pylint: 
disable=redefined-variable-type
+        else:
+            data = OrderedDict(  # pylint: disable=redefined-variable-type
+                [('_%d_%s' % (i, default_name), d) for i, d in 
enumerate(data)])
+    if not isinstance(data, dict):
+        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " +
+                        "a list of them or dict with them as values")
+    for k, v in data.items():
+        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
+            try:
+                data[k] = array(v)
+            except:
+                raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) +
+                                "should be NDArray, numpy.ndarray or 
h5py.Dataset")
+
+    return list(sorted(data.items()))
+
+
+def _has_instance(data, dtype):
+    """Return True if ``data`` has instance of ``dtype``.
+    This function is called after _init_data.
+    ``data`` is a list of (str, NDArray)"""
+    for item in data:
+        _, arr = item
+        if isinstance(arr, dtype):
+            return True
+    return False
+
+
+def _getdata_by_idx(data, idx):
+    """Shuffle the data."""
+    shuffle_data = []
+
+    for k, v in data:
+        if (isinstance(v, h5py.Dataset) if h5py else False):
+            shuffle_data.append((k, v))
+        elif isinstance(v, CSRNDArray):
+            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
+        else:
+            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
+
+    return shuffle_data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index 872763f..0641f23 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -17,6 +17,7 @@
 
 # pylint: skip-file
 import mxnet as mx
+import mxnet.ndarray as nd
 from mxnet.test_utils import *
 from mxnet.base import MXNetError
 import numpy as np
@@ -106,79 +107,139 @@ def test_image_iter_exception():
             pass
     assertRaises(MXNetError, check_cifar10_exception)
 
-def test_NDArrayIter():
-    data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
+def _init_NDArrayIter_data(data_type, is_image=False):
+    if is_image:
+        data = nd.random.uniform(0, 255, shape=(5000, 1, 28, 28))
+        labels = nd.ones((5000, 1))
+        return data, labels
+    if data_type == 'NDArray':
+        data = nd.ones((1000, 2, 2))
+        labels = nd.ones((1000, 1))
+    else:
+        data = np.ones((1000, 2, 2))
+        labels = np.ones((1000, 1))
     for i in range(1000):
         data[i] = i / 100
-        label[i] = i / 100
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, True, last_batch_handle='pad')
-    batchidx = 0
+        labels[i] = i / 100
+    return data, labels
+
+
+def _test_last_batch_handle(data, labels=None, is_image=False):
+    # Test the three parameters 'pad', 'discard', 'roll_over'
+    last_batch_handle_list = ['pad', 'discard', 'roll_over']
+    if labels is not None and not is_image and len(labels) != 0:
+        labelcount_list = [(124, 100), (100, 96), (100, 96)]
+    if is_image:
+        batch_count_list = [40, 39, 39]
+    else:
+        batch_count_list = [8, 7, 7]
+    
+    for idx in range(len(last_batch_handle_list)):
+        dataiter = mx.io.NDArrayIter(
+            data, labels, 128, False, 
last_batch_handle=last_batch_handle_list[idx])
+        batch_count = 0
+        if labels is not None and len(labels) != 0 and not is_image:
+            labelcount = [0 for i in range(10)]
+        for batch in dataiter:
+            if len(data) == 2:
+                assert len(batch.data) == 2
+            if labels is not None and len(labels) != 0:
+                if not is_image:
+                    label = batch.label[0].asnumpy().flatten()
+                    # check data if it matches corresponding labels
+                    assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
+                    for i in range(label.shape[0]):
+                       labelcount[int(label[i])] += 1
+            else:
+                assert not batch.label, 'label is not empty list'
+            # keep the last batch of 'pad' to be used later 
+            # to test first batch of roll_over in second iteration
+            batch_count += 1
+            if last_batch_handle_list[idx] == 'pad' and \
+                batch_count == batch_count_list[0]:
+                cache = batch.data[0].asnumpy()
+        # check if batchifying functionality work properly
+        if labels is not None and len(labels) != 0 and not is_image:
+            assert labelcount[0] == labelcount_list[idx][0], 
last_batch_handle_list[idx]
+            assert labelcount[8] == labelcount_list[idx][1], 
last_batch_handle_list[idx]
+        assert batch_count == batch_count_list[idx]
+    # roll_over option
+    dataiter.reset()
+    assert np.array_equal(dataiter.next().data[0].asnumpy(), cache)
+
+
+def _test_shuffle(data, labels=None):
+    dataiter = mx.io.NDArrayIter(data, labels, 1, False)
+    batch_list = []
     for batch in dataiter:
-        batchidx += 1
-    assert(batchidx == 8)
-    dataiter = mx.io.NDArrayIter(
-        data, label, 128, False, last_batch_handle='pad')
-    batchidx = 0
-    labelcount = [0 for i in range(10)]
+        # cache the original data
+        batch_list.append(batch.data[0].asnumpy())
+    dataiter = mx.io.NDArrayIter(data, labels, 1, True)
+    idx_list = dataiter.idx
+    i = 0
     for batch in dataiter:
-        label = batch.label[0].asnumpy().flatten()
-        assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-        for i in range(label.shape[0]):
-            labelcount[int(label[i])] += 1
+        # check if each data point have been shuffled to corresponding 
positions
+        assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
+        i += 1
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
+
+def test_NDArrayIter():
+    dtype_list = ['NDArray', 'ndarray']
+    tested_data_type = [False, True]
+    for dtype in dtype_list:
+        for is_image in tested_data_type:
+            data, labels = _init_NDArrayIter_data(dtype, is_image)
+            _test_last_batch_handle(data, labels, is_image)
+            _test_last_batch_handle([data, data], labels, is_image)
+            _test_last_batch_handle(data=[data, data], is_image=is_image)
+            _test_last_batch_handle(
+                {'data1': data, 'data2': data}, labels, is_image)
+            _test_last_batch_handle(data={'data1': data, 'data2': data}, 
is_image=is_image)
+            _test_last_batch_handle(data, [], is_image)
+            _test_last_batch_handle(data=data, is_image=is_image)
+            _test_shuffle(data, labels)
+            _test_shuffle([data, data], labels)
+            _test_shuffle([data, data])
+            _test_shuffle({'data1': data, 'data2': data}, labels)
+            _test_shuffle({'data1': data, 'data2': data})
+            _test_shuffle(data, [])
+            _test_shuffle(data)
 
 
 def test_NDArrayIter_h5py():
     if not h5py:
         return
 
-    data = np.ones([1000, 2, 2])
-    label = np.ones([1000, 1])
-    for i in range(1000):
-        data[i] = i / 100
-        label[i] = i / 100
+    data, labels = _init_NDArrayIter_data('ndarray')
 
     try:
-        os.remove("ndarraytest.h5")
+        os.remove('ndarraytest.h5')
     except OSError:
         pass
-    with h5py.File("ndarraytest.h5") as f:
-        f.create_dataset("data", data=data)
-        f.create_dataset("label", data=label)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, True, last_batch_handle='pad')
-        batchidx = 0
-        for batch in dataiter:
-            batchidx += 1
-        assert(batchidx == 8)
-
-        dataiter = mx.io.NDArrayIter(
-            f["data"], f["label"], 128, False, last_batch_handle='pad')
-        labelcount = [0 for i in range(10)]
-        for batch in dataiter:
-            label = batch.label[0].asnumpy().flatten()
-            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
-            for i in range(label.shape[0]):
-                labelcount[int(label[i])] += 1
-
+    with h5py.File('ndarraytest.h5') as f:
+        f.create_dataset('data', data=data)
+        f.create_dataset('label', data=labels)
+        
+        _test_last_batch_handle(f['data'], f['label'])
+        _test_last_batch_handle(f['data'], [])
+        _test_last_batch_handle(f['data'])
     try:
         os.remove("ndarraytest.h5")
     except OSError:
         pass
 
-    for i in range(10):
-        if i == 0:
-            assert(labelcount[i] == 124)
-        else:
-            assert(labelcount[i] == 100)
+
+def _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list, csr_iter_None, 
num_rows, batch_size):
+    num_batch = 0
+    for _, batch_empty_list, batch_empty_None in zip(csr_iter, 
csr_iter_empty_list, csr_iter_None):
+        assert not batch_empty_list.label, 'label is not empty list'
+        assert not batch_empty_None.label, 'label is not empty list'
+        num_batch += 1
+
+    assert(num_batch == num_rows // batch_size)
+    assertRaises(StopIteration, csr_iter.next)
+    assertRaises(StopIteration, csr_iter_empty_list.next)
+    assertRaises(StopIteration, csr_iter_None.next)
 
 
 def test_NDArrayIter_csr():
@@ -200,15 +261,26 @@ def test_NDArrayIter_csr():
                      {'data': train_data}, dns, batch_size)
     except ImportError:
         pass
+    
+    # scipy.sparse.csr_matrix with shuffle
+    csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    csr_iter_empty_list = iter(mx.io.NDArrayIter({'data': train_data}, [], 
batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    csr_iter_None = iter(mx.io.NDArrayIter({'data': train_data}, None, 
batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list,
+                          csr_iter_None, num_rows, batch_size)
 
     # CSRNDArray with shuffle
     csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, 
batch_size,
                                       shuffle=True, 
last_batch_handle='discard'))
-    num_batch = 0
-    for batch in csr_iter:
-        num_batch += 1
-
-    assert(num_batch == num_rows // batch_size)
+    csr_iter_empty_list = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': 
dns}, [], batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    csr_iter_None = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, 
None, batch_size,
+                                      shuffle=True, 
last_batch_handle='discard'))
+    _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list,
+                          csr_iter_None, num_rows, batch_size)
 
     # make iterators
     csr_iter = iter(mx.io.NDArrayIter(

Reply via email to