This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 8ff50c9  Revert "Change the way NDArrayIter handle the last batch" 
(#12537)
8ff50c9 is described below

commit 8ff50c95201e02e849e0592de5fb7af87489be53
Author: Joshua Z. Zhang <[email protected]>
AuthorDate: Wed Sep 12 13:32:57 2018 -0700

    Revert "Change the way NDArrayIter handle the last batch" (#12537)
    
    * Revert "Removing the re-size for validation data, which breaking the 
validation accuracy of CIFAR training (#12362)"
    
    This reverts commit ceabcaac77543d99246415b2fb2d8c973a830453.
    
    * Revert "[MXNET-580] Add SN-GAN example (#12419)"
    
    This reverts commit 46a5cee2515a1ac0a1ae5afbe7e639debb998587.
    
    * Revert "Remove regression checks for website links (#12507)"
    
    This reverts commit 619bc3ea3c9093b72634d16e91596b3a65f3f1fc.
    
    * Revert "Revert "Fix flaky test: test_mkldnn.test_activation #12377 
(#12418)" (#12516)"
    
    This reverts commit 7ea05333efc8ca868443b89233b101d068f6af9f.
    
    * Revert "further bump up tolerance for sparse dot (#12527)"
    
    This reverts commit 90599e1038a4ff6604e9ed0d55dc274c2df635f8.
    
    * Revert "Fix broken URLs (#12508)"
    
    This reverts commit 3d83c896fd8b237c53003888e35a4d792c1e5389.
    
    * Revert "Temporarily disable flaky tests (#12520)"
    
    This reverts commit 35ca13c3b5a0e57d904d1fead079152a15dfeac4.
    
    * Revert "Add support for more req patterns for bilinear sampler backward 
(#12386)"
    
    This reverts commit 4ee866fc75307b284cc0eae93d0cf4dad3b62533.
    
    * Revert "Change the way NDArrayIter handle the last batch (#12285)"
    
    This reverts commit 597a637fb1b8fa5b16331218cda8be61ce0ee202.
---
 CONTRIBUTORS.md                  |   1 -
 python/mxnet/{io => }/io.py      | 280 +++++++++++++++++++--------------------
 python/mxnet/io/__init__.py      |  29 ----
 python/mxnet/io/utils.py         |  86 ------------
 tests/python/unittest/test_io.py | 122 ++++++++---------
 5 files changed, 190 insertions(+), 328 deletions(-)

diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 1c005d5..8d8aeac 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -178,4 +178,3 @@ List of Contributors
 * [Aaron Markham](https://github.com/aaronmarkham)
 * [Sam Skalicky](https://github.com/samskalicky)
 * [Per Goncalves da Silva](https://github.com/perdasilva)
-* [Cheng-Che Lee](https://github.com/stu1130)
diff --git a/python/mxnet/io/io.py b/python/mxnet/io.py
similarity index 82%
rename from python/mxnet/io/io.py
rename to python/mxnet/io.py
index 2ae3e70..884e929 100644
--- a/python/mxnet/io/io.py
+++ b/python/mxnet/io.py
@@ -17,26 +17,30 @@
 
 """Data iterators for common data formats."""
 from __future__ import absolute_import
-from collections import namedtuple
+from collections import OrderedDict, namedtuple
 
 import sys
 import ctypes
 import logging
 import threading
+try:
+    import h5py
+except ImportError:
+    h5py = None
 import numpy as np
-
-from ..base import _LIB
-from ..base import c_str_array, mx_uint, py_str
-from ..base import DataIterHandle, NDArrayHandle
-from ..base import mx_real_t
-from ..base import check_call, build_param_doc as _build_param_doc
-from ..ndarray import NDArray
-from ..ndarray.sparse import CSRNDArray
-from ..ndarray import _ndarray_cls
-from ..ndarray import array
-from ..ndarray import concat
-
-from .utils import init_data, has_instance, getdata_by_idx
+from .base import _LIB
+from .base import c_str_array, mx_uint, py_str
+from .base import DataIterHandle, NDArrayHandle
+from .base import mx_real_t
+from .base import check_call, build_param_doc as _build_param_doc
+from .ndarray import NDArray
+from .ndarray.sparse import CSRNDArray
+from .ndarray.sparse import array as sparse_array
+from .ndarray import _ndarray_cls
+from .ndarray import array
+from .ndarray import concatenate
+from .ndarray import arange
+from .ndarray.random import shuffle as random_shuffle
 
 class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
     """DataDesc is used to store name, shape, type and layout
@@ -485,6 +489,59 @@ class PrefetchingIter(DataIter):
     def getpad(self):
         return self.current_batch.pad
 
+def _init_data(data, allow_empty, default_name):
+    """Convert data into canonical form."""
+    assert (data is not None) or allow_empty
+    if data is None:
+        data = []
+
+    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
+                  if h5py else (np.ndarray, NDArray)):
+        data = [data]
+    if isinstance(data, list):
+        if not allow_empty:
+            assert(len(data) > 0)
+        if len(data) == 1:
+            data = OrderedDict([(default_name, data[0])]) # pylint: 
disable=redefined-variable-type
+        else:
+            data = OrderedDict( # pylint: disable=redefined-variable-type
+                [('_%d_%s' % (i, default_name), d) for i, d in 
enumerate(data)])
+    if not isinstance(data, dict):
+        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " 
+ \
+                "a list of them or dict with them as values")
+    for k, v in data.items():
+        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
+            try:
+                data[k] = array(v)
+            except:
+                raise TypeError(("Invalid type '%s' for %s, "  % (type(v), k)) 
+ \
+                                "should be NDArray, numpy.ndarray or 
h5py.Dataset")
+
+    return list(sorted(data.items()))
+
+def _has_instance(data, dtype):
+    """Return True if ``data`` has instance of ``dtype``.
+    This function is called after _init_data.
+    ``data`` is a list of (str, NDArray)"""
+    for item in data:
+        _, arr = item
+        if isinstance(arr, dtype):
+            return True
+    return False
+
+def _shuffle(data, idx):
+    """Shuffle the data."""
+    shuffle_data = []
+
+    for k, v in data:
+        if (isinstance(v, h5py.Dataset) if h5py else False):
+            shuffle_data.append((k, v))
+        elif isinstance(v, CSRNDArray):
+            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
+        else:
+            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
+
+    return shuffle_data
 
 class NDArrayIter(DataIter):
     """Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``, 
``h5py.Dataset``
@@ -544,22 +601,6 @@ class NDArrayIter(DataIter):
     ...
     >>> batchidx # Remaining examples are discarded. So, 10/3 batches are 
created.
     3
-    >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False, 
last_batch_handle='roll_over')
-    >>> batchidx = 0
-    >>> for batch in dataiter:
-    ...     batchidx += 1
-    ...
-    >>> batchidx # Remaining examples are rolled over to the next iteration.
-    3
-    >>> dataiter.reset()
-    >>> dataiter.next().data[0].asnumpy()
-    [[[ 36.  37.]
-      [ 38.  39.]]
-     [[ 0.  1.]
-      [ 2.  3.]]
-     [[ 4.  5.]
-      [ 6.  7.]]]
-    (3L, 2L, 2L)
 
     `NDArrayIter` also supports multiple input and labels.
 
@@ -592,11 +633,8 @@ class NDArrayIter(DataIter):
         Only supported if no h5py.Dataset inputs are used.
     last_batch_handle : str, optional
         How to handle the last batch. This parameter can be 'pad', 'discard' or
-        'roll_over'.
-        If 'pad', the last batch will be padded with data starting from the 
begining
-        If 'discard', the last batch will be discarded
-        If 'roll_over', the remaining elements will be rolled over to the next 
iteration and
-        note that it is intended for training and can cause problems if used 
for prediction.
+        'roll_over'. 'roll_over' is intended for training and can cause 
problems
+        if used for prediction.
     data_name : str, optional
         The data name.
     label_name : str, optional
@@ -607,28 +645,36 @@ class NDArrayIter(DataIter):
                  label_name='softmax_label'):
         super(NDArrayIter, self).__init__(batch_size)
 
-        self.data = init_data(data, allow_empty=False, default_name=data_name)
-        self.label = init_data(label, allow_empty=True, 
default_name=label_name)
+        self.data = _init_data(data, allow_empty=False, default_name=data_name)
+        self.label = _init_data(label, allow_empty=True, 
default_name=label_name)
 
-        if ((has_instance(self.data, CSRNDArray) or has_instance(self.label, 
CSRNDArray)) and
+        if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label, 
CSRNDArray)) and
                 (last_batch_handle != 'discard')):
             raise NotImplementedError("`NDArrayIter` only supports 
``CSRNDArray``" \
                                       " with `last_batch_handle` set to 
`discard`.")
 
-        self.idx = np.arange(self.data[0][1].shape[0])
-        self.shuffle = shuffle
-        self.last_batch_handle = last_batch_handle
-        self.batch_size = batch_size
-        self.cursor = -self.batch_size
-        self.num_data = self.idx.shape[0]
-        # shuffle
-        self.reset()
+        # shuffle data
+        if shuffle:
+            tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
+            self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
+            self.data = _shuffle(self.data, self.idx)
+            self.label = _shuffle(self.label, self.idx)
+        else:
+            self.idx = np.arange(self.data[0][1].shape[0])
+
+        # batching
+        if last_batch_handle == 'discard':
+            new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] % 
batch_size
+            self.idx = self.idx[:new_n]
 
         self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
         self.num_source = len(self.data_list)
-        # used for 'roll_over'
-        self._cache_data = None
-        self._cache_label = None
+        self.num_data = self.idx.shape[0]
+        assert self.num_data >= batch_size, \
+            "batch_size needs to be smaller than data size."
+        self.cursor = -batch_size
+        self.batch_size = batch_size
+        self.last_batch_handle = last_batch_handle
 
     @property
     def provide_data(self):
@@ -648,126 +694,74 @@ class NDArrayIter(DataIter):
 
     def hard_reset(self):
         """Ignore roll over data and set to start."""
-        if self.shuffle:
-            self._shuffle_data()
         self.cursor = -self.batch_size
-        self._cache_data = None
-        self._cache_label = None
 
     def reset(self):
-        """Resets the iterator to the beginning of the data."""
-        if self.shuffle:
-            self._shuffle_data()
-        # the range below indicate the last batch
-        if self.last_batch_handle == 'roll_over' and \
-            self.num_data - self.batch_size < self.cursor < self.num_data:
-            # (self.cursor - self.num_data) represents the data we have for 
the last batch
-            self.cursor = self.cursor - self.num_data - self.batch_size
+        if self.last_batch_handle == 'roll_over' and self.cursor > 
self.num_data:
+            self.cursor = -self.batch_size + 
(self.cursor%self.num_data)%self.batch_size
         else:
             self.cursor = -self.batch_size
 
     def iter_next(self):
-        """Increments the coursor by batch_size for next batch
-        and check current cursor if it exceed the number of data points."""
         self.cursor += self.batch_size
         return self.cursor < self.num_data
 
     def next(self):
-        """Returns the next batch of data."""
-        if not self.iter_next():
-            raise StopIteration
-        data = self.getdata()
-        label = self.getlabel()
-        # iter should stop when last batch is not complete
-        if data[0].shape[0] != self.batch_size:
-        # in this case, cache it for next epoch
-            self._cache_data = data
-            self._cache_label = label
+        if self.iter_next():
+            return DataBatch(data=self.getdata(), label=self.getlabel(), \
+                    pad=self.getpad(), index=None)
+        else:
             raise StopIteration
-        return DataBatch(data=data, label=label, \
-            pad=self.getpad(), index=None)
-
-    def _getdata(self, data_source, start=None, end=None):
-        """Load data from underlying arrays."""
-        assert start is not None or end is not None, 'should at least specify 
start or end'
-        start = start if start is not None else 0
-        end = end if end is not None else data_source[0][1].shape[0]
-        s = slice(start, end)
-        return [
-            x[1][s]
-            if isinstance(x[1], (np.ndarray, NDArray)) else
-            # h5py (only supports indices in increasing order)
-            array(x[1][sorted(self.idx[s])][[
-                list(self.idx[s]).index(i)
-                for i in sorted(self.idx[s])
-            ]]) for x in data_source
-        ]
 
-    def _concat(self, first_data, second_data):
-        """Helper function to concat two NDArrays."""
-        return [
-            concat(first_data[0], second_data[0], dim=0)
-        ]
-
-    def _batchify(self, data_source):
+    def _getdata(self, data_source):
         """Load data from underlying arrays, internal use only."""
-        assert self.cursor < self.num_data, 'DataIter needs reset.'
-        # first batch of next epoch with 'roll_over'
-        if self.last_batch_handle == 'roll_over' and \
-            -self.batch_size < self.cursor < 0:
-            assert self._cache_data is not None or self._cache_label is not 
None, \
-                'next epoch should have cached data'
-            cache_data = self._cache_data if self._cache_data is not None else 
self._cache_label
-            second_data = self._getdata(
-                data_source, end=self.cursor + self.batch_size)
-            if self._cache_data is not None:
-                self._cache_data = None
-            else:
-                self._cache_label = None
-            return self._concat(cache_data, second_data)
-        # last batch with 'pad'
-        elif self.last_batch_handle == 'pad' and \
-            self.cursor + self.batch_size > self.num_data:
-            pad = self.batch_size - self.num_data + self.cursor
-            first_data = self._getdata(data_source, start=self.cursor)
-            second_data = self._getdata(data_source, end=pad)
-            return self._concat(first_data, second_data)
-        # normal case
+        assert(self.cursor < self.num_data), "DataIter needs reset."
+        if self.cursor + self.batch_size <= self.num_data:
+            return [
+                # np.ndarray or NDArray case
+                x[1][self.cursor:self.cursor + self.batch_size]
+                if isinstance(x[1], (np.ndarray, NDArray)) else
+                # h5py (only supports indices in increasing order)
+                array(x[1][sorted(self.idx[
+                    self.cursor:self.cursor + self.batch_size])][[
+                        list(self.idx[self.cursor:
+                                      self.cursor + self.batch_size]).index(i)
+                        for i in sorted(self.idx[
+                            self.cursor:self.cursor + self.batch_size])
+                    ]]) for x in data_source
+            ]
         else:
-            if self.cursor + self.batch_size < self.num_data:
-                end_idx = self.cursor + self.batch_size
-            # get incomplete last batch
-            else:
-                end_idx = self.num_data
-            return self._getdata(data_source, self.cursor, end_idx)
+            pad = self.batch_size - self.num_data + self.cursor
+            return [
+                # np.ndarray or NDArray case
+                concatenate([x[1][self.cursor:], x[1][:pad]])
+                if isinstance(x[1], (np.ndarray, NDArray)) else
+                # h5py (only supports indices in increasing order)
+                concatenate([
+                    array(x[1][sorted(self.idx[self.cursor:])][[
+                        list(self.idx[self.cursor:]).index(i)
+                        for i in sorted(self.idx[self.cursor:])
+                    ]]),
+                    array(x[1][sorted(self.idx[:pad])][[
+                        list(self.idx[:pad]).index(i)
+                        for i in sorted(self.idx[:pad])
+                    ]])
+                ]) for x in data_source
+            ]
 
     def getdata(self):
-        """Get data."""
-        return self._batchify(self.data)
+        return self._getdata(self.data)
 
     def getlabel(self):
-        """Get label."""
-        return self._batchify(self.label)
+        return self._getdata(self.label)
 
     def getpad(self):
-        """Get pad value of DataBatch."""
         if self.last_batch_handle == 'pad' and \
            self.cursor + self.batch_size > self.num_data:
             return self.cursor + self.batch_size - self.num_data
-        # check the first batch
-        elif self.last_batch_handle == 'roll_over' and \
-            -self.batch_size < self.cursor < 0:
-            return -self.cursor
         else:
             return 0
 
-    def _shuffle_data(self):
-        """Shuffle the data."""
-        # shuffle index
-        np.random.shuffle(self.idx)
-        # get the data by corresponding index
-        self.data = getdata_by_idx(self.data, self.idx)
-        self.label = getdata_by_idx(self.label, self.idx)
 
 class MXDataIter(DataIter):
     """A python wrapper a C++ data iterator.
@@ -779,7 +773,7 @@ class MXDataIter(DataIter):
     underlying C++ data iterators.
 
     Usually you don't need to interact with `MXDataIter` directly unless you 
are
-    implementing your own data iterators in C+ +. To do that, please refer to
+    implementing your own data iterators in C++. To do that, please refer to
     examples under the `src/io` folder.
 
     Parameters
diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py
deleted file mode 100644
index 5c5e2e6..0000000
--- a/python/mxnet/io/__init__.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#!/usr/bin/env python
-
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# coding: utf-8
-# pylint: disable=wildcard-import
-""" Data iterators for common data formats and utility functions."""
-from __future__ import absolute_import
-
-from . import io
-from .io import *
-
-from . import utils
-from .utils import *
diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py
deleted file mode 100644
index 872e641..0000000
--- a/python/mxnet/io/utils.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""utility functions for io.py"""
-from collections import OrderedDict
-
-import numpy as np
-try:
-    import h5py
-except ImportError:
-    h5py = None
-
-from ..ndarray.sparse import CSRNDArray
-from ..ndarray.sparse import array as sparse_array
-from ..ndarray import NDArray
-from ..ndarray import array
-
-def init_data(data, allow_empty, default_name):
-    """Convert data into canonical form."""
-    assert (data is not None) or allow_empty
-    if data is None:
-        data = []
-
-    if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
-                  if h5py else (np.ndarray, NDArray)):
-        data = [data]
-    if isinstance(data, list):
-        if not allow_empty:
-            assert(len(data) > 0)
-        if len(data) == 1:
-            data = OrderedDict([(default_name, data[0])])  # pylint: 
disable=redefined-variable-type
-        else:
-            data = OrderedDict(  # pylint: disable=redefined-variable-type
-                [('_%d_%s' % (i, default_name), d) for i, d in 
enumerate(data)])
-    if not isinstance(data, dict):
-        raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " +
-                        "a list of them or dict with them as values")
-    for k, v in data.items():
-        if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
-            try:
-                data[k] = array(v)
-            except:
-                raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) +
-                                "should be NDArray, numpy.ndarray or 
h5py.Dataset")
-
-    return list(sorted(data.items()))
-
-
-def has_instance(data, dtype):
-    """Return True if ``data`` has instance of ``dtype``.
-    This function is called after _init_data.
-    ``data`` is a list of (str, NDArray)"""
-    for item in data:
-        _, arr = item
-        if isinstance(arr, dtype):
-            return True
-    return False
-
-
-def getdata_by_idx(data, idx):
-    """Shuffle the data."""
-    shuffle_data = []
-
-    for k, v in data:
-        if (isinstance(v, h5py.Dataset) if h5py else False):
-            shuffle_data.append((k, v))
-        elif isinstance(v, CSRNDArray):
-            shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
-        else:
-            shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
-
-    return shuffle_data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index ae68626..4dfa69c 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -88,88 +88,80 @@ def test_Cifar10Rec():
         assert(labelcount[i] == 5000)
 
 
-def _init_NDArrayIter_data():
+def test_NDArrayIter():
     data = np.ones([1000, 2, 2])
-    labels = np.ones([1000, 1])
+    label = np.ones([1000, 1])
     for i in range(1000):
         data[i] = i / 100
-        labels[i] = i / 100
-    return data, labels
-
-
-def _test_last_batch_handle(data, labels):
-    # Test the three parameters 'pad', 'discard', 'roll_over'
-    last_batch_handle_list = ['pad', 'discard' , 'roll_over']
-    labelcount_list = [(124, 100), (100, 96), (100, 96)]
-    batch_count_list = [8, 7, 7]
-    
-    for idx in range(len(last_batch_handle_list)):
-        dataiter = mx.io.NDArrayIter(
-            data, labels, 128, False, 
last_batch_handle=last_batch_handle_list[idx])
-        batch_count = 0
-        labelcount = [0 for i in range(10)]
-        for batch in dataiter:
-            label = batch.label[0].asnumpy().flatten()
-            # check data if it matches corresponding labels
-            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()), 
last_batch_handle_list[idx]
-            for i in range(label.shape[0]):
-                labelcount[int(label[i])] += 1
-            # keep the last batch of 'pad' to be used later 
-            # to test first batch of roll_over in second iteration
-            batch_count += 1
-            if last_batch_handle_list[idx] == 'pad' and \
-                batch_count == 8:
-                cache = batch.data[0].asnumpy()
-        # check if batchifying functionality work properly
-        assert labelcount[0] == labelcount_list[idx][0], 
last_batch_handle_list[idx]
-        assert labelcount[8] == labelcount_list[idx][1], 
last_batch_handle_list[idx]
-        assert batch_count == batch_count_list[idx]
-    # roll_over option
-    dataiter.reset()
-    assert np.array_equal(dataiter.next().data[0].asnumpy(), cache)
-
-
-def _test_shuffle(data, labels):
-    dataiter = mx.io.NDArrayIter(data, labels, 1, False)
-    batch_list = []
+        label[i] = i / 100
+    dataiter = mx.io.NDArrayIter(
+        data, label, 128, True, last_batch_handle='pad')
+    batchidx = 0
     for batch in dataiter:
-        # cache the original data
-        batch_list.append(batch.data[0].asnumpy())
-    dataiter = mx.io.NDArrayIter(data, labels, 1, True)
-    idx_list = dataiter.idx
-    i = 0
+        batchidx += 1
+    assert(batchidx == 8)
+    dataiter = mx.io.NDArrayIter(
+        data, label, 128, False, last_batch_handle='pad')
+    batchidx = 0
+    labelcount = [0 for i in range(10)]
     for batch in dataiter:
-        # check if each data point have been shuffled to corresponding 
positions
-        assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
-        i += 1
-
+        label = batch.label[0].asnumpy().flatten()
+        assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
+        for i in range(label.shape[0]):
+            labelcount[int(label[i])] += 1
 
-def test_NDArrayIter():
-    data, labels = _init_NDArrayIter_data()
-    _test_last_batch_handle(data, labels)
-    _test_shuffle(data, labels)
+    for i in range(10):
+        if i == 0:
+            assert(labelcount[i] == 124)
+        else:
+            assert(labelcount[i] == 100)
 
 
 def test_NDArrayIter_h5py():
     if not h5py:
         return
 
-    data, labels = _init_NDArrayIter_data()
+    data = np.ones([1000, 2, 2])
+    label = np.ones([1000, 1])
+    for i in range(1000):
+        data[i] = i / 100
+        label[i] = i / 100
 
     try:
-        os.remove('ndarraytest.h5')
+        os.remove("ndarraytest.h5")
     except OSError:
         pass
-    with h5py.File('ndarraytest.h5') as f:
-        f.create_dataset('data', data=data)
-        f.create_dataset('label', data=labels)
+    with h5py.File("ndarraytest.h5") as f:
+        f.create_dataset("data", data=data)
+        f.create_dataset("label", data=label)
+
+        dataiter = mx.io.NDArrayIter(
+            f["data"], f["label"], 128, True, last_batch_handle='pad')
+        batchidx = 0
+        for batch in dataiter:
+            batchidx += 1
+        assert(batchidx == 8)
+
+        dataiter = mx.io.NDArrayIter(
+            f["data"], f["label"], 128, False, last_batch_handle='pad')
+        labelcount = [0 for i in range(10)]
+        for batch in dataiter:
+            label = batch.label[0].asnumpy().flatten()
+            assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
+            for i in range(label.shape[0]):
+                labelcount[int(label[i])] += 1
 
-        _test_last_batch_handle(f['data'], f['label'])
     try:
         os.remove("ndarraytest.h5")
     except OSError:
         pass
 
+    for i in range(10):
+        if i == 0:
+            assert(labelcount[i] == 124)
+        else:
+            assert(labelcount[i] == 100)
+
 
 def test_NDArrayIter_csr():
     # creating toy data
@@ -190,20 +182,12 @@ def test_NDArrayIter_csr():
                      {'data': train_data}, dns, batch_size)
     except ImportError:
         pass
-    # scipy.sparse.csr_matrix with shuffle
-    num_batch = 0
-    csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size,
-                                      shuffle=True, 
last_batch_handle='discard'))
-    for _ in csr_iter:
-        num_batch += 1
-
-    assert(num_batch == num_rows // batch_size)
 
     # CSRNDArray with shuffle
     csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns, 
batch_size,
                                       shuffle=True, 
last_batch_handle='discard'))
     num_batch = 0
-    for _ in csr_iter:
+    for batch in csr_iter:
         num_batch += 1
 
     assert(num_batch == num_rows // batch_size)

Reply via email to