eric-haibin-lin closed pull request #11956: extend reshape op to allow reverse 
shape inference
URL: https://github.com/apache/incubator-mxnet/pull/11956
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 118af679315..ed513c0d778 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -443,6 +443,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
   API_BEGIN();
   NDArray *arr = static_cast<NDArray*>(handle);
   nnvm::Tuple<dim_t> shape(dims, dims+ndim);
+  CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined. 
Input shape: "
+    << arr->shape();
   TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(), 
reverse);
   *ptr = arr->ReshapeWithRecord(new_shape);
   *out = ptr;
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index eec920555ed..78e1fa1d9c6 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -122,7 +122,7 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& 
shape,
       CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
       if (d1 == -1) d1 = d0 / d2;
       if (d2 == -1) d2 = d0 / d1;
-      CHECK_EQ(d1 * d2, static_cast<IType>(d0)) <<
+      CHECK(d1 * d2 == static_cast<IType>(d0) || static_cast<IType>(d0) == 
IType(0)) <<
         "Split dims " << d1 << ", " << d2 << " do not divide original dim " << 
d0;
       tmp.push_back(d1);
       tmp.push_back(d2);
@@ -151,13 +151,36 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>& 
shape,
   return oshape;
 }
 
+inline bool ReverseReshapeInferShape(TShape *in, const TShape& out) {
+  if (in->Size() && out.Size()) {
+    return true;
+  } else if (!out.Size()) {
+    return false;
+  } else {
+    int zero_axis = -1;
+    int non_zero_prod = 1;
+    for (index_t i = 0; i < in->ndim(); i++) {
+      if ((*in)[i] == 0) {
+        if (zero_axis != -1)
+          return false;  // more than 1 zero found.
+        else
+          zero_axis = i;
+      } else {
+        non_zero_prod *= (*in)[i];
+      }
+    }
+    (*in)[zero_axis] = out.Size() / non_zero_prod;
+    return true;
+  }
+}
+
 inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
-                             std::vector<TShape> *in_attrs,
-                             std::vector<TShape> *out_attrs) {
+                         std::vector<TShape> *in_attrs,
+                         std::vector<TShape> *out_attrs) {
   const ReshapeParam& param_ = nnvm::get<ReshapeParam>(attrs.parsed);
   CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
   CHECK_EQ(out_attrs->size(), 1U);
-  const TShape &dshape = (*in_attrs)[0];
+  TShape &dshape = (*in_attrs)[0];
   if (dshape.ndim() == 0) return false;
   TShape oshape;
   if (param_.shape.ndim() != 0) {
@@ -182,14 +205,15 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
       oshape[inf_idx] = dshape.Size() / oshape.Size();
     }
   } else {
-    return (*out_attrs)[0].ndim();
+    return (*out_attrs)[0].ndim() && ReverseReshapeInferShape(&(*in_attrs)[0], 
(*out_attrs)[0]);
   }
+  ReverseReshapeInferShape(&dshape, oshape);
   CHECK_EQ(oshape.Size(), dshape.Size())
     << "Target shape size is different to source. "
     << "Target: " << oshape
     << "\nSource: " << dshape;
   SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
-  return true;
+  return ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]);
 }
 
 inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index fa5de0c68c7..938cffefab5 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1943,11 +1943,11 @@ def test_bxor(a, b):
     test_bmul(a, b)
     test_bdiv(a, b)
     '''
-    Flaky Test Disabled due to master build failure: 
-    
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
 
+    Flaky Test Disabled due to master build failure:
+    
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
     Github Issue: https://github.com/apache/incubator-mxnet/issues/11838
-    
-    test_bmod(a, b) 
+
+    test_bmod(a, b)
     '''
     test_bmod_int(a, b)
     test_bpow(a, b)
@@ -2065,6 +2065,23 @@ def test_reshape_new(src_shape, shape_args, reverse, 
dst_shape):
         assert np.square(exe.grad_dict['data'].asnumpy() - 
grad_npy.reshape(src_shape)).mean() < 1E-7, \
             'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape = 
%s'\
             %(str(src_shape), str(shape_args), str(reverse), str(dst_shape))
+
+        for i in range(len(src_shape)):
+            holdout_src_shape = list(src_shape)
+            holdout_src_shape[i] = 0
+            holdout_src_shape = tuple(holdout_src_shape)
+            net = mx.sym.Variable('data')
+            net = mx.sym.elemwise_add(net.reshape(shape_args, 
reverse=reverse), mx.sym.ones(shape=dst_shape))
+            input_shape, output_shape, __ = 
net.infer_shape(data=holdout_src_shape)
+            assert output_shape[0] == dst_shape, \
+                'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, 
Dst Shape = %s, ' \
+                'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), 
str(reverse),
+                                      str(dst_shape), str(output_shape[0]))
+            assert input_shape[0] == src_shape, \
+                'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s, 
Dst Shape = %s, ' \
+                'Output Shape = %s' %(str(holdout_src_shape), str(shape_args), 
str(reverse),
+                                      str(dst_shape), str(output_shape[0]))
+
     # Test new api (Using shape)
     test_cases = [
         [(2, 3, 5, 5),  (0, -1),          False, (2, 75)],
@@ -6614,7 +6631,7 @@ def test_diag():
     w = np.random.randint(2,9)
     a_np = np.random.random((h, w)).astype(np.float32)
     a = mx.nd.array(a_np).astype('float32')
-    
+
     # k == 0
     r = mx.nd.diag(a)
     assert_almost_equal(r.asnumpy(), np.diag(a_np))
@@ -6657,7 +6674,7 @@ def test_diag():
     d = np.random.randint(2,9)
     a_np = np.random.random((d))
     a = mx.nd.array(a_np)
-    
+
     # k is random
     k = np.random.randint(-d,d)
     r = mx.nd.diag(a, k=k)
@@ -6724,7 +6741,7 @@ def test_invalid_block_size():
         invalid_shape_inp = (n , c, h, w)
         data = rand_ndarray(invalid_shape_inp, 'default')
         assertRaises(MXNetError, mx.nd.depth_to_space, data, block)
-        
+
     test_invalid_depth_dim()
     test_invalid_space_dim()
     test_invalid_block_size()
@@ -6770,12 +6787,12 @@ def test_invalid_block_size():
         invalid_shape_inp = (n, c, h, w)
         data = rand_ndarray(invalid_shape_inp, 'default')
         assertRaises(MXNetError, mx.nd.space_to_depth, data, block)
-    
+
     def test_invalid_depth_dim():
         invalid_shape_inp = (n, 0, h, w)
         data = rand_ndarray(invalid_shape_inp, 'default')
         assertRaises(MXNetError, mx.nd.space_to_depth, data, block)
-    
+
     test_invalid_space_dim()
     test_invalid_block_size()
     test_invalid_depth_dim()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to