This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 8ff50c9 Revert "Change the way NDArrayIter handle the last batch"
(#12537)
8ff50c9 is described below
commit 8ff50c95201e02e849e0592de5fb7af87489be53
Author: Joshua Z. Zhang <[email protected]>
AuthorDate: Wed Sep 12 13:32:57 2018 -0700
Revert "Change the way NDArrayIter handle the last batch" (#12537)
* Revert "Removing the re-size for validation data, which breaking the
validation accuracy of CIFAR training (#12362)"
This reverts commit ceabcaac77543d99246415b2fb2d8c973a830453.
* Revert "[MXNET-580] Add SN-GAN example (#12419)"
This reverts commit 46a5cee2515a1ac0a1ae5afbe7e639debb998587.
* Revert "Remove regression checks for website links (#12507)"
This reverts commit 619bc3ea3c9093b72634d16e91596b3a65f3f1fc.
* Revert "Revert "Fix flaky test: test_mkldnn.test_activation #12377
(#12418)" (#12516)"
This reverts commit 7ea05333efc8ca868443b89233b101d068f6af9f.
* Revert "further bump up tolerance for sparse dot (#12527)"
This reverts commit 90599e1038a4ff6604e9ed0d55dc274c2df635f8.
* Revert "Fix broken URLs (#12508)"
This reverts commit 3d83c896fd8b237c53003888e35a4d792c1e5389.
* Revert "Temporarily disable flaky tests (#12520)"
This reverts commit 35ca13c3b5a0e57d904d1fead079152a15dfeac4.
* Revert "Add support for more req patterns for bilinear sampler backward
(#12386)"
This reverts commit 4ee866fc75307b284cc0eae93d0cf4dad3b62533.
* Revert "Change the way NDArrayIter handle the last batch (#12285)"
This reverts commit 597a637fb1b8fa5b16331218cda8be61ce0ee202.
---
CONTRIBUTORS.md | 1 -
python/mxnet/{io => }/io.py | 280 +++++++++++++++++++--------------------
python/mxnet/io/__init__.py | 29 ----
python/mxnet/io/utils.py | 86 ------------
tests/python/unittest/test_io.py | 122 ++++++++---------
5 files changed, 190 insertions(+), 328 deletions(-)
diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md
index 1c005d5..8d8aeac 100644
--- a/CONTRIBUTORS.md
+++ b/CONTRIBUTORS.md
@@ -178,4 +178,3 @@ List of Contributors
* [Aaron Markham](https://github.com/aaronmarkham)
* [Sam Skalicky](https://github.com/samskalicky)
* [Per Goncalves da Silva](https://github.com/perdasilva)
-* [Cheng-Che Lee](https://github.com/stu1130)
diff --git a/python/mxnet/io/io.py b/python/mxnet/io.py
similarity index 82%
rename from python/mxnet/io/io.py
rename to python/mxnet/io.py
index 2ae3e70..884e929 100644
--- a/python/mxnet/io/io.py
+++ b/python/mxnet/io.py
@@ -17,26 +17,30 @@
"""Data iterators for common data formats."""
from __future__ import absolute_import
-from collections import namedtuple
+from collections import OrderedDict, namedtuple
import sys
import ctypes
import logging
import threading
+try:
+ import h5py
+except ImportError:
+ h5py = None
import numpy as np
-
-from ..base import _LIB
-from ..base import c_str_array, mx_uint, py_str
-from ..base import DataIterHandle, NDArrayHandle
-from ..base import mx_real_t
-from ..base import check_call, build_param_doc as _build_param_doc
-from ..ndarray import NDArray
-from ..ndarray.sparse import CSRNDArray
-from ..ndarray import _ndarray_cls
-from ..ndarray import array
-from ..ndarray import concat
-
-from .utils import init_data, has_instance, getdata_by_idx
+from .base import _LIB
+from .base import c_str_array, mx_uint, py_str
+from .base import DataIterHandle, NDArrayHandle
+from .base import mx_real_t
+from .base import check_call, build_param_doc as _build_param_doc
+from .ndarray import NDArray
+from .ndarray.sparse import CSRNDArray
+from .ndarray.sparse import array as sparse_array
+from .ndarray import _ndarray_cls
+from .ndarray import array
+from .ndarray import concatenate
+from .ndarray import arange
+from .ndarray.random import shuffle as random_shuffle
class DataDesc(namedtuple('DataDesc', ['name', 'shape'])):
"""DataDesc is used to store name, shape, type and layout
@@ -485,6 +489,59 @@ class PrefetchingIter(DataIter):
def getpad(self):
return self.current_batch.pad
+def _init_data(data, allow_empty, default_name):
+ """Convert data into canonical form."""
+ assert (data is not None) or allow_empty
+ if data is None:
+ data = []
+
+ if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
+ if h5py else (np.ndarray, NDArray)):
+ data = [data]
+ if isinstance(data, list):
+ if not allow_empty:
+ assert(len(data) > 0)
+ if len(data) == 1:
+ data = OrderedDict([(default_name, data[0])]) # pylint:
disable=redefined-variable-type
+ else:
+ data = OrderedDict( # pylint: disable=redefined-variable-type
+ [('_%d_%s' % (i, default_name), d) for i, d in
enumerate(data)])
+ if not isinstance(data, dict):
+ raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset "
+ \
+ "a list of them or dict with them as values")
+ for k, v in data.items():
+ if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
+ try:
+ data[k] = array(v)
+ except:
+ raise TypeError(("Invalid type '%s' for %s, " % (type(v), k))
+ \
+ "should be NDArray, numpy.ndarray or
h5py.Dataset")
+
+ return list(sorted(data.items()))
+
+def _has_instance(data, dtype):
+ """Return True if ``data`` has instance of ``dtype``.
+ This function is called after _init_data.
+ ``data`` is a list of (str, NDArray)"""
+ for item in data:
+ _, arr = item
+ if isinstance(arr, dtype):
+ return True
+ return False
+
+def _shuffle(data, idx):
+ """Shuffle the data."""
+ shuffle_data = []
+
+ for k, v in data:
+ if (isinstance(v, h5py.Dataset) if h5py else False):
+ shuffle_data.append((k, v))
+ elif isinstance(v, CSRNDArray):
+ shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
+ else:
+ shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
+
+ return shuffle_data
class NDArrayIter(DataIter):
"""Returns an iterator for ``mx.nd.NDArray``, ``numpy.ndarray``,
``h5py.Dataset``
@@ -544,22 +601,6 @@ class NDArrayIter(DataIter):
...
>>> batchidx # Remaining examples are discarded. So, 10/3 batches are
created.
3
- >>> dataiter = mx.io.NDArrayIter(data, labels, 3, False,
last_batch_handle='roll_over')
- >>> batchidx = 0
- >>> for batch in dataiter:
- ... batchidx += 1
- ...
- >>> batchidx # Remaining examples are rolled over to the next iteration.
- 3
- >>> dataiter.reset()
- >>> dataiter.next().data[0].asnumpy()
- [[[ 36. 37.]
- [ 38. 39.]]
- [[ 0. 1.]
- [ 2. 3.]]
- [[ 4. 5.]
- [ 6. 7.]]]
- (3L, 2L, 2L)
`NDArrayIter` also supports multiple input and labels.
@@ -592,11 +633,8 @@ class NDArrayIter(DataIter):
Only supported if no h5py.Dataset inputs are used.
last_batch_handle : str, optional
How to handle the last batch. This parameter can be 'pad', 'discard' or
- 'roll_over'.
- If 'pad', the last batch will be padded with data starting from the
begining
- If 'discard', the last batch will be discarded
- If 'roll_over', the remaining elements will be rolled over to the next
iteration and
- note that it is intended for training and can cause problems if used
for prediction.
+ 'roll_over'. 'roll_over' is intended for training and can cause
problems
+ if used for prediction.
data_name : str, optional
The data name.
label_name : str, optional
@@ -607,28 +645,36 @@ class NDArrayIter(DataIter):
label_name='softmax_label'):
super(NDArrayIter, self).__init__(batch_size)
- self.data = init_data(data, allow_empty=False, default_name=data_name)
- self.label = init_data(label, allow_empty=True,
default_name=label_name)
+ self.data = _init_data(data, allow_empty=False, default_name=data_name)
+ self.label = _init_data(label, allow_empty=True,
default_name=label_name)
- if ((has_instance(self.data, CSRNDArray) or has_instance(self.label,
CSRNDArray)) and
+ if ((_has_instance(self.data, CSRNDArray) or _has_instance(self.label,
CSRNDArray)) and
(last_batch_handle != 'discard')):
raise NotImplementedError("`NDArrayIter` only supports
``CSRNDArray``" \
" with `last_batch_handle` set to
`discard`.")
- self.idx = np.arange(self.data[0][1].shape[0])
- self.shuffle = shuffle
- self.last_batch_handle = last_batch_handle
- self.batch_size = batch_size
- self.cursor = -self.batch_size
- self.num_data = self.idx.shape[0]
- # shuffle
- self.reset()
+ # shuffle data
+ if shuffle:
+ tmp_idx = arange(self.data[0][1].shape[0], dtype=np.int32)
+ self.idx = random_shuffle(tmp_idx, out=tmp_idx).asnumpy()
+ self.data = _shuffle(self.data, self.idx)
+ self.label = _shuffle(self.label, self.idx)
+ else:
+ self.idx = np.arange(self.data[0][1].shape[0])
+
+ # batching
+ if last_batch_handle == 'discard':
+ new_n = self.data[0][1].shape[0] - self.data[0][1].shape[0] %
batch_size
+ self.idx = self.idx[:new_n]
self.data_list = [x[1] for x in self.data] + [x[1] for x in self.label]
self.num_source = len(self.data_list)
- # used for 'roll_over'
- self._cache_data = None
- self._cache_label = None
+ self.num_data = self.idx.shape[0]
+ assert self.num_data >= batch_size, \
+ "batch_size needs to be smaller than data size."
+ self.cursor = -batch_size
+ self.batch_size = batch_size
+ self.last_batch_handle = last_batch_handle
@property
def provide_data(self):
@@ -648,126 +694,74 @@ class NDArrayIter(DataIter):
def hard_reset(self):
"""Ignore roll over data and set to start."""
- if self.shuffle:
- self._shuffle_data()
self.cursor = -self.batch_size
- self._cache_data = None
- self._cache_label = None
def reset(self):
- """Resets the iterator to the beginning of the data."""
- if self.shuffle:
- self._shuffle_data()
- # the range below indicate the last batch
- if self.last_batch_handle == 'roll_over' and \
- self.num_data - self.batch_size < self.cursor < self.num_data:
- # (self.cursor - self.num_data) represents the data we have for
the last batch
- self.cursor = self.cursor - self.num_data - self.batch_size
+ if self.last_batch_handle == 'roll_over' and self.cursor >
self.num_data:
+ self.cursor = -self.batch_size +
(self.cursor%self.num_data)%self.batch_size
else:
self.cursor = -self.batch_size
def iter_next(self):
- """Increments the coursor by batch_size for next batch
- and check current cursor if it exceed the number of data points."""
self.cursor += self.batch_size
return self.cursor < self.num_data
def next(self):
- """Returns the next batch of data."""
- if not self.iter_next():
- raise StopIteration
- data = self.getdata()
- label = self.getlabel()
- # iter should stop when last batch is not complete
- if data[0].shape[0] != self.batch_size:
- # in this case, cache it for next epoch
- self._cache_data = data
- self._cache_label = label
+ if self.iter_next():
+ return DataBatch(data=self.getdata(), label=self.getlabel(), \
+ pad=self.getpad(), index=None)
+ else:
raise StopIteration
- return DataBatch(data=data, label=label, \
- pad=self.getpad(), index=None)
-
- def _getdata(self, data_source, start=None, end=None):
- """Load data from underlying arrays."""
- assert start is not None or end is not None, 'should at least specify
start or end'
- start = start if start is not None else 0
- end = end if end is not None else data_source[0][1].shape[0]
- s = slice(start, end)
- return [
- x[1][s]
- if isinstance(x[1], (np.ndarray, NDArray)) else
- # h5py (only supports indices in increasing order)
- array(x[1][sorted(self.idx[s])][[
- list(self.idx[s]).index(i)
- for i in sorted(self.idx[s])
- ]]) for x in data_source
- ]
- def _concat(self, first_data, second_data):
- """Helper function to concat two NDArrays."""
- return [
- concat(first_data[0], second_data[0], dim=0)
- ]
-
- def _batchify(self, data_source):
+ def _getdata(self, data_source):
"""Load data from underlying arrays, internal use only."""
- assert self.cursor < self.num_data, 'DataIter needs reset.'
- # first batch of next epoch with 'roll_over'
- if self.last_batch_handle == 'roll_over' and \
- -self.batch_size < self.cursor < 0:
- assert self._cache_data is not None or self._cache_label is not
None, \
- 'next epoch should have cached data'
- cache_data = self._cache_data if self._cache_data is not None else
self._cache_label
- second_data = self._getdata(
- data_source, end=self.cursor + self.batch_size)
- if self._cache_data is not None:
- self._cache_data = None
- else:
- self._cache_label = None
- return self._concat(cache_data, second_data)
- # last batch with 'pad'
- elif self.last_batch_handle == 'pad' and \
- self.cursor + self.batch_size > self.num_data:
- pad = self.batch_size - self.num_data + self.cursor
- first_data = self._getdata(data_source, start=self.cursor)
- second_data = self._getdata(data_source, end=pad)
- return self._concat(first_data, second_data)
- # normal case
+ assert(self.cursor < self.num_data), "DataIter needs reset."
+ if self.cursor + self.batch_size <= self.num_data:
+ return [
+ # np.ndarray or NDArray case
+ x[1][self.cursor:self.cursor + self.batch_size]
+ if isinstance(x[1], (np.ndarray, NDArray)) else
+ # h5py (only supports indices in increasing order)
+ array(x[1][sorted(self.idx[
+ self.cursor:self.cursor + self.batch_size])][[
+ list(self.idx[self.cursor:
+ self.cursor + self.batch_size]).index(i)
+ for i in sorted(self.idx[
+ self.cursor:self.cursor + self.batch_size])
+ ]]) for x in data_source
+ ]
else:
- if self.cursor + self.batch_size < self.num_data:
- end_idx = self.cursor + self.batch_size
- # get incomplete last batch
- else:
- end_idx = self.num_data
- return self._getdata(data_source, self.cursor, end_idx)
+ pad = self.batch_size - self.num_data + self.cursor
+ return [
+ # np.ndarray or NDArray case
+ concatenate([x[1][self.cursor:], x[1][:pad]])
+ if isinstance(x[1], (np.ndarray, NDArray)) else
+ # h5py (only supports indices in increasing order)
+ concatenate([
+ array(x[1][sorted(self.idx[self.cursor:])][[
+ list(self.idx[self.cursor:]).index(i)
+ for i in sorted(self.idx[self.cursor:])
+ ]]),
+ array(x[1][sorted(self.idx[:pad])][[
+ list(self.idx[:pad]).index(i)
+ for i in sorted(self.idx[:pad])
+ ]])
+ ]) for x in data_source
+ ]
def getdata(self):
- """Get data."""
- return self._batchify(self.data)
+ return self._getdata(self.data)
def getlabel(self):
- """Get label."""
- return self._batchify(self.label)
+ return self._getdata(self.label)
def getpad(self):
- """Get pad value of DataBatch."""
if self.last_batch_handle == 'pad' and \
self.cursor + self.batch_size > self.num_data:
return self.cursor + self.batch_size - self.num_data
- # check the first batch
- elif self.last_batch_handle == 'roll_over' and \
- -self.batch_size < self.cursor < 0:
- return -self.cursor
else:
return 0
- def _shuffle_data(self):
- """Shuffle the data."""
- # shuffle index
- np.random.shuffle(self.idx)
- # get the data by corresponding index
- self.data = getdata_by_idx(self.data, self.idx)
- self.label = getdata_by_idx(self.label, self.idx)
class MXDataIter(DataIter):
"""A python wrapper a C++ data iterator.
@@ -779,7 +773,7 @@ class MXDataIter(DataIter):
underlying C++ data iterators.
Usually you don't need to interact with `MXDataIter` directly unless you
are
- implementing your own data iterators in C+ +. To do that, please refer to
+ implementing your own data iterators in C++. To do that, please refer to
examples under the `src/io` folder.
Parameters
diff --git a/python/mxnet/io/__init__.py b/python/mxnet/io/__init__.py
deleted file mode 100644
index 5c5e2e6..0000000
--- a/python/mxnet/io/__init__.py
+++ /dev/null
@@ -1,29 +0,0 @@
-#!/usr/bin/env python
-
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-# coding: utf-8
-# pylint: disable=wildcard-import
-""" Data iterators for common data formats and utility functions."""
-from __future__ import absolute_import
-
-from . import io
-from .io import *
-
-from . import utils
-from .utils import *
diff --git a/python/mxnet/io/utils.py b/python/mxnet/io/utils.py
deleted file mode 100644
index 872e641..0000000
--- a/python/mxnet/io/utils.py
+++ /dev/null
@@ -1,86 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""utility functions for io.py"""
-from collections import OrderedDict
-
-import numpy as np
-try:
- import h5py
-except ImportError:
- h5py = None
-
-from ..ndarray.sparse import CSRNDArray
-from ..ndarray.sparse import array as sparse_array
-from ..ndarray import NDArray
-from ..ndarray import array
-
-def init_data(data, allow_empty, default_name):
- """Convert data into canonical form."""
- assert (data is not None) or allow_empty
- if data is None:
- data = []
-
- if isinstance(data, (np.ndarray, NDArray, h5py.Dataset)
- if h5py else (np.ndarray, NDArray)):
- data = [data]
- if isinstance(data, list):
- if not allow_empty:
- assert(len(data) > 0)
- if len(data) == 1:
- data = OrderedDict([(default_name, data[0])]) # pylint:
disable=redefined-variable-type
- else:
- data = OrderedDict( # pylint: disable=redefined-variable-type
- [('_%d_%s' % (i, default_name), d) for i, d in
enumerate(data)])
- if not isinstance(data, dict):
- raise TypeError("Input must be NDArray, numpy.ndarray, h5py.Dataset " +
- "a list of them or dict with them as values")
- for k, v in data.items():
- if not isinstance(v, (NDArray, h5py.Dataset) if h5py else NDArray):
- try:
- data[k] = array(v)
- except:
- raise TypeError(("Invalid type '%s' for %s, " % (type(v), k)) +
- "should be NDArray, numpy.ndarray or
h5py.Dataset")
-
- return list(sorted(data.items()))
-
-
-def has_instance(data, dtype):
- """Return True if ``data`` has instance of ``dtype``.
- This function is called after _init_data.
- ``data`` is a list of (str, NDArray)"""
- for item in data:
- _, arr = item
- if isinstance(arr, dtype):
- return True
- return False
-
-
-def getdata_by_idx(data, idx):
- """Shuffle the data."""
- shuffle_data = []
-
- for k, v in data:
- if (isinstance(v, h5py.Dataset) if h5py else False):
- shuffle_data.append((k, v))
- elif isinstance(v, CSRNDArray):
- shuffle_data.append((k, sparse_array(v.asscipy()[idx], v.context)))
- else:
- shuffle_data.append((k, array(v.asnumpy()[idx], v.context)))
-
- return shuffle_data
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index ae68626..4dfa69c 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -88,88 +88,80 @@ def test_Cifar10Rec():
assert(labelcount[i] == 5000)
-def _init_NDArrayIter_data():
+def test_NDArrayIter():
data = np.ones([1000, 2, 2])
- labels = np.ones([1000, 1])
+ label = np.ones([1000, 1])
for i in range(1000):
data[i] = i / 100
- labels[i] = i / 100
- return data, labels
-
-
-def _test_last_batch_handle(data, labels):
- # Test the three parameters 'pad', 'discard', 'roll_over'
- last_batch_handle_list = ['pad', 'discard' , 'roll_over']
- labelcount_list = [(124, 100), (100, 96), (100, 96)]
- batch_count_list = [8, 7, 7]
-
- for idx in range(len(last_batch_handle_list)):
- dataiter = mx.io.NDArrayIter(
- data, labels, 128, False,
last_batch_handle=last_batch_handle_list[idx])
- batch_count = 0
- labelcount = [0 for i in range(10)]
- for batch in dataiter:
- label = batch.label[0].asnumpy().flatten()
- # check data if it matches corresponding labels
- assert((batch.data[0].asnumpy()[:, 0, 0] == label).all()),
last_batch_handle_list[idx]
- for i in range(label.shape[0]):
- labelcount[int(label[i])] += 1
- # keep the last batch of 'pad' to be used later
- # to test first batch of roll_over in second iteration
- batch_count += 1
- if last_batch_handle_list[idx] == 'pad' and \
- batch_count == 8:
- cache = batch.data[0].asnumpy()
- # check if batchifying functionality work properly
- assert labelcount[0] == labelcount_list[idx][0],
last_batch_handle_list[idx]
- assert labelcount[8] == labelcount_list[idx][1],
last_batch_handle_list[idx]
- assert batch_count == batch_count_list[idx]
- # roll_over option
- dataiter.reset()
- assert np.array_equal(dataiter.next().data[0].asnumpy(), cache)
-
-
-def _test_shuffle(data, labels):
- dataiter = mx.io.NDArrayIter(data, labels, 1, False)
- batch_list = []
+ label[i] = i / 100
+ dataiter = mx.io.NDArrayIter(
+ data, label, 128, True, last_batch_handle='pad')
+ batchidx = 0
for batch in dataiter:
- # cache the original data
- batch_list.append(batch.data[0].asnumpy())
- dataiter = mx.io.NDArrayIter(data, labels, 1, True)
- idx_list = dataiter.idx
- i = 0
+ batchidx += 1
+ assert(batchidx == 8)
+ dataiter = mx.io.NDArrayIter(
+ data, label, 128, False, last_batch_handle='pad')
+ batchidx = 0
+ labelcount = [0 for i in range(10)]
for batch in dataiter:
- # check if each data point have been shuffled to corresponding
positions
- assert np.array_equal(batch.data[0].asnumpy(), batch_list[idx_list[i]])
- i += 1
-
+ label = batch.label[0].asnumpy().flatten()
+ assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
+ for i in range(label.shape[0]):
+ labelcount[int(label[i])] += 1
-def test_NDArrayIter():
- data, labels = _init_NDArrayIter_data()
- _test_last_batch_handle(data, labels)
- _test_shuffle(data, labels)
+ for i in range(10):
+ if i == 0:
+ assert(labelcount[i] == 124)
+ else:
+ assert(labelcount[i] == 100)
def test_NDArrayIter_h5py():
if not h5py:
return
- data, labels = _init_NDArrayIter_data()
+ data = np.ones([1000, 2, 2])
+ label = np.ones([1000, 1])
+ for i in range(1000):
+ data[i] = i / 100
+ label[i] = i / 100
try:
- os.remove('ndarraytest.h5')
+ os.remove("ndarraytest.h5")
except OSError:
pass
- with h5py.File('ndarraytest.h5') as f:
- f.create_dataset('data', data=data)
- f.create_dataset('label', data=labels)
+ with h5py.File("ndarraytest.h5") as f:
+ f.create_dataset("data", data=data)
+ f.create_dataset("label", data=label)
+
+ dataiter = mx.io.NDArrayIter(
+ f["data"], f["label"], 128, True, last_batch_handle='pad')
+ batchidx = 0
+ for batch in dataiter:
+ batchidx += 1
+ assert(batchidx == 8)
+
+ dataiter = mx.io.NDArrayIter(
+ f["data"], f["label"], 128, False, last_batch_handle='pad')
+ labelcount = [0 for i in range(10)]
+ for batch in dataiter:
+ label = batch.label[0].asnumpy().flatten()
+ assert((batch.data[0].asnumpy()[:, 0, 0] == label).all())
+ for i in range(label.shape[0]):
+ labelcount[int(label[i])] += 1
- _test_last_batch_handle(f['data'], f['label'])
try:
os.remove("ndarraytest.h5")
except OSError:
pass
+ for i in range(10):
+ if i == 0:
+ assert(labelcount[i] == 124)
+ else:
+ assert(labelcount[i] == 100)
+
def test_NDArrayIter_csr():
# creating toy data
@@ -190,20 +182,12 @@ def test_NDArrayIter_csr():
{'data': train_data}, dns, batch_size)
except ImportError:
pass
- # scipy.sparse.csr_matrix with shuffle
- num_batch = 0
- csr_iter = iter(mx.io.NDArrayIter({'data': train_data}, dns, batch_size,
- shuffle=True,
last_batch_handle='discard'))
- for _ in csr_iter:
- num_batch += 1
-
- assert(num_batch == num_rows // batch_size)
# CSRNDArray with shuffle
csr_iter = iter(mx.io.NDArrayIter({'csr_data': csr, 'dns_data': dns}, dns,
batch_size,
shuffle=True,
last_batch_handle='discard'))
num_batch = 0
- for _ in csr_iter:
+ for batch in csr_iter:
num_batch += 1
assert(num_batch == num_rows // batch_size)