This is an automated email from the ASF dual-hosted git repository.
jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 494a642 fix symbolblock (#8050)
494a642 is described below
commit 494a642d31ada9a3c52d580778f0b31abce419c0
Author: Eric Junyuan Xie <[email protected]>
AuthorDate: Wed Sep 27 10:05:26 2017 -0700
fix symbolblock (#8050)
---
python/mxnet/gluon/block.py | 6 +++---
tests/python/unittest/test_gluon.py | 22 +++++++++++++++++++++-
2 files changed, 24 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index a16a515..04d23a6 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -487,8 +487,8 @@ class SymbolBlock(HybridBlock):
self._params = ParameterDict('', params)
if isinstance(inputs, symbol.Symbol) and len(inputs.list_outputs()) ==
1:
inputs = [inputs]
- if isinstance(outputs, symbol.Symbol) and len(outputs.list_outputs())
== 1:
- outputs = [outputs]
+ if isinstance(outputs, (list, tuple)) and len(outputs) == 1:
+ outputs = outputs[0]
syms, self._in_format = _flatten(inputs)
out, self._out_format = _flatten(outputs)
@@ -523,7 +523,7 @@ class SymbolBlock(HybridBlock):
assert in_fmt == self._in_format, "Invalid input format"
ret = copy.copy(self._cached_graph[1])
ret._compose(**{k.name: v for k, v in zip(self._cached_graph[0],
args)})
- return _regroup(ret, self._out_format)[0]
+ return _regroup(list(ret), self._out_format)[0]
def hybrid_forward(self, F, x, *args, **kwargs):
raise NotImplementedError
diff --git a/tests/python/unittest/test_gluon.py
b/tests/python/unittest/test_gluon.py
index ed25e71..5432e17 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -130,7 +130,27 @@ def test_symbol_block():
assert len(smodel(mx.nd.zeros((16, 10)))) == 14
out = smodel(mx.sym.var('in'))
- assert len(out.get_internals().list_outputs()) ==
len(outputs.list_outputs())
+ assert len(out) == len(outputs.list_outputs())
+
+ class Net(nn.HybridBlock):
+ def __init__(self, model):
+ super(Net, self).__init__()
+ self.model = model
+
+ def hybrid_forward(self, F, x):
+ out = self.model(x)
+ return F.add_n(*[i.sum() for i in out])
+
+ net = Net(smodel)
+ net.hybridize()
+ assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
+
+ inputs = mx.sym.var('data')
+ outputs = model(inputs)
+ smodel = gluon.SymbolBlock(outputs, inputs, params=model.collect_params())
+ net = Net(smodel)
+ net.hybridize()
+ assert isinstance(net(mx.nd.zeros((16, 10))), mx.nd.NDArray)
def check_layer_forward(layer, dshape):
--
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].