This is an automated email from the ASF dual-hosted git repository.
zhreshold pushed a commit to branch v1.4.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.4.x by this push:
new b7f0122 fix the situation where idx didn't align with rec (#13550)
b7f0122 is described below
commit b7f0122c8dd741ca71803ff81bd3ebd7a4e1070e
Author: Jake Lee <[email protected]>
AuthorDate: Fri Dec 7 17:06:05 2018 -0800
fix the situation where idx didn't align with rec (#13550)
minor fix the image.py
add last_batch_handle for imagedeiter
remove the label type
refactor the imageiter unit test
fix the trailing whitespace
fix coding style
add new line
move helper function to the top of the file
---
python/mxnet/image/detection.py | 64 +++++++++++--
python/mxnet/image/image.py | 5 +-
tests/python/unittest/test_image.py | 184 ++++++++++++++++++++----------------
3 files changed, 157 insertions(+), 96 deletions(-)
diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index b27917c..d5b5eca 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -658,19 +658,26 @@ class ImageDetIter(ImageIter):
Data name for provided symbols.
label_name : str
Name for detection labels
+ last_batch_handle : str, optional
+ How to handle the last batch.
+ This parameter can be 'pad'(default), 'discard' or 'roll_over'.
+ If 'pad', the last batch will be padded with data starting from the
begining
+ If 'discard', the last batch will be discarded
+ If 'roll_over', the remaining elements will be rolled over to the next
iteration
kwargs : ...
More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
"""
def __init__(self, batch_size, data_shape,
path_imgrec=None, path_imglist=None, path_root=None,
path_imgidx=None,
shuffle=False, part_index=0, num_parts=1, aug_list=None,
imglist=None,
- data_name='data', label_name='label', **kwargs):
+ data_name='data', label_name='label',
last_batch_handle='pad', **kwargs):
super(ImageDetIter, self).__init__(batch_size=batch_size,
data_shape=data_shape,
path_imgrec=path_imgrec,
path_imglist=path_imglist,
path_root=path_root,
path_imgidx=path_imgidx,
shuffle=shuffle,
part_index=part_index,
num_parts=num_parts, aug_list=[],
imglist=imglist,
- data_name=data_name,
label_name=label_name)
+ data_name=data_name,
label_name=label_name,
+ last_batch_handle=last_batch_handle)
if aug_list is None:
self.auglist = CreateDetAugmenter(data_shape, **kwargs)
@@ -751,14 +758,10 @@ class ImageDetIter(ImageIter):
self.provide_label = [(self.provide_label[0][0],
(self.batch_size,) + label_shape)]
self.label_shape = label_shape
- def next(self):
- """Override the function for returning next batch."""
+ def _batchify(self, batch_data, batch_label, start=0):
+ """Override the helper function for batchifying data"""
+ i = start
batch_size = self.batch_size
- c, h, w = self.data_shape
- batch_data = nd.zeros((batch_size, c, h, w))
- batch_label = nd.empty(self.provide_label[0][1])
- batch_label[:] = -1
- i = 0
try:
while i < batch_size:
label, s = self.next_sample()
@@ -783,7 +786,48 @@ class ImageDetIter(ImageIter):
if not i:
raise StopIteration
- return io.DataBatch([batch_data], [batch_label], batch_size - i)
+ return i
+
+ def next(self):
+ """Override the function for returning next batch."""
+ batch_size = self.batch_size
+ c, h, w = self.data_shape
+ # if last batch data is rolled over
+ if self._cache_data is not None:
+ # check both the data and label have values
+ assert self._cache_label is not None, "_cache_label didn't have
values"
+ assert self._cache_idx is not None, "_cache_idx didn't have values"
+ batch_data = self._cache_data
+ batch_label = self._cache_label
+ i = self._cache_idx
+ else:
+ batch_data = nd.zeros((batch_size, c, h, w))
+ batch_label = nd.empty(self.provide_label[0][1])
+ batch_label[:] = -1
+ i = self._batchify(batch_data, batch_label)
+ # calculate the padding
+ pad = batch_size - i
+ # handle padding for the last batch
+ if pad != 0:
+ if self.last_batch_handle == 'discard':
+ raise StopIteration
+ # if the option is 'roll_over', throw StopIteration and cache the
data
+ elif self.last_batch_handle == 'roll_over' and \
+ self._cache_data is None:
+ self._cache_data = batch_data
+ self._cache_label = batch_label
+ self._cache_idx = i
+ raise StopIteration
+ else:
+ _ = self._batchify(batch_data, batch_label, i)
+ if self.last_batch_handle == 'pad':
+ self._allow_read = False
+ else:
+ self._cache_data = None
+ self._cache_label = None
+ self._cache_idx = None
+
+ return io.DataBatch([batch_data], [batch_label], pad=pad)
def augmentation_transform(self, data, label): # pylint:
disable=arguments-differ
"""Override Transforms input data with specified augmentations."""
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index c9a457f..9c2a1cb 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -1145,7 +1145,7 @@ class ImageIter(io.DataIter):
self.shuffle = shuffle
if self.imgrec is None:
self.seq = imgkeys
- elif shuffle or num_parts > 1:
+ elif shuffle or num_parts > 1 or path_imgidx:
assert self.imgidx is not None
self.seq = self.imgidx
else:
@@ -1261,7 +1261,7 @@ class ImageIter(io.DataIter):
i = self._cache_idx
# clear the cache data
else:
- batch_data = nd.empty((batch_size, c, h, w))
+ batch_data = nd.zeros((batch_size, c, h, w))
batch_label = nd.empty(self.provide_label[0][1])
i = self._batchify(batch_data, batch_label)
# calculate the padding
@@ -1285,6 +1285,7 @@ class ImageIter(io.DataIter):
self._cache_data = None
self._cache_label = None
self._cache_idx = None
+
return io.DataBatch([batch_data], [batch_label], pad=pad)
def check_data_shape(self, data_shape):
diff --git a/tests/python/unittest/test_image.py
b/tests/python/unittest/test_image.py
index 4f66823..4063027 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -25,6 +25,7 @@ import unittest
from nose.tools import raises
+
def _get_data(url, dirname):
import os, tarfile
download(url, dirname=dirname, overwrite=False)
@@ -50,6 +51,62 @@ def _generate_objects():
label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
return [2, 5] + label
+def _test_imageiter_last_batch(imageiter_list, assert_data_shape):
+ test_iter = imageiter_list[0]
+ # test batch data shape
+ for _ in range(3):
+ for batch in test_iter:
+ assert batch.data[0].shape == assert_data_shape
+ test_iter.reset()
+ # test last batch handle(discard)
+ test_iter = imageiter_list[1]
+ i = 0
+ for batch in test_iter:
+ i += 1
+ assert i == 5
+ # test last_batch_handle(pad)
+ test_iter = imageiter_list[2]
+ i = 0
+ for batch in test_iter:
+ if i == 0:
+ first_three_data = batch.data[0][:2]
+ if i == 5:
+ last_three_data = batch.data[0][1:]
+ i += 1
+ assert i == 6
+ assert np.array_equal(first_three_data.asnumpy(),
last_three_data.asnumpy())
+ # test last_batch_handle(roll_over)
+ test_iter = imageiter_list[3]
+ i = 0
+ for batch in test_iter:
+ if i == 0:
+ first_image = batch.data[0][0]
+ i += 1
+ assert i == 5
+ test_iter.reset()
+ first_batch_roll_over = test_iter.next()
+ assert np.array_equal(
+ first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
+ assert first_batch_roll_over.pad == 2
+ # test iteratopr work properly after calling reset several times when
last_batch_handle is roll_over
+ for _ in test_iter:
+ pass
+ test_iter.reset()
+ first_batch_roll_over_twice = test_iter.next()
+ assert np.array_equal(
+ first_batch_roll_over_twice.data[0][2].asnumpy(),
first_image.asnumpy())
+ assert first_batch_roll_over_twice.pad == 1
+ # we've called next once
+ i = 1
+ for _ in test_iter:
+ i += 1
+ # test the third epoch with size 6
+ assert i == 6
+ # test shuffle option for sanity test
+ test_iter = imageiter_list[4]
+ for _ in test_iter:
+ pass
+
class TestImage(unittest.TestCase):
IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz"
@@ -151,86 +208,32 @@ class TestImage(unittest.TestCase):
assert_almost_equal(mx_result.asnumpy(), (src - mean) / std,
atol=1e-3)
def test_imageiter(self):
- def check_imageiter(dtype='float32'):
- im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
- fname = './data/test_imageiter.lst'
- file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
- for k, x in enumerate(TestImage.IMAGES)]
- with open(fname, 'w') as f:
- for line in file_list:
- f.write(line + '\n')
-
- test_list = ['imglist', 'path_imglist']
+ im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
+ fname = './data/test_imageiter.lst'
+ file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
+ for k, x in enumerate(TestImage.IMAGES)]
+ with open(fname, 'w') as f:
+ for line in file_list:
+ f.write(line + '\n')
+ test_list = ['imglist', 'path_imglist']
+ for dtype in ['int32', 'float32', 'int64', 'float64']:
for test in test_list:
imglist = im_list if test == 'imglist' else None
path_imglist = fname if test == 'path_imglist' else None
-
- test_iter = mx.image.ImageIter(2, (3, 224, 224),
label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype)
- # test batch data shape
- for _ in range(3):
- for batch in test_iter:
- assert batch.data[0].shape == (2, 3, 224, 224)
- test_iter.reset()
- # test last batch handle(discard)
- test_iter = mx.image.ImageIter(3, (3, 224, 224),
label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='discard')
- i = 0
- for batch in test_iter:
- i += 1
- assert i == 5
- # test last_batch_handle(pad)
- test_iter = mx.image.ImageIter(3, (3, 224, 224),
label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='pad')
- i = 0
- for batch in test_iter:
- if i == 0:
- first_three_data = batch.data[0][:2]
- if i == 5:
- last_three_data = batch.data[0][1:]
- i += 1
- assert i == 6
- assert np.array_equal(first_three_data.asnumpy(),
last_three_data.asnumpy())
- # test last_batch_handle(roll_over)
- test_iter = mx.image.ImageIter(3, (3, 224, 224),
label_width=1, imglist=imglist,
- path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='roll_over')
- i = 0
- for batch in test_iter:
- if i == 0:
- first_image = batch.data[0][0]
- i += 1
- assert i == 5
- test_iter.reset()
- first_batch_roll_over = test_iter.next()
- assert np.array_equal(
- first_batch_roll_over.data[0][1].asnumpy(),
first_image.asnumpy())
- assert first_batch_roll_over.pad == 2
- # test iteratopr work properly after calling reset several
times when last_batch_handle is roll_over
- for _ in test_iter:
- pass
- test_iter.reset()
- first_batch_roll_over_twice = test_iter.next()
- assert np.array_equal(
- first_batch_roll_over_twice.data[0][2].asnumpy(),
first_image.asnumpy())
- assert first_batch_roll_over_twice.pad == 1
- # we've called next once
- i = 1
- for _ in test_iter:
- i += 1
- # test the third epoch with size 6
- assert i == 6
- # test shuffle option for sanity test
- test_iter = mx.image.ImageIter(3, (3, 224, 224),
label_width=1, imglist=imglist, shuffle=True,
- path_imglist=path_imglist,
path_root='', dtype=dtype, last_batch_handle='pad')
- for _ in test_iter:
- pass
-
- for dtype in ['int32', 'float32', 'int64', 'float64']:
- check_imageiter(dtype)
-
- # test with default dtype
- check_imageiter()
+ imageiter_list = [
+ mx.image.ImageIter(2, (3, 224, 224), label_width=1,
imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1,
imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='discard'),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1,
imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='pad'),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1,
imglist=imglist,
+ path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='roll_over'),
+ mx.image.ImageIter(3, (3, 224, 224), label_width=1,
imglist=imglist, shuffle=True,
+ path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='pad')
+ ]
+ _test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224))
@with_seed()
def test_augmenters(self):
@@ -259,16 +262,20 @@ class TestImage(unittest.TestCase):
im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list,
path_root='')
for _ in range(3):
- for batch in det_iter:
+ for _ in det_iter:
pass
- det_iter.reset()
-
+ det_iter.reset()
val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list,
path_root='')
det_iter = val_iter.sync_label_shape(det_iter)
assert det_iter.data_shape == val_iter.data_shape
assert det_iter.label_shape == val_iter.label_shape
- # test file list
+ # test batch_size is not divisible by number of images
+ det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list,
path_root='')
+ for _ in det_iter:
+ pass
+
+ # test file list with last batch handle
fname = './data/test_imagedetiter.lst'
im_list = [[k] + _generate_objects() + [x] for k, x in
enumerate(TestImage.IMAGES)]
with open(fname, 'w') as f:
@@ -276,10 +283,19 @@ class TestImage(unittest.TestCase):
line = '\t'.join([str(k) for k in line])
f.write(line + '\n')
- det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
- path_root='')
- for batch in det_iter:
- pass
+ imageiter_list = [
+ mx.image.ImageDetIter(2, (3, 400, 400),
+ path_imglist=fname, path_root=''),
+ mx.image.ImageDetIter(3, (3, 400, 400),
+ path_imglist=fname, path_root='', last_batch_handle='discard'),
+ mx.image.ImageDetIter(3, (3, 400, 400),
+ path_imglist=fname, path_root='', last_batch_handle='pad'),
+ mx.image.ImageDetIter(3, (3, 400, 400),
+ path_imglist=fname, path_root='',
last_batch_handle='roll_over'),
+ mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True,
+ path_imglist=fname, path_root='', last_batch_handle='pad')
+ ]
+ _test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400))
def test_det_augmenters(self):
# only test if all augmenters will work