zhreshold closed pull request #13550: Fixes infinite loop using imagedetiter
URL: https://github.com/apache/incubator-mxnet/pull/13550
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index b27917c8623..d5b5ecab528 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -658,19 +658,26 @@ class ImageDetIter(ImageIter):
         Data name for provided symbols.
     label_name : str
         Name for detection labels
+    last_batch_handle : str, optional
+        How to handle the last batch.
+        This parameter can be 'pad'(default), 'discard' or 'roll_over'.
+        If 'pad', the last batch will be padded with data starting from the 
begining
+        If 'discard', the last batch will be discarded
+        If 'roll_over', the remaining elements will be rolled over to the next 
iteration
     kwargs : ...
         More arguments for creating augmenter. See mx.image.CreateDetAugmenter.
     """
     def __init__(self, batch_size, data_shape,
                  path_imgrec=None, path_imglist=None, path_root=None, 
path_imgidx=None,
                  shuffle=False, part_index=0, num_parts=1, aug_list=None, 
imglist=None,
-                 data_name='data', label_name='label', **kwargs):
+                 data_name='data', label_name='label', 
last_batch_handle='pad', **kwargs):
         super(ImageDetIter, self).__init__(batch_size=batch_size, 
data_shape=data_shape,
                                            path_imgrec=path_imgrec, 
path_imglist=path_imglist,
                                            path_root=path_root, 
path_imgidx=path_imgidx,
                                            shuffle=shuffle, 
part_index=part_index,
                                            num_parts=num_parts, aug_list=[], 
imglist=imglist,
-                                           data_name=data_name, 
label_name=label_name)
+                                           data_name=data_name, 
label_name=label_name,
+                                           last_batch_handle=last_batch_handle)
 
         if aug_list is None:
             self.auglist = CreateDetAugmenter(data_shape, **kwargs)
@@ -751,14 +758,10 @@ def reshape(self, data_shape=None, label_shape=None):
             self.provide_label = [(self.provide_label[0][0], 
(self.batch_size,) + label_shape)]
             self.label_shape = label_shape
 
-    def next(self):
-        """Override the function for returning next batch."""
+    def _batchify(self, batch_data, batch_label, start=0):
+        """Override the helper function for batchifying data"""
+        i = start
         batch_size = self.batch_size
-        c, h, w = self.data_shape
-        batch_data = nd.zeros((batch_size, c, h, w))
-        batch_label = nd.empty(self.provide_label[0][1])
-        batch_label[:] = -1
-        i = 0
         try:
             while i < batch_size:
                 label, s = self.next_sample()
@@ -783,7 +786,48 @@ def next(self):
             if not i:
                 raise StopIteration
 
-        return io.DataBatch([batch_data], [batch_label], batch_size - i)
+        return i
+
+    def next(self):
+        """Override the function for returning next batch."""
+        batch_size = self.batch_size
+        c, h, w = self.data_shape
+        # if last batch data is rolled over
+        if self._cache_data is not None:
+            # check both the data and label have values
+            assert self._cache_label is not None, "_cache_label didn't have 
values"
+            assert self._cache_idx is not None, "_cache_idx didn't have values"
+            batch_data = self._cache_data
+            batch_label = self._cache_label
+            i = self._cache_idx
+        else:
+            batch_data = nd.zeros((batch_size, c, h, w))
+            batch_label = nd.empty(self.provide_label[0][1])
+            batch_label[:] = -1
+            i = self._batchify(batch_data, batch_label)
+        # calculate the padding
+        pad = batch_size - i
+        # handle padding for the last batch
+        if pad != 0:
+            if self.last_batch_handle == 'discard':
+                raise StopIteration
+            # if the option is 'roll_over', throw StopIteration and cache the 
data
+            elif self.last_batch_handle == 'roll_over' and \
+                self._cache_data is None:
+                self._cache_data = batch_data
+                self._cache_label = batch_label
+                self._cache_idx = i
+                raise StopIteration
+            else:
+                _ = self._batchify(batch_data, batch_label, i)
+                if self.last_batch_handle == 'pad':
+                    self._allow_read = False
+                else:
+                    self._cache_data = None
+                    self._cache_label = None
+                    self._cache_idx = None
+
+        return io.DataBatch([batch_data], [batch_label], pad=pad)
 
     def augmentation_transform(self, data, label):  # pylint: 
disable=arguments-differ
         """Override Transforms input data with specified augmentations."""
diff --git a/python/mxnet/image/image.py b/python/mxnet/image/image.py
index c9a457f5b7e..9c2a1cbfba2 100644
--- a/python/mxnet/image/image.py
+++ b/python/mxnet/image/image.py
@@ -1145,7 +1145,7 @@ def __init__(self, batch_size, data_shape, label_width=1,
         self.shuffle = shuffle
         if self.imgrec is None:
             self.seq = imgkeys
-        elif shuffle or num_parts > 1:
+        elif shuffle or num_parts > 1 or path_imgidx:
             assert self.imgidx is not None
             self.seq = self.imgidx
         else:
@@ -1261,7 +1261,7 @@ def next(self):
             i = self._cache_idx
             # clear the cache data
         else:
-            batch_data = nd.empty((batch_size, c, h, w))
+            batch_data = nd.zeros((batch_size, c, h, w))
             batch_label = nd.empty(self.provide_label[0][1])
             i = self._batchify(batch_data, batch_label)
         # calculate the padding
@@ -1285,6 +1285,7 @@ def next(self):
                     self._cache_data = None
                     self._cache_label = None
                     self._cache_idx = None
+
         return io.DataBatch([batch_data], [batch_label], pad=pad)
 
     def check_data_shape(self, data_shape):
diff --git a/tests/python/unittest/test_image.py 
b/tests/python/unittest/test_image.py
index 4f66823cdbf..4063027cc1e 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -25,6 +25,7 @@
 
 from nose.tools import raises
 
+
 def _get_data(url, dirname):
     import os, tarfile
     download(url, dirname=dirname, overwrite=False)
@@ -50,6 +51,62 @@ def _generate_objects():
     label = np.hstack((cid[:, np.newaxis], boxes)).ravel().tolist()
     return [2, 5] + label
 
+def _test_imageiter_last_batch(imageiter_list, assert_data_shape):
+    test_iter = imageiter_list[0]
+    # test batch data shape
+    for _ in range(3):
+        for batch in test_iter:
+            assert batch.data[0].shape == assert_data_shape
+        test_iter.reset()
+    # test last batch handle(discard)
+    test_iter = imageiter_list[1]
+    i = 0
+    for batch in test_iter:
+        i += 1
+    assert i == 5
+    # test last_batch_handle(pad)
+    test_iter = imageiter_list[2]
+    i = 0
+    for batch in test_iter:
+        if i == 0:
+            first_three_data = batch.data[0][:2]
+        if i == 5:
+            last_three_data = batch.data[0][1:]
+        i += 1
+    assert i == 6
+    assert np.array_equal(first_three_data.asnumpy(), 
last_three_data.asnumpy())
+    # test last_batch_handle(roll_over)
+    test_iter = imageiter_list[3]
+    i = 0
+    for batch in test_iter:
+        if i == 0:
+            first_image = batch.data[0][0]
+        i += 1
+    assert i == 5
+    test_iter.reset()
+    first_batch_roll_over = test_iter.next()
+    assert np.array_equal(
+        first_batch_roll_over.data[0][1].asnumpy(), first_image.asnumpy())
+    assert first_batch_roll_over.pad == 2
+    # test iteratopr work properly after calling reset several times when 
last_batch_handle is roll_over
+    for _ in test_iter:
+        pass
+    test_iter.reset()
+    first_batch_roll_over_twice = test_iter.next()
+    assert np.array_equal(
+        first_batch_roll_over_twice.data[0][2].asnumpy(), 
first_image.asnumpy())
+    assert first_batch_roll_over_twice.pad == 1
+    # we've called next once
+    i = 1
+    for _ in test_iter:
+        i += 1
+    # test the third epoch with size 6
+    assert i == 6
+    # test shuffle option for sanity test
+    test_iter = imageiter_list[4]
+    for _ in test_iter:
+        pass
+
 
 class TestImage(unittest.TestCase):
     IMAGES_URL = "http://data.mxnet.io/data/test_images.tar.gz";
@@ -151,86 +208,32 @@ def test_color_normalize(self):
             assert_almost_equal(mx_result.asnumpy(), (src - mean) / std, 
atol=1e-3)
 
     def test_imageiter(self):
-        def check_imageiter(dtype='float32'):
-            im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
-            fname = './data/test_imageiter.lst'
-            file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
-                         for k, x in enumerate(TestImage.IMAGES)]
-            with open(fname, 'w') as f:
-                for line in file_list:
-                    f.write(line + '\n')
-
-            test_list = ['imglist', 'path_imglist']
+        im_list = [[np.random.randint(0, 5), x] for x in TestImage.IMAGES]
+        fname = './data/test_imageiter.lst'
+        file_list = ['\t'.join([str(k), str(np.random.randint(0, 5)), x])
+                        for k, x in enumerate(TestImage.IMAGES)]
+        with open(fname, 'w') as f:
+            for line in file_list:
+                f.write(line + '\n')
 
+        test_list = ['imglist', 'path_imglist']
+        for dtype in ['int32', 'float32', 'int64', 'float64']:
             for test in test_list:
                 imglist = im_list if test == 'imglist' else None
                 path_imglist = fname if test == 'path_imglist' else None
-
-                test_iter = mx.image.ImageIter(2, (3, 224, 224), 
label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype)
-                # test batch data shape
-                for _ in range(3):
-                    for batch in test_iter:
-                        assert batch.data[0].shape == (2, 3, 224, 224)
-                    test_iter.reset()
-                # test last batch handle(discard)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='discard')
-                i = 0
-                for batch in test_iter:
-                    i += 1
-                assert i == 5
-                # test last_batch_handle(pad)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='pad')
-                i = 0
-                for batch in test_iter:
-                    if i == 0:
-                        first_three_data = batch.data[0][:2]
-                    if i == 5:
-                        last_three_data = batch.data[0][1:]
-                    i += 1
-                assert i == 6
-                assert np.array_equal(first_three_data.asnumpy(), 
last_three_data.asnumpy())
-                # test last_batch_handle(roll_over)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist,
-                    path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='roll_over')
-                i = 0
-                for batch in test_iter:
-                    if i == 0:
-                        first_image = batch.data[0][0]
-                    i += 1
-                assert i == 5
-                test_iter.reset()
-                first_batch_roll_over = test_iter.next()
-                assert np.array_equal(
-                    first_batch_roll_over.data[0][1].asnumpy(), 
first_image.asnumpy())
-                assert first_batch_roll_over.pad == 2
-                # test iteratopr work properly after calling reset several 
times when last_batch_handle is roll_over
-                for _ in test_iter:
-                    pass
-                test_iter.reset()
-                first_batch_roll_over_twice = test_iter.next()
-                assert np.array_equal(
-                    first_batch_roll_over_twice.data[0][2].asnumpy(), 
first_image.asnumpy())
-                assert first_batch_roll_over_twice.pad == 1
-                # we've called next once
-                i = 1
-                for _ in test_iter:
-                    i += 1
-                # test the third epoch with size 6
-                assert i == 6
-                # test shuffle option for sanity test
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist, shuffle=True,
-                                               path_imglist=path_imglist, 
path_root='', dtype=dtype, last_batch_handle='pad')
-                for _ in test_iter:
-                    pass
-
-        for dtype in ['int32', 'float32', 'int64', 'float64']:
-            check_imageiter(dtype)
-
-        # test with default dtype
-        check_imageiter()
+                imageiter_list = [
+                    mx.image.ImageIter(2, (3, 224, 224), label_width=1, 
imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, 
imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='discard'),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, 
imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='pad'),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, 
imglist=imglist,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='roll_over'),
+                    mx.image.ImageIter(3, (3, 224, 224), label_width=1, 
imglist=imglist, shuffle=True,
+                        path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='pad')
+                ]
+                _test_imageiter_last_batch(imageiter_list, (2, 3, 224, 224))
 
     @with_seed()
     def test_augmenters(self):
@@ -259,16 +262,20 @@ def test_image_detiter(self):
         im_list = [_generate_objects() + [x] for x in TestImage.IMAGES]
         det_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, 
path_root='')
         for _ in range(3):
-            for batch in det_iter:
+            for _ in det_iter:
                 pass
-            det_iter.reset()
-
+        det_iter.reset()
         val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, 
path_root='')
         det_iter = val_iter.sync_label_shape(det_iter)
         assert det_iter.data_shape == val_iter.data_shape
         assert det_iter.label_shape == val_iter.label_shape
 
-        # test file list
+        # test batch_size is not divisible by number of images
+        det_iter = mx.image.ImageDetIter(4, (3, 300, 300), imglist=im_list, 
path_root='')
+        for _ in det_iter:
+            pass
+
+        # test file list with last batch handle
         fname = './data/test_imagedetiter.lst'
         im_list = [[k] + _generate_objects() + [x] for k, x in 
enumerate(TestImage.IMAGES)]
         with open(fname, 'w') as f:
@@ -276,10 +283,19 @@ def test_image_detiter(self):
                 line = '\t'.join([str(k) for k in line])
                 f.write(line + '\n')
 
-        det_iter = mx.image.ImageDetIter(2, (3, 400, 400), path_imglist=fname,
-            path_root='')
-        for batch in det_iter:
-            pass
+        imageiter_list = [
+            mx.image.ImageDetIter(2, (3, 400, 400),
+                path_imglist=fname, path_root=''),
+            mx.image.ImageDetIter(3, (3, 400, 400),
+                path_imglist=fname, path_root='', last_batch_handle='discard'),
+            mx.image.ImageDetIter(3, (3, 400, 400),
+                path_imglist=fname, path_root='', last_batch_handle='pad'),
+            mx.image.ImageDetIter(3, (3, 400, 400),
+                path_imglist=fname, path_root='', 
last_batch_handle='roll_over'),
+            mx.image.ImageDetIter(3, (3, 400, 400), shuffle=True,
+                path_imglist=fname, path_root='', last_batch_handle='pad')
+        ]
+        _test_imageiter_last_batch(imageiter_list, (2, 3, 400, 400))
 
     def test_det_augmenters(self):
         # only test if all augmenters will work


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to