szha closed pull request #9645: Fix DataBatch.__str__ for cases where we don't have labels. URL: https://github.com/apache/incubator-mxnet/pull/9645
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/python/mxnet/io.py b/python/mxnet/io.py index b07f7c1bea..201414e8f6 100644 --- a/python/mxnet/io.py +++ b/python/mxnet/io.py @@ -168,7 +168,10 @@ def __init__(self, data, label=None, pad=None, index=None, def __str__(self): data_shapes = [d.shape for d in self.data] - label_shapes = [l.shape for l in self.label] + if self.label: + label_shapes = [l.shape for l in self.label] + else: + label_shapes = None return "{}: data shapes: {} label shapes: {}".format( self.__class__.__name__, data_shapes, diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py index e8aba38b82..58ca1d74fb 100644 --- a/tests/python/unittest/test_io.py +++ b/tests/python/unittest/test_io.py @@ -252,6 +252,17 @@ def check_libSVMIter_news_data(): check_libSVMIter_synthetic() check_libSVMIter_news_data() + +def test_DataBatch(): + from nose.tools import ok_ + from mxnet.io import DataBatch + import re + batch = DataBatch(data=[mx.nd.ones((2,3))]) + ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: None', str(batch))) + batch = DataBatch(data=[mx.nd.ones((2,3)), mx.nd.ones((7,8))], label=[mx.nd.ones((4,5))]) + ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label shapes: \[\(4L?, 5L?\)\]', str(batch))) + + @unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/7826") def test_CSVIter(): def check_CSVIter_synthetic(): ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services