This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 6d71577 Fix shape inference bug (#7682) 6d71577 is described below commit 6d7157768e216bcb4b505f98555fc72286d160fb Author: reminisce <wujun....@gmail.com> AuthorDate: Thu Aug 31 14:11:14 2017 -0700 Fix shape inference bug (#7682) --- src/executor/infer_graph_attr_pass.cc | 5 ++++- tests/python/unittest/test_symbol.py | 25 +++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/src/executor/infer_graph_attr_pass.cc b/src/executor/infer_graph_attr_pass.cc index 144c371..76b95f3 100644 --- a/src/executor/infer_graph_attr_pass.cc +++ b/src/executor/infer_graph_attr_pass.cc @@ -160,7 +160,10 @@ nnvm::Graph InferAttr(nnvm::Graph &&ret, uint32_t eid = idx.entry_id(nid, igrad[i].index); if (fis_none(rshape[eid])) { rshape[eid] = rshape[idx.entry_id(fnode.inputs[i])]; - } else { + } else if (!fis_none(rshape[idx.entry_id(fnode.inputs[i])])) { + // Need to skip empty forward shape, because it may not be + // available now and it is possible to infer the forward + // shape in one of the next a few passes CHECK_EQ(rshape[eid], rshape[idx.entry_id(fnode.inputs[i])]) << "Backward shape inconsistent with the forward shape"; } diff --git a/tests/python/unittest/test_symbol.py b/tests/python/unittest/test_symbol.py index 4d162ec..4a2cdb3 100644 --- a/tests/python/unittest/test_symbol.py +++ b/tests/python/unittest/test_symbol.py @@ -286,6 +286,31 @@ def test_zero_prop2(): assert False +def test_simple_bind_special_case(): + """This is a special case that results in shape inference + failure after moving simple_bind logic from frontend to backend. + Added here for testing against the network similar to the following one. + + Network diagram: + weight --> abs_op --> sum_op -- + |--> add_op + data --> fc_op --> sum_op -- + + Given data's shape, if the shape inference starts from weight node, + then the node entries of negative_op and sum_op are unknown in the + forward pass. Therefore, there are several unknown shapes after the + first forward pass is done. Now the backward inference pass starts with + the assumption that there are no unknown-shape node entries in the forward + pass, and consequently, leads to CHECK_EQ failure. + """ + data_shape = (5, 13) + data = mx.sym.Variable('data') + fc = mx.sym.FullyConnected(data=data, num_hidden=1, no_bias=True, name='fc') + modified_weight = mx.sym.abs(fc.get_internals()['fc_weight']) + net = mx.sym.sum(modified_weight) + mx.sym.sum(fc) + net.simple_bind(ctx=mx.cpu(), data=data_shape) + + if __name__ == '__main__': import nose nose.runmodule() -- To stop receiving notification emails like this one, please contact ['"comm...@mxnet.apache.org" <comm...@mxnet.apache.org>'].