This is an automated email from the ASF dual-hosted git repository. ptrendx pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 90954ec make TransposeShape infer shape form both sides (#15713) 90954ec is described below commit 90954ec64c284d0e3a0757c992e7be0c30a0423e Author: dtracz <41399548+dtr...@users.noreply.github.com> AuthorDate: Thu Aug 1 14:41:13 2019 -0700 make TransposeShape infer shape form both sides (#15713) * make TransposeShape infer shape form both sides * small fixes * remove redundant lines * unit tests --- src/operator/tensor/matrix_op-inl.h | 19 +++++++++++++++++-- tests/python/unittest/test_operator.py | 20 ++++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/src/operator/tensor/matrix_op-inl.h b/src/operator/tensor/matrix_op-inl.h index 5cd7bf6..cd98cb0 100644 --- a/src/operator/tensor/matrix_op-inl.h +++ b/src/operator/tensor/matrix_op-inl.h @@ -344,19 +344,34 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape& shp = (*in_attrs)[0]; + mxnet::TShape& out_shp = (*out_attrs)[0]; CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions"; - mxnet::TShape ret(shp.ndim(), -1); + CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0"; + CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0"; + if (shp.ndim() == -1 && out_shp.ndim() == -1) + return false; // none of the shapes is known + if (out_shp.ndim() > 0 && shp.ndim() > 0) + CHECK_EQ(out_shp.ndim(), shp.ndim()); + mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1); + mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1); if (param.axes.ndim() == 0) { for (int i = 0; i < shp.ndim(); ++i) { ret[i] = shp[shp.ndim()-1-i]; } + for (int i = 0; i < out_shp.ndim(); ++i) { + get[shp.ndim()-1-i] = out_shp[i]; + } } else { - CHECK_EQ(shp.ndim(), param.axes.ndim()); + CHECK_EQ(std::max(shp.ndim(), out_shp.ndim()), param.axes.ndim()); for (int i = 0; i < shp.ndim(); ++i) { CHECK(param.axes[i] < static_cast<int64_t>(shp.ndim())); ret[i] = shp[param.axes[i]]; } + for (int i = 0; i < out_shp.ndim(); ++i) { + get[param.axes[i]] = out_shp[i]; + } } + SHAPE_ASSIGN_CHECK(*in_attrs, 0, get); SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret); return shape_is_known(ret); } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e8c9d6c..8f1c253 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -8970,6 +8970,26 @@ def test_get_operator_arguments(): ok_(operator_arguments.narg == 2) +def test_transpose_infer_shape_back(): + o1 = mx.sym.ones(shape=[2,3]) + o2 = mx.sym.ones(shape=[-1,-1]) + t = mx.sym.transpose(o2) + b = o1 + t + x = b.bind(mx.cpu(), args={}) + y = x.forward() + assert(y[0].shape == (2,3)) + + +def test_transpose_infer_shape_mixed(): + o1 = mx.sym.ones(shape=[2,-1]) + o2 = mx.sym.ones(shape=[3,-1]) + t = mx.sym.transpose(o2) + b = o1 + t + x = b.bind(mx.cpu(), args={}) + y = x.forward() + assert(y[0].shape == (2,3)) + + if __name__ == '__main__': import nose nose.runmodule()