This is an automated email from the ASF dual-hosted git repository.
zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e434251 fix property not updating bug (#13085)
e434251 is described below
commit e43425154419a384d3320d830131e907e837af06
Author: Joshua Z. Zhang <[email protected]>
AuthorDate: Thu Nov 29 13:21:09 2018 -0800
fix property not updating bug (#13085)
---
python/mxnet/image/detection.py | 2 ++
tests/python/unittest/test_image.py | 10 ++++++----
2 files changed, 8 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index 3b9f64e..b27917c 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -745,9 +745,11 @@ class ImageDetIter(ImageIter):
if data_shape is not None:
self.check_data_shape(data_shape)
self.provide_data = [(self.provide_data[0][0], (self.batch_size,)
+ data_shape)]
+ self.data_shape = data_shape
if label_shape is not None:
self.check_label_shape(label_shape)
self.provide_label = [(self.provide_label[0][0],
(self.batch_size,) + label_shape)]
+ self.label_shape = label_shape
def next(self):
"""Override the function for returning next batch."""
diff --git a/tests/python/unittest/test_image.py
b/tests/python/unittest/test_image.py
index 3e35d6d..4f66823 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -159,14 +159,14 @@ class TestImage(unittest.TestCase):
with open(fname, 'w') as f:
for line in file_list:
f.write(line + '\n')
-
+
test_list = ['imglist', 'path_imglist']
for test in test_list:
imglist = im_list if test == 'imglist' else None
path_imglist = fname if test == 'path_imglist' else None
-
- test_iter = mx.image.ImageIter(2, (3, 224, 224),
label_width=1, imglist=imglist,
+
+ test_iter = mx.image.ImageIter(2, (3, 224, 224),
label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype)
# test batch data shape
for _ in range(3):
@@ -181,7 +181,7 @@ class TestImage(unittest.TestCase):
i += 1
assert i == 5
# test last_batch_handle(pad)
- test_iter = mx.image.ImageIter(3, (3, 224, 224),
label_width=1, imglist=imglist,
+ test_iter = mx.image.ImageIter(3, (3, 224, 224),
label_width=1, imglist=imglist,
path_imglist=path_imglist, path_root='', dtype=dtype,
last_batch_handle='pad')
i = 0
for batch in test_iter:
@@ -265,6 +265,8 @@ class TestImage(unittest.TestCase):
val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list,
path_root='')
det_iter = val_iter.sync_label_shape(det_iter)
+ assert det_iter.data_shape == val_iter.data_shape
+ assert det_iter.label_shape == val_iter.label_shape
# test file list
fname = './data/test_imagedetiter.lst'