This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new fe5b56e  Fix DataBatch.__str__ for cases where we don't have labels. 
(#9645)
fe5b56e is described below

commit fe5b56e419d454dc8f42f0307f53ced133804ca7
Author: Pedro Larroy <928489+lar...@users.noreply.github.com>
AuthorDate: Sat Feb 3 08:25:20 2018 +0100

    Fix DataBatch.__str__ for cases where we don't have labels. (#9645)
---
 python/mxnet/io.py               |  5 ++++-
 tests/python/unittest/test_io.py | 11 +++++++++++
 2 files changed, 15 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/io.py b/python/mxnet/io.py
index b07f7c1..201414e 100644
--- a/python/mxnet/io.py
+++ b/python/mxnet/io.py
@@ -168,7 +168,10 @@ class DataBatch(object):
 
     def __str__(self):
         data_shapes = [d.shape for d in self.data]
-        label_shapes = [l.shape for l in self.label]
+        if self.label:
+            label_shapes = [l.shape for l in self.label]
+        else:
+            label_shapes = None
         return "{}: data shapes: {} label shapes: {}".format(
             self.__class__.__name__,
             data_shapes,
diff --git a/tests/python/unittest/test_io.py b/tests/python/unittest/test_io.py
index e8aba38..58ca1d7 100644
--- a/tests/python/unittest/test_io.py
+++ b/tests/python/unittest/test_io.py
@@ -252,6 +252,17 @@ def test_LibSVMIter():
     check_libSVMIter_synthetic()
     check_libSVMIter_news_data()
 
+
+def test_DataBatch():
+    from nose.tools import ok_
+    from mxnet.io import DataBatch
+    import re
+    batch = DataBatch(data=[mx.nd.ones((2,3))])
+    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\)\] label shapes: 
None', str(batch)))
+    batch = DataBatch(data=[mx.nd.ones((2,3)), mx.nd.ones((7,8))], 
label=[mx.nd.ones((4,5))])
+    ok_(re.match('DataBatch: data shapes: \[\(2L?, 3L?\), \(7L?, 8L?\)\] label 
shapes: \[\(4L?, 5L?\)\]', str(batch)))
+
+
 @unittest.skip("test fails intermittently. temporarily disabled till it gets 
fixed. tracked at https://github.com/apache/incubator-mxnet/issues/7826";)
 def test_CSVIter():
     def check_CSVIter_synthetic():

-- 
To stop receiving notification emails like this one, please contact
zhash...@apache.org.

Reply via email to