This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 13030b6 Change the way NDArrayIter handle the last batch (#12545)
13030b6 is described below
commit 13030b6e35ee9bbdaf674e6e5828069edf9b85f5
Author: Jake Lee <[email protected]>
AuthorDate: Tue Oct 16 10:00:47 2018 -0700
Change the way NDArrayIter handle the last batch (#12545)
* 1. move the shuffle to the reset 2. modify the roll_over behavior
accordingly
* refactor the concat part
* refactor the code
* implement unit test for last_batch_handle
* refactor the getdata part
* add docstring and refine the code according to linter
* 1. add test case for NDArrayIter_h5py 2. refactor the implementation
* update contributions doc
* fix wording
* update doc for roll_over
* 1. add test for second iteration of roll_over 2. add shuffle test case
* fix some wording and refine the variables naming
* move utility function to new file
* move utility function to io_utils.py
* change shuffle function name to avoid redefining name
* make io as a module
* rename the utility functions
* disable wildcard-import
* fix the algorithm
* refactor the code
* test the NDArrayIter with different combinations of shuffle=True,
data_source type and lables
* add edge case of label data for csr NDArrayIter
* trigger Travis CI
* handle the 'list' of data source
* check the list of data source
* fix the extra blank
* Trigger CI
* add _ to the utility functions
* Trigger CI
* update several test cases
* add test case for airbnb
* fix the typo
* fix wrong labels data shape
* switch the order of condition to make more sense
---
python/mxnet/io/__init__.py | 29 ++++
python/mxnet/{ => io}/io.py | 284 +++++++++++++++++++++------------------
python/mxnet/io/utils.py | 86 ++++++++++++
tests/python/unittest/test_io.py | 188 ++++++++++++++++++--------
4 files changed, 398 insertions(+), 189 deletions(-)
diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py
new file mode 100644
index 0000000..5c5e2e6
--- /dev/null
+++ b/python/mxnet/io/__init__.py
@@ -0,0 +1,29 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+# coding: utf-8
+# pylint: disable=wildcard-import
+""" Data iterators for common data formats and utility functions."""
+from __future__ import absolute_import
+
+from . import io
+from .io import *
+
+from . import utils
+from .utils import *
diff --git a/python/mxnet/io.py b/python/mxnet/io/io.py
similarity index 81%
rename from python/mxnet/io.py
rename to python/mxnet/io/io.py
index 884e929..20da2ea 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io/io.py
@@ -17,30 +17,26 @@
"""Data iterators for common data formats."""
from __future__ import absolute_import
-from collections import OrderedDict, namedtuple
+from collections import namedtuple
import sys
import ctypes
import logging
import threading
-try:
- import h5py
-except ImportError:
- h5py = None
import numpy as np
-from .base import _LIB
-from .base import c_str_array, mx_uint, py_str
-from .base import DataIterHandle, NDArrayHandle
-from .base import mx_real_t
-from .base import check_call, build_param_doc as _build_param_doc
-from .ndarray import NDArray
-from .ndarray.sparse import CSRNDArray
-from .ndarray.sparse import array as sparse_array
-from .ndarray import _ndarray_cls
-from .ndarray import array
-from .ndarray import concatenate
-from .ndarray import arange
-from .ndarray.random import shuffle as random_shuffle
+
+from ..base import _LIB
+from ..base import c_str_array, mx_uint, py_str
+from ..base import DataIterHandle, NDArrayHandle
+from ..base import mx_real_t
+from ..base import check_call, build_param_doc as _build_param_doc
+from ..ndarray import NDArray
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray import _ndarray_cls
+from ..ndarray import array
+from ..ndarray import concat
+
+from .utils import _init_data, _has_instance, _getdata_by_idx
class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
@@ -489,59 +485,6 @@ class PrefetchingIter(DataIter):
def getpad(self):
return self.current_batch.pad
-def _init_data(data, allow_empty, default_name):
- """Convert data into canonical form."""
- assert (data is not None) or allow_empty
- if data is None:
- data = []
-
- if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
- if h5py else (np.ndarray, NDArray)):
- data = [data]
- if isinstance(data, list):
- if not allow_empty:
- assert(len(data) > 0)
- if len(data) == 1:
- data = OrderedDict([(default_name, data[0])]) # pylint:
disable=redefined-variable-type
- else:
- data = OrderedDict( # pylint: disable=redefined-variable-type
- [('_%d_%s' % (i, default_name), d) for i, d in
enumerate(data)])
- if not isinstance(data, dict):
- raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset "
+ \
- "a list of them or dict with them as values")
- for k, v in data.items():
- if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
- try:
- data[k] = array(v)
- except:
- raise TypeError(("Invalid type '%s' for %s, " % (type(v), k))
+ \
- "should be NDArray, numpy.ndarray or
h5py.Dataset")
-
- return list(sorted(data.items()))
-
-def _has_instance(data, dtype):
- """Return True if ``data`` has instance of ``dtype``.
- This function is called after _init_data.
- ``data`` is a list of (str, NDArray)"""
- for item in data:
- _, arr = item
- if isinstance(arr, dtype):
- return True
- return False
-
-def _shuffle(data, idx):
- """Shuffle the data."""
- shuffle_data = []
-
- for k, v in data:
- if (isinstance(v, h5py.Dataset) if h5py else False):
- shuffle_data.append((k, v))
- elif isinstance(v, CSRNDArray):
- shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
- else:
- shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
-
- return shuffle_data
class NDArrayIter(DataIter):
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``,
``h5py.Dataset``
@@ -601,6 +544,22 @@ class NDArrayIter(DataIter):
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are
created.
3
+ >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False,
last_batch_handle='roll_over')
+ >>> batchidx = 0
+ >>> for batch in dataiter:
+ ... batchidx += 1
+ ...
+ >>> batchidx # Remaining examples are rolled over to the next iteration.
+ 3
+ >>> dataiter.reset()
+ >>> dataiter.next().data[0].asnumpy()
+ [[[ 36. 37.]
+ [ 38. 39.]]
+ [[ 0. 1.]
+ [ 2. 3.]]
+ [[ 4. 5.]
+ [ 6. 7.]]]
+ (3L, 2L, 2L)
`NDArrayIter` also supports multiple input and labels.
@@ -633,8 +592,11 @@ class NDArrayIter(DataIter):
Only supported if no h5py.Dataset inputs are used.
last_batch_handle : str, optional
How to handle the last batch. This parameter can be 'pad', 'discard' or
- 'roll_over'. 'roll_over' is intended for training and can cause
problems
- if used for prediction.
+ 'roll_over'.
+ If 'pad', the last batch will be padded with data starting from the
begining
+ If 'discard', the last batch will be discarded
+ If 'roll_over', the remaining elements will be rolled over to the next
iteration and
+ note that it is intended for training and can cause problems if used
for prediction.
data_name : str, optional
The data name.
label_name : str, optional
@@ -648,33 +610,26 @@ class NDArrayIter(DataIter):
self.data = _init_data(data, allow_empty=False, default_name=data_name)
self.label = _init_data(label, allow_empty=True,
default_name=label_name)
- if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label,
CSRNDArray)) and
+ if ((_has_instance(self.data, CSRNDArray) or
+ _has_instance(self.label, CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports
``CSRNDArray``" \
" with `last_batch_handle` set to
`discard`.")
- # shuffle data
- if shuffle:
- tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
- self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
- self.data = _shuffle(self.data, self.idx)
- self.label = _shuffle(self.label, self.idx)
- else:
- self.idx = np.arange(self.data[0][1].shape[0])
-
- # batching
- if last_batch_handle == 'discard':
- new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] %
batch_size
- self.idx = self.idx[:new_n]
+ self.idx = np.arange(self.data[0][1].shape[0])
+ self.shuffle = shuffle
+ self.last_batch_handle = last_batch_handle
+ self.batch_size = batch_size
+ self.cursor = -self.batch_size
+ self.num_data = self.idx.shape[0]
+ # shuffle
+ self.reset()
self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
self.num_source = len(self.data_list)
- self.num_data = self.idx.shape[0]
- assert self.num_data >= batch_size, \
- "batch_size needs to be smaller than data size."
- self.cursor = -batch_size
- self.batch_size = batch_size
- self.last_batch_handle = last_batch_handle
+ # used for 'roll_over'
+ self._cache_data = None
+ self._cache_label = None
@property
def provide_data(self):
@@ -694,74 +649,141 @@ class NDArrayIter(DataIter):
def hard_reset(self):
"""Ignore roll over data and set to start."""
+ if self.shuffle:
+ self._shuffle_data()
self.cursor = -self.batch_size
+ self._cache_data = None
+ self._cache_label = None
def reset(self):
- if self.last_batch_handle == 'roll_over' and self.cursor >
self.num_data:
- self.cursor = -self.batch_size +
(self.cursor%self.num_data)%self.batch_size
+ """Resets the iterator to the beginning of the data."""
+ if self.shuffle:
+ self._shuffle_data()
+ # the range below indicate the last batch
+ if self.last_batch_handle == 'roll_over' and \
+ self.num_data - self.batch_size < self.cursor < self.num_data:
+ # (self.cursor - self.num_data) represents the data we have for
the last batch
+ self.cursor = self.cursor - self.num_data - self.batch_size
else:
self.cursor = -self.batch_size
def iter_next(self):
+ """Increments the coursor by batch_size for next batch
+ and check current cursor if it exceed the number of data points."""
self.cursor += self.batch_size
return self.cursor < self.num_data
def next(self):
- if self.iter_next():
- return DataBatch(data=self.getdata(), label=self.getlabel(), \
- pad=self.getpad(), index=None)
- else:
+ """Returns the next batch of data."""
+ if not self.iter_next():
raise StopIteration
+ data = self.getdata()
+ label = self.getlabel()
+ # iter should stop when last batch is not complete
+ if data[0].shape[0] != self.batch_size:
+ # in this case, cache it for next epoch
+ self._cache_data = data
+ self._cache_label = label
+ raise StopIteration
+ return DataBatch(data=data, label=label, \
+ pad=self.getpad(), index=None)
+
+ def _getdata(self, data_source, start=None, end=None):
+ """Load data from underlying arrays."""
+ assert start is not None or end is not None, 'should at least specify
start or end'
+ start = start if start is not None else 0
+ if end is None:
+ end = data_source[0][1].shape[0] if data_source else 0
+ s = slice(start, end)
+ return [
+ x[1][s]
+ if isinstance(x[1], (np.ndarray, NDArray)) else
+ # h5py (only supports indices in increasing order)
+ array(x[1][sorted(self.idx[s])][[
+ list(self.idx[s]).index(i)
+ for i in sorted(self.idx[s])
+ ]]) for x in data_source
+ ]
- def _getdata(self, data_source):
- """Load data from underlying arrays, internal use only."""
- assert(self.cursor < self.num_data), "DataIter needs reset."
- if self.cursor + self.batch_size <= self.num_data:
+ def _concat(self, first_data, second_data):
+ """Helper function to concat two NDArrays."""
+ assert len(first_data) == len(
+ second_data), 'data source should contain the same size'
+ if first_data and second_data:
return [
- # np.ndarray or NDArray case
- x[1][self.cursor:self.cursor + self.batch_size]
- if isinstance(x[1], (np.ndarray, NDArray)) else
- # h5py (only supports indices in increasing order)
- array(x[1][sorted(self.idx[
- self.cursor:self.cursor + self.batch_size])][[
- list(self.idx[self.cursor:
- self.cursor + self.batch_size]).index(i)
- for i in sorted(self.idx[
- self.cursor:self.cursor + self.batch_size])
- ]]) for x in data_source
+ concat(
+ first_data[x],
+ second_data[x],
+ dim=0
+ ) for x in range(len(first_data))
]
+ elif (not first_data) and (not second_data):
+ return []
else:
- pad = self.batch_size - self.num_data + self.cursor
return [
- # np.ndarray or NDArray case
- concatenate([x[1][self.cursor:], x[1][:pad]])
- if isinstance(x[1], (np.ndarray, NDArray)) else
- # h5py (only supports indices in increasing order)
- concatenate([
- array(x[1][sorted(self.idx[self.cursor:])][[
- list(self.idx[self.cursor:]).index(i)
- for i in sorted(self.idx[self.cursor:])
- ]]),
- array(x[1][sorted(self.idx[:pad])][[
- list(self.idx[:pad]).index(i)
- for i in sorted(self.idx[:pad])
- ]])
- ]) for x in data_source
+ first_data[0] if first_data else second_data[0]
+ for x in range(len(first_data))
]
+ def _batchify(self, data_source):
+ """Load data from underlying arrays, internal use only."""
+ assert self.cursor < self.num_data, 'DataIter needs reset.'
+ # first batch of next epoch with 'roll_over'
+ if self.last_batch_handle == 'roll_over' and \
+ -self.batch_size < self.cursor < 0:
+ assert self._cache_data is not None or self._cache_label is not
None, \
+ 'next epoch should have cached data'
+ cache_data = self._cache_data if self._cache_data is not None else
self._cache_label
+ second_data = self._getdata(
+ data_source, end=self.cursor + self.batch_size)
+ if self._cache_data is not None:
+ self._cache_data = None
+ else:
+ self._cache_label = None
+ return self._concat(cache_data, second_data)
+ # last batch with 'pad'
+ elif self.last_batch_handle == 'pad' and \
+ self.cursor + self.batch_size > self.num_data:
+ pad = self.batch_size - self.num_data + self.cursor
+ first_data = self._getdata(data_source, start=self.cursor)
+ second_data = self._getdata(data_source, end=pad)
+ return self._concat(first_data, second_data)
+ # normal case
+ else:
+ if self.cursor + self.batch_size < self.num_data:
+ end_idx = self.cursor + self.batch_size
+ # get incomplete last batch
+ else:
+ end_idx = self.num_data
+ return self._getdata(data_source, self.cursor, end_idx)
+
def getdata(self):
- return self._getdata(self.data)
+ """Get data."""
+ return self._batchify(self.data)
def getlabel(self):
- return self._getdata(self.label)
+ """Get label."""
+ return self._batchify(self.label)
def getpad(self):
+ """Get pad value of DataBatch."""
if self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
+ # check the first batch
+ elif self.last_batch_handle == 'roll_over' and \
+ -self.batch_size < self.cursor < 0:
+ return -self.cursor
else:
return 0
+ def _shuffle_data(self):
+ """Shuffle the data."""
+ # shuffle index
+ np.random.shuffle(self.idx)
+ # get the data by corresponding index
+ self.data = _getdata_by_idx(self.data, self.idx)
+ self.label = _getdata_by_idx(self.label, self.idx)
class MXDataIter(DataIter):
"""A python wrapper a C++ data iterator.
diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py
new file mode 100644
index 0000000..55ba34a
--- /dev/null
+++ b/python/mxnet/io/utils.py
@@ -0,0 +1,86 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""utility functions for io.py"""
+from collections import OrderedDict
+
+import numpy as np
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+from ..ndarray.sparse import CSRNDArray
+from ..ndarray.sparse import array as sparse_array
+from ..ndarray import NDArray
+from ..ndarray import array
+
+def _init_data(data, allow_empty, default_name):
+ """Convert data into canonical form."""
+ assert (data is not None) or allow_empty
+ if data is None:
+ data = []
+
+ if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
+ if h5py else (np.ndarray, NDArray)):
+ data = [data]
+ if isinstance(data, list):
+ if not allow_empty:
+ assert(len(data) > 0)
+ if len(data) == 1:
+ data = OrderedDict([(default_name, data[0])]) # pylint:
disable=redefined-variable-type
+ else:
+ data = OrderedDict( # pylint: disable=redefined-variable-type
+ [('_%d_%s' % (i, default_name), d) for i, d in
enumerate(data)])
+ if not isinstance(data, dict):
+ raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " +
+ "a list of them or dict with them as values")
+ for k, v in data.items():
+ if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
+ try:
+ data[k] = array(v)
+ except:
+ raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) +
+ "should be NDArray, numpy.ndarray or
h5py.Dataset")
+
+ return list(sorted(data.items()))
+
+
+def _has_instance(data, dtype):
+ """Return True if ``data`` has instance of ``dtype``.
+ This function is called after _init_data.
+ ``data`` is a list of (str, NDArray)"""
+ for item in data:
+ _, arr = item
+ if isinstance(arr, dtype):
+ return True
+ return False
+
+
+def _getdata_by_idx(data, idx):
+ """Shuffle the data."""
+ shuffle_data = []
+
+ for k, v in data:
+ if (isinstance(v, h5py.Dataset) if h5py else False):
+ shuffle_data.append((k, v))
+ elif isinstance(v, CSRNDArray):
+ shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
+ else:
+ shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
+
+ return shuffle_data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index 872763f..0641f23 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -17,6 +17,7 @@
# pylint: skip-file
import mxnet as mx
+import mxnet.ndarray as nd
from mxnet.test_utils import *
from mxnet.base import MXNetError
import numpy as np
@@ -106,79 +107,139 @@ def test_image_iter_exception():
pass
assertRaises(MXNetError, check_cifar10_exception)
-def test_NDArrayIter():
- data = np.ones([1000, 2, 2])
- label = np.ones([1000, 1])
+def _init_NDArrayIter_data(data_type, is_image=False):
+ if is_image:
+ data = nd.random.uniform(0, 255, shape=(5000, 1, 28, 28))
+ labels = nd.ones((5000, 1))
+ return data, labels
+ if data_type == 'NDArray':
+ data = nd.ones((1000, 2, 2))
+ labels = nd.ones((1000, 1))
+ else:
+ data = np.ones((1000, 2, 2))
+ labels = np.ones((1000, 1))
for i in range(1000):
data[i] = i / 100
- label[i] = i / 100
- dataiter = mx.io.NDArrayIter(
- data, label, 128, True, last_batch_handle='pad')
- batchidx = 0
+ labels[i] = i / 100
+ return data, labels
+
+
+def _test_last_batch_handle(data, labels=None, is_image=False):
+ # Test the three parameters 'pad', 'discard', 'roll_over'
+ last_batch_handle_list = ['pad', 'discard', 'roll_over']
+ if labels is not None and not is_image and len(labels) != 0:
+ labelcount_list = [(124, 100), (100, 96), (100, 96)]
+ if is_image:
+ batch_count_list = [40, 39, 39]
+ else:
+ batch_count_list = [8, 7, 7]
+
+ for idx in range(len(last_batch_handle_list)):
+ dataiter = mx.io.NDArrayIter(
+ data, labels, 128, False,
last_batch_handle=last_batch_handle_list[idx])
+ batch_count = 0
+ if labels is not None and len(labels) != 0 and not is_image:
+ labelcount = [0 for i in range(10)]
+ for batch in dataiter:
+ if len(data) == 2:
+ assert len(batch.data) == 2
+ if labels is not None and len(labels) != 0:
+ if not is_image:
+ label = batch.label[0].asnumpy().flatten()
+ # check data if it matches corresponding labels
+ assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
+ for i in range(label.shape[0]):
+ labelcount[int(label[i])] += 1
+ else:
+ assert not batch.label, 'label is not empty list'
+ # keep the last batch of 'pad' to be used later
+ # to test first batch of roll_over in second iteration
+ batch_count += 1
+ if last_batch_handle_list[idx] == 'pad' and \
+ batch_count == batch_count_list[0]:
+ cache = batch.data[0].asnumpy()
+ # check if batchifying functionality work properly
+ if labels is not None and len(labels) != 0 and not is_image:
+ assert labelcount[0] == labelcount_list[idx][0],
last_batch_handle_list[idx]
+ assert labelcount[8] == labelcount_list[idx][1],
last_batch_handle_list[idx]
+ assert batch_count == batch_count_list[idx]
+ # roll_over option
+ dataiter.reset()
+ assert np.array_equal(dataiter.next().data[0].asnumpy(), cache)
+
+
+def _test_shuffle(data, labels=None):
+ dataiter = mx.io.NDArrayIter(data, labels, 1, False)
+ batch_list = []
for batch in dataiter:
- batchidx += 1
- assert(batchidx == 8)
- dataiter = mx.io.NDArrayIter(
- data, label, 128, False, last_batch_handle='pad')
- batchidx = 0
- labelcount = [0 for i in range(10)]
+ # cache the original data
+ batch_list.append(batch.data[0].asnumpy())
+ dataiter = mx.io.NDArrayIter(data, labels, 1, True)
+ idx_list = dataiter.idx
+ i = 0
for batch in dataiter:
- label = batch.label[0].asnumpy().flatten()
- assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
- for i in range(label.shape[0]):
- labelcount[int(label[i])] += 1
+ # check if each data point have been shuffled to corresponding
positions
+ assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
+ i += 1
- for i in range(10):
- if i == 0:
- assert(labelcount[i] == 124)
- else:
- assert(labelcount[i] == 100)
+
+def test_NDArrayIter():
+ dtype_list = ['NDArray', 'ndarray']
+ tested_data_type = [False, True]
+ for dtype in dtype_list:
+ for is_image in tested_data_type:
+ data, labels = _init_NDArrayIter_data(dtype, is_image)
+ _test_last_batch_handle(data, labels, is_image)
+ _test_last_batch_handle([data, data], labels, is_image)
+ _test_last_batch_handle(data=[data, data], is_image=is_image)
+ _test_last_batch_handle(
+ {'data1': data, 'data2': data}, labels, is_image)
+ _test_last_batch_handle(data={'data1': data, 'data2': data},
is_image=is_image)
+ _test_last_batch_handle(data, [], is_image)
+ _test_last_batch_handle(data=data, is_image=is_image)
+ _test_shuffle(data, labels)
+ _test_shuffle([data, data], labels)
+ _test_shuffle([data, data])
+ _test_shuffle({'data1': data, 'data2': data}, labels)
+ _test_shuffle({'data1': data, 'data2': data})
+ _test_shuffle(data, [])
+ _test_shuffle(data)
def test_NDArrayIter_h5py():
if not h5py:
return
- data = np.ones([1000, 2, 2])
- label = np.ones([1000, 1])
- for i in range(1000):
- data[i] = i / 100
- label[i] = i / 100
+ data, labels = _init_NDArrayIter_data('ndarray')
try:
- os.remove("ndarraytest.h5")
+ os.remove('ndarraytest.h5')
except OSError:
pass
- with h5py.File("ndarraytest.h5") as f:
- f.create_dataset("data", data=data)
- f.create_dataset("label", data=label)
-
- dataiter = mx.io.NDArrayIter(
- f["data"], f["label"], 128, True, last_batch_handle='pad')
- batchidx = 0
- for batch in dataiter:
- batchidx += 1
- assert(batchidx == 8)
-
- dataiter = mx.io.NDArrayIter(
- f["data"], f["label"], 128, False, last_batch_handle='pad')
- labelcount = [0 for i in range(10)]
- for batch in dataiter:
- label = batch.label[0].asnumpy().flatten()
- assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
- for i in range(label.shape[0]):
- labelcount[int(label[i])] += 1
-
+ with h5py.File('ndarraytest.h5') as f:
+ f.create_dataset('data', data=data)
+ f.create_dataset('label', data=labels)
+
+ _test_last_batch_handle(f['data'], f['label'])
+ _test_last_batch_handle(f['data'], [])
+ _test_last_batch_handle(f['data'])
try:
os.remove("ndarraytest.h5")
except OSError:
pass
- for i in range(10):
- if i == 0:
- assert(labelcount[i] == 124)
- else:
- assert(labelcount[i] == 100)
+
+def _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list, csr_iter_None,
num_rows, batch_size):
+ num_batch = 0
+ for _, batch_empty_list, batch_empty_None in zip(csr_iter,
csr_iter_empty_list, csr_iter_None):
+ assert not batch_empty_list.label, 'label is not empty list'
+ assert not batch_empty_None.label, 'label is not empty list'
+ num_batch += 1
+
+ assert(num_batch == num_rows // batch_size)
+ assertRaises(StopIteration, csr_iter.next)
+ assertRaises(StopIteration, csr_iter_empty_list.next)
+ assertRaises(StopIteration, csr_iter_None.next)
def test_NDArrayIter_csr():
@@ -200,15 +261,26 @@ def test_NDArrayIter_csr():
{'data': train_data}, dns, batch_size)
except ImportError:
pass
+
+ # scipy.sparse.csr_matrix with shuffle
+ csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size,
+ shuffle=True,
last_batch_handle='discard'))
+ csr_iter_empty_list = iter(mx.io.NDArrayIter({'data': train_data}, [],
batch_size,
+ shuffle=True,
last_batch_handle='discard'))
+ csr_iter_None = iter(mx.io.NDArrayIter({'data': train_data}, None,
batch_size,
+ shuffle=True,
last_batch_handle='discard'))
+ _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list,
+ csr_iter_None, num_rows, batch_size)
# CSRNDArray with shuffle
csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns,
batch_size,
shuffle=True,
last_batch_handle='discard'))
- num_batch = 0
- for batch in csr_iter:
- num_batch += 1
-
- assert(num_batch == num_rows // batch_size)
+ csr_iter_empty_list = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data':
dns}, [], batch_size,
+ shuffle=True,
last_batch_handle='discard'))
+ csr_iter_None = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns},
None, batch_size,
+ shuffle=True,
last_batch_handle='discard'))
+ _test_NDArrayIter_csr(csr_iter, csr_iter_empty_list,
+ csr_iter_None, num_rows, batch_size)
# make iterators
csr_iter = iter(mx.io.NDArrayIter(