zhreshold closed pull request #13085: ImageDetIter: fix property not updating 
bug
URL: https://github.com/apache/incubator-mxnet/pull/13085
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/image/detection.py b/python/mxnet/image/detection.py
index 3b9f64e1220..b27917c8623 100644
--- a/python/mxnet/image/detection.py
+++ b/python/mxnet/image/detection.py
@@ -745,9 +745,11 @@ def reshape(self, data_shape=None, label_shape=None):
         if data_shape is not None:
             self.check_data_shape(data_shape)
             self.provide_data = [(self.provide_data[0][0], (self.batch_size,) 
+ data_shape)]
+            self.data_shape = data_shape
         if label_shape is not None:
             self.check_label_shape(label_shape)
             self.provide_label = [(self.provide_label[0][0], 
(self.batch_size,) + label_shape)]
+            self.label_shape = label_shape
 
     def next(self):
         """Override the function for returning next batch."""
diff --git a/tests/python/unittest/test_image.py 
b/tests/python/unittest/test_image.py
index c8022b67bee..6dccba1892c 100644
--- a/tests/python/unittest/test_image.py
+++ b/tests/python/unittest/test_image.py
@@ -154,14 +154,14 @@ def check_imageiter(dtype='float32'):
             with open(fname, 'w') as f:
                 for line in file_list:
                     f.write(line + '\n')
-            
+
             test_list = ['imglist', 'path_imglist']
 
             for test in test_list:
                 imglist = im_list if test == 'imglist' else None
                 path_imglist = fname if test == 'path_imglist' else None
-                
-                test_iter = mx.image.ImageIter(2, (3, 224, 224), 
label_width=1, imglist=imglist, 
+
+                test_iter = mx.image.ImageIter(2, (3, 224, 224), 
label_width=1, imglist=imglist,
                     path_imglist=path_imglist, path_root='', dtype=dtype)
                 # test batch data shape
                 for _ in range(3):
@@ -176,7 +176,7 @@ def check_imageiter(dtype='float32'):
                     i += 1
                 assert i == 5
                 # test last_batch_handle(pad)
-                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist, 
+                test_iter = mx.image.ImageIter(3, (3, 224, 224), 
label_width=1, imglist=imglist,
                     path_imglist=path_imglist, path_root='', dtype=dtype, 
last_batch_handle='pad')
                 i = 0
                 for batch in test_iter:
@@ -260,6 +260,8 @@ def test_image_detiter(self):
 
         val_iter = mx.image.ImageDetIter(2, (3, 300, 300), imglist=im_list, 
path_root='')
         det_iter = val_iter.sync_label_shape(det_iter)
+        assert det_iter.data_shape == val_iter.data_shape
+        assert det_iter.label_shape == val_iter.label_shape
 
         # test file list
         fname = './data/test_imagedetiter.lst'


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to