This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 90954ec  make TransposeShape infer shape form both sides (#15713)
90954ec is described below

commit 90954ec64c284d0e3a0757c992e7be0c30a0423e
Author: dtracz <41399548+dtr...@users.noreply.github.com>
AuthorDate: Thu Aug 1 14:41:13 2019 -0700

    make TransposeShape infer shape form both sides (#15713)
    
    * make TransposeShape infer shape form both sides
    
    * small fixes
    
    * remove redundant lines
    
    * unit tests
---
 src/operator/tensor/matrix_op-inl.h    | 19 +++++++++++++++++--
 tests/python/unittest/test_operator.py | 20 ++++++++++++++++++++
 2 files changed, 37 insertions(+), 2 deletions(-)

diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 5cd7bf6..cd98cb0 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -344,19 +344,34 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
   mxnet::TShape& shp = (*in_attrs)[0];
+  mxnet::TShape& out_shp = (*out_attrs)[0];
   CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
-  mxnet::TShape ret(shp.ndim(), -1);
+  CHECK_NE(shp.ndim(), 0) << "Number of dimensions cannot be 0";
+  CHECK_NE(out_shp.ndim(), 0) << "Number of dimensions cannot be 0";
+  if (shp.ndim() == -1 && out_shp.ndim() == -1)
+    return false;  // none of the shapes is known
+  if (out_shp.ndim() > 0 && shp.ndim() > 0)
+    CHECK_EQ(out_shp.ndim(), shp.ndim());
+  mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
+  mxnet::TShape ret(std::max(shp.ndim(), out_shp.ndim()), -1);
   if (param.axes.ndim() == 0) {
     for (int i = 0; i < shp.ndim(); ++i) {
       ret[i] = shp[shp.ndim()-1-i];
     }
+    for (int i = 0; i < out_shp.ndim(); ++i) {
+      get[shp.ndim()-1-i] = out_shp[i];
+    }
   } else {
-    CHECK_EQ(shp.ndim(), param.axes.ndim());
+    CHECK_EQ(std::max(shp.ndim(), out_shp.ndim()), param.axes.ndim());
     for (int i = 0; i < shp.ndim(); ++i) {
       CHECK(param.axes[i] < static_cast<int64_t>(shp.ndim()));
       ret[i] = shp[param.axes[i]];
     }
+    for (int i = 0; i < out_shp.ndim(); ++i) {
+      get[param.axes[i]] = out_shp[i];
+    }
   }
+  SHAPE_ASSIGN_CHECK(*in_attrs, 0, get);
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, ret);
   return shape_is_known(ret);
 }
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index e8c9d6c..8f1c253 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -8970,6 +8970,26 @@ def test_get_operator_arguments():
     ok_(operator_arguments.narg == 2)
 
 
+def test_transpose_infer_shape_back():
+    o1 = mx.sym.ones(shape=[2,3])
+    o2 = mx.sym.ones(shape=[-1,-1])
+    t = mx.sym.transpose(o2)
+    b = o1 + t
+    x = b.bind(mx.cpu(), args={})
+    y = x.forward()
+    assert(y[0].shape == (2,3))
+
+
+def test_transpose_infer_shape_mixed():
+    o1 = mx.sym.ones(shape=[2,-1])
+    o2 = mx.sym.ones(shape=[3,-1])
+    t = mx.sym.transpose(o2)
+    b = o1 + t
+    x = b.bind(mx.cpu(), args={})
+    y = x.forward()
+    assert(y[0].shape == (2,3))
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to