This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 494a642  fix symbolblock (#8050)
494a642 is described below

commit 494a642d31ada9a3c52d580778f0b31abce419c0
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Wed Sep 27 10:05:26 2017 -0700

    fix symbolblock (#8050)
---
 python/mxnet/gluon/block.py         |  6 +++---
 tests/python/unittest/test_gluon.py | 22 +++++++++++++++++++++-
 2 files changed, 24 insertions(+), 4 deletions(-)

diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index a16a515..04d23a6 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -487,8 +487,8 @@ class SymbolBlock(HybridBlock):
         self._params = ParameterDict('', params)
         if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) == 
1:
             inputs = [inputs]
-        if isinstance(outputs, symbol.Symbol) and len(outputs.list_outputs()) 
== 1:
-            outputs = [outputs]
+        if isinstance(outputs, (list, tuple)) and len(outputs) == 1:
+            outputs = outputs[0]
 
         syms, self._in_format = _flatten(inputs)
         out, self._out_format = _flatten(outputs)
@@ -523,7 +523,7 @@ class SymbolBlock(HybridBlock):
         assert in_fmt == self._in_format, "Invalid input format"
         ret = copy.copy(self._cached_graph[1])
         ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0], 
args)})
-        return _regroup(ret, self._out_format)[0]
+        return _regroup(list(ret), self._out_format)[0]
 
     def hybrid_forward(self, F, x, *args, **kwargs):
         raise NotImplementedError
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index ed25e71..5432e17 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -130,7 +130,27 @@ def test_symbol_block():
     assert len(smodel(mx.nd.zeros((16, 10)))) == 14
 
     out = smodel(mx.sym.var('in'))
-    assert len(out.get_internals().list_outputs()) == 
len(outputs.list_outputs())
+    assert len(out) == len(outputs.list_outputs())
+
+    class Net(nn.HybridBlock):
+        def __init__(self, model):
+            super(Net, self).__init__()
+            self.model = model
+
+        def hybrid_forward(self, F, x):
+            out = self.model(x)
+            return F.add_n(*[i.sum() for i in out])
+
+    net = Net(smodel)
+    net.hybridize()
+    assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
+
+    inputs = mx.sym.var('data')
+    outputs = model(inputs)
+    smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params())
+    net = Net(smodel)
+    net.hybridize()
+    assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
 
 
 def check_layer_forward(layer, dshape):

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to