This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4623c6d  [MXNET-185] Improved error message (#10424)
4623c6d is described below

commit 4623c6d7c473f3f7a8ac5ca441f2c042576623d5
Author: Ankit Khedia <36249596+ankkhe...@users.noreply.github.com>
AuthorDate: Thu Apr 12 12:52:51 2018 -0700

    [MXNET-185] Improved error message (#10424)
    
    * Improved error message
    
    * fix pylint errors
    
    * addressed PR comments
    
    * fixed lint issues
    
    * lint error fix
    
    * append  warning message
    
    * lint errors
    
    * removed redundant message
    
    * fixed warning
    
    * Update block.py
---
 python/mxnet/gluon/block.py | 17 +++++++++++++----
 1 file changed, 13 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index fd95641..6400358 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -458,7 +458,13 @@ class HybridBlock(Block):
                                 for name in out.list_inputs()]
 
     def _finish_deferred_init(self, hybrid, *args):
-        self.infer_shape(*args)
+        try:
+            self.infer_shape(*args)
+        except Exception as e:
+            error_msg = "Deferred initialization failed because shape"\
+                        " cannot be inferred \n {}".format(e)
+            raise ValueError(error_msg)
+
         if hybrid:
             for is_arg, i in self._cached_op_args:
                 if not is_arg:
@@ -509,11 +515,14 @@ class HybridBlock(Block):
         """Generic infer attributes."""
         inputs, out = self._get_graph(*args)
         args, _ = _flatten(args)
-        arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
-            **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
+        with warnings.catch_warnings(record=True) as w:
+            arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
+                **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
+            if arg_attrs is None:
+                raise ValueError(w[0].message)
         sdict = {i: j for i, j in zip(out.list_arguments(), arg_attrs)}
         sdict.update({name : attr for name, attr in \
-                      zip(out.list_auxiliary_states(), aux_attrs)})
+             zip(out.list_auxiliary_states(), aux_attrs)})
         for i in self.collect_params().values():
             setattr(i, attr, sdict[i.name])
 

-- 
To stop receiving notification emails like this one, please contact
j...@apache.org.

Reply via email to