piiswrong closed pull request #10573: Fix gluon error message.
URL: https://github.com/apache/incubator-mxnet/pull/10573
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index 2f8cdd80fc7..0f415436116 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -84,7 +84,7 @@ def __exit__(self, ptype, value, trace):
         _BlockScope._current = self._old_scope
 
 
-def _flatten(args):
+def _flatten(args, inout_str):
     if isinstance(args, NDArray):
         return [args], int(0)
     if isinstance(args, Symbol):
@@ -93,12 +93,12 @@ def _flatten(args):
         return [args], int(length)
 
     assert isinstance(args, (list, tuple)), \
-        "HybridBlock input must be (nested) list of Symbol or NDArray, " \
-        "but got %s of type %s"%(str(args), str(type(args)))
+        "HybridBlock %s must be (nested) list of Symbol or NDArray, " \
+        "but got %s of type %s"%(inout_str, str(args), str(type(args)))
     flat = []
     fmts = []
     for i in args:
-        arg, fmt = _flatten(i)
+        arg, fmt = _flatten(i, inout_str)
         flat.extend(arg)
         fmts.append(fmt)
     return flat, fmts
@@ -461,7 +461,7 @@ def __setattr__(self, name, value):
 
     def _get_graph(self, *args):
         if not self._cached_graph:
-            args, self._in_format = _flatten(args)
+            args, self._in_format = _flatten(args, "input")
             if len(args) > 1:
                 inputs = [symbol.var('data%d'%i) for i in range(len(args))]
             else:
@@ -471,7 +471,7 @@ def _get_graph(self, *args):
             params = {i: j.var() for i, j in self._reg_params.items()}
             with self.name_scope():
                 out = self.hybrid_forward(symbol, *grouped_inputs, **params)  
# pylint: disable=no-value-for-parameter
-            out, self._out_format = _flatten(out)
+            out, self._out_format = _flatten(out, "output")
 
             self._cached_graph = inputs, symbol.Group(out)
 
@@ -521,7 +521,7 @@ def _call_cached_op(self, *args):
         if self._cached_op is None:
             self._build_cache(*args)
 
-        args, fmt = _flatten(args)
+        args, fmt = _flatten(args, "input")
         assert fmt == self._in_format, "Invalid input format"
         cargs = [args[i] if is_arg else i.data()
                  for is_arg, i in self._cached_op_args]
@@ -558,7 +558,7 @@ def cast(self, dtype):
     def _infer_attrs(self, infer_fn, attr, *args):
         """Generic infer attributes."""
         inputs, out = self._get_graph(*args)
-        args, _ = _flatten(args)
+        args, _ = _flatten(args, "input")
         with warnings.catch_warnings(record=True) as w:
             arg_attrs, _, aux_attrs = getattr(out, infer_fn)(
                 **{i.name: getattr(j, attr) for i, j in zip(inputs, args)})
@@ -690,8 +690,8 @@ def __init__(self, outputs, inputs, params=None):
         if isinstance(outputs, (list, tuple)) and len(outputs) == 1:
             outputs = outputs[0]
 
-        syms, self._in_format = _flatten(inputs)
-        out, self._out_format = _flatten(outputs)
+        syms, self._in_format = _flatten(inputs, "input")
+        out, self._out_format = _flatten(outputs, "output")
         out = symbol.Group(out)
 
         input_names = set()
@@ -724,7 +724,7 @@ def forward(self, x, *args):
         assert isinstance(x, Symbol), \
             "HybridBlock requires the first argument to forward be either " \
             "Symbol or NDArray, but got %s"%type(x)
-        args, in_fmt = _flatten([x] + list(args))
+        args, in_fmt = _flatten([x] + list(args), "input")
         assert in_fmt == self._in_format, "Invalid input format"
         ret = copy.copy(self._cached_graph[1])
         ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], 
args)})


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to