This is an automated email from the ASF dual-hosted git repository.

zhreshold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e434251  fix property not updating bug (#13085)
e434251 is described below

commit e43425154419a384d3320d830131e907e837af06
Author: Joshua Z. Zhang <[email protected]>
AuthorDate: Thu Nov 29 13:21:09 2018 -0800

    fix property not updating bug (#13085)
---
 python/mxnet/image/detection.py     |  2 ++
 tests/python/unittest/test_image.py | 10 ++++++----
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index 3b9f64e..b27917c 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -745,9 +745,11 @@ class ImageDetIter(ImageIter):
         if data_shape is not None:
             self.check_data_shape(data_shape)
             self.provide_data = [(self.provide_data[0][0], (self.batch_size,) 
+ data_shape)]
+            self.data_shape = data_shape
         if label_shape is not None:
             self.check_label_shape(label_shape)
             self.provide_label = [(self.provide_label[0][0], 
(self.batch_size,) + label_shape)]
+            self.label_shape = label_shape
 
     def next(self):
         """Override the function for returning next batch."""
diff --git a/tests/python/unittest/test_image.py 
b/tests/python/unittest/test_image.py
index 3e35d6d..4f66823 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -159,14 +159,14 @@ class TestImage(unittest.TestCase):
             with open(fname, 'w') as f:
                 for line in file_list:
                     f.write(line + '\n')
-            
+
             test_list = ['imglist', 'path_imglist']
 
             for test in test_list:
                 imglist = im_list if test == 'imglist' else None
                 path_imglist = fname if test == 'path_imglist' else None
-                
-                test_iter = mx.image.ImageIter(2, (3, 224, 224), 
label_width=1, imglist=imglist, 
+
+                test_iter = mx.image.ImageIter(2, (3, 224, 224), 
label_width=1, imglist=imglist,
                     path_imglist=path_imglist, path_root='', dtype=dtype)
                 # test batch data shape
                 for _ in range(3):
@@ -181,7 +181,7 @@ class TestImage(unittest.TestCase):
                     i += 1
                 assert i == 5
                 # test last_batch_handle(pad)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist, 
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist,
                     path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='pad')
                 i = 0
                 for batch in test_iter:
@@ -265,6 +265,8 @@ class TestImage(unittest.TestCase):
 
         val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, 
path_root='')
         det_iter = val_iter.sync_label_shape(det_iter)
+        assert det_iter.data_shape == val_iter.data_shape
+        assert det_iter.label_shape == val_iter.label_shape
 
         # test file list
         fname = './data/test_imagedetiter.lst'

Reply via email to