This is an automated email from the ASF dual-hosted git repository.

haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 51f650e  extend reshape op to allow reverse shape inference (#11956)
51f650e is described below

commit 51f650e0bf3b905fec4aebfc873c1c56eac61536
Author: Sheng Zha <[email protected]>
AuthorDate: Tue Jul 31 16:58:21 2018 -0700

    extend reshape op to allow reverse shape inference (#11956)
---
 src/c_api/c_api.cc                     |  2 ++
 src/operator/tensor/matrix_op-inl.h    | 36 ++++++++++++++++++++++++++++------
 tests/python/unittest/test_operator.py | 35 ++++++++++++++++++++++++---------
 3 files changed, 58 insertions(+), 15 deletions(-)

diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 118af67..ed513c0 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -443,6 +443,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
   API_BEGIN();
   NDArray *arr = static_cast<NDArray*>(handle);
   nnvm::Tuple<dim_t> shape(dims, dims+ndim);
+  CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. 
Input shape: "
+    << arr->shape();
   TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), 
reverse);
   *ptr = arr->ReshapeWithRecord(new_shape);
   *out = ptr;
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index eec9205..78e1fa1 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -122,7 +122,7 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& 
shape,
       CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
       if (d1 == -1) d1 = d0 / d2;
       if (d2 == -1) d2 = d0 / d1;
-      CHECK_EQ(d1 * d2, static_cast<IType>(d0)) <<
+      CHECK(d1 * d2 == static_cast<IType>(d0) || static_cast<IType>(d0) == 
IType(0)) <<
         "Split dims " << d1 << ", " << d2 << " do not divide original dim " << 
d0;
       tmp.push_back(d1);
       tmp.push_back(d2);
@@ -151,13 +151,36 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& 
shape,
   return oshape;
 }
 
+inline bool ReverseReshapeInferShape(TShape *in, const TShape& out) {
+  if (in->Size() && out.Size()) {
+    return true;
+  } else if (!out.Size()) {
+    return false;
+  } else {
+    int zero_axis = -1;
+    int non_zero_prod = 1;
+    for (index_t i = 0; i < in->ndim(); i++) {
+      if ((*in)[i] == 0) {
+        if (zero_axis != -1)
+          return false;  // more than 1 zero found.
+        else
+          zero_axis = i;
+      } else {
+        non_zero_prod *= (*in)[i];
+      }
+    }
+    (*in)[zero_axis] = out.Size() / non_zero_prod;
+    return true;
+  }
+}
+
 inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
-                             std::vector<TShape> *in_attrs,
-                             std::vector<TShape> *out_attrs) {
+                         std::vector<TShape> *in_attrs,
+                         std::vector<TShape> *out_attrs) {
   const ReshapeParam& param_ = nnvm::get<ReshapeParam>(attrs.parsed);
   CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
   CHECK_EQ(out_attrs->size(), 1U);
-  const TShape &dshape = (*in_attrs)[0];
+  TShape &dshape = (*in_attrs)[0];
   if (dshape.ndim() == 0) return false;
   TShape oshape;
   if (param_.shape.ndim() != 0) {
@@ -182,14 +205,15 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
       oshape[inf_idx] = dshape.Size() / oshape.Size();
     }
   } else {
-    return (*out_attrs)[0].ndim();
+    return (*out_attrs)[0].ndim() && ReverseReshapeInferShape(&(*in_attrs)[0], 
(*out_attrs)[0]);
   }
+  ReverseReshapeInferShape(&dshape, oshape);
   CHECK_EQ(oshape.Size(), dshape.Size())
     << "Target shape size is different to source. "
     << "Target: " << oshape
     << "\nSource: " << dshape;
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
-  return true;
+  return ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]);
 }
 
 inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 99d635e..12d0bd1 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1943,11 +1943,11 @@ def test_broadcast_binary_op():
     test_bmul(a, b)
     test_bdiv(a, b)
     '''
-    Flaky Test Disabled due to master build failure: 
-    
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
 
+    Flaky Test Disabled due to master build failure:
+    
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
     Github Issue: https://github.com/apache/incubator-mxnet/issues/11838
-    
-    test_bmod(a, b) 
+
+    test_bmod(a, b)
     '''
     test_bmod_int(a, b)
     test_bpow(a, b)
@@ -2065,6 +2065,23 @@ def test_reshape():
         assert np.square(exe.grad_dict['data'].asnumpy() - 
grad_npy.reshape(src_shape)).mean() < 1E-7, \
             'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = 
%s'\
             %(str(src_shape), str(shape_args), str(reverse), str(dst_shape))
+
+        for i in range(len(src_shape)):
+            holdout_src_shape = list(src_shape)
+            holdout_src_shape[i] = 0
+            holdout_src_shape = tuple(holdout_src_shape)
+            net = mx.sym.Variable('data')
+            net = mx.sym.elemwise_add(net.reshape(shape_args, 
reverse=reverse), mx.sym.ones(shape=dst_shape))
+            input_shape, output_shape, __ = 
net.infer_shape(data=holdout_src_shape)
+            assert output_shape[0] == dst_shape, \
+                'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, 
Dst Shape = %s, ' \
+                'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), 
str(reverse),
+                                      str(dst_shape), str(output_shape[0]))
+            assert input_shape[0] == src_shape, \
+                'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, 
Dst Shape = %s, ' \
+                'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), 
str(reverse),
+                                      str(dst_shape), str(output_shape[0]))
+
     # Test new api (Using shape)
     test_cases = [
         [(2, 3, 5, 5),  (0, -1),          False, (2, 75)],
@@ -6615,7 +6632,7 @@ def test_diag():
     w = np.random.randint(2,9)
     a_np = np.random.random((h, w)).astype(np.float32)
     a = mx.nd.array(a_np).astype('float32')
-    
+
     # k == 0
     r = mx.nd.diag(a)
     assert_almost_equal(r.asnumpy(), np.diag(a_np))
@@ -6658,7 +6675,7 @@ def test_diag():
     d = np.random.randint(2,9)
     a_np = np.random.random((d))
     a = mx.nd.array(a_np)
-    
+
     # k is random
     k = np.random.randint(-d,d)
     r = mx.nd.diag(a, k=k)
@@ -6725,7 +6742,7 @@ def test_depthtospace():
         invalid_shape_inp = (n , c, h, w)
         data = rand_ndarray(invalid_shape_inp, 'default')
         assertRaises(MXNetError, mx.nd.depth_to_space, data, block)
-        
+
     test_invalid_depth_dim()
     test_invalid_space_dim()
     test_invalid_block_size()
@@ -6771,12 +6788,12 @@ def test_spacetodepth():
         invalid_shape_inp = (n, c, h, w)
         data = rand_ndarray(invalid_shape_inp, 'default')
         assertRaises(MXNetError, mx.nd.space_to_depth, data, block)
-    
+
     def test_invalid_depth_dim():
         invalid_shape_inp = (n, 0, h, w)
         data = rand_ndarray(invalid_shape_inp, 'default')
         assertRaises(MXNetError, mx.nd.space_to_depth, data, block)
-    
+
     test_invalid_space_dim()
     test_invalid_block_size()
     test_invalid_depth_dim()

Reply via email to