szha closed pull request #9645: Fix DataBatch.__str__ for cases where we don't 
have labels.

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/ b/python/mxnet/
index b07f7c1bea..201414e8f6 100644
--- a/python/mxnet/
+++ b/python/mxnet/
@@ -168,7 +168,10 @@ def __init__(self, data, label=None, pad=None, index=None,
     def __str__(self):
         data_shapes = [d.shape for d in]
-        label_shapes = [l.shape for l in self.label]
+        if self.label:
+            label_shapes = [l.shape for l in self.label]
+        else:
+            label_shapes = None
         return "{}: data shapes: {} label shapes: {}".format(
diff --git a/tests/python/unittest/ b/tests/python/unittest/
index e8aba38b82..58ca1d74fb 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -252,6 +252,17 @@ def check_libSVMIter_news_data():
+def test_DataBatch():
+    from import ok_
+    from import DataBatch
+    import re
+    batch = DataBatch(data=[mx.nd.ones((2,3))])
+    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: 
None', str(batch)))
+    batch = DataBatch(data=[mx.nd.ones((2,3)), mx.nd.ones((7,8))], 
+    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label 
shapes: \[\(4L?, 5L?\)\]', str(batch)))
 @unittest.skip("test fails intermittently. temporarily disabled till it gets 
fixed. tracked at";)
 def test_CSVIter():
     def check_CSVIter_synthetic():


This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:

With regards,
Apache Git Services

Reply via email to