This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 51f650e extend reshape op to allow reverse shape inference (#11956)
51f650e is described below
commit 51f650e0bf3b905fec4aebfc873c1c56eac61536
Author: Sheng Zha <[email protected]>
AuthorDate: Tue Jul 31 16:58:21 2018 -0700
extend reshape op to allow reverse shape inference (#11956)
---
src/c_api/c_api.cc | 2 ++
src/operator/tensor/matrix_op-inl.h | 36 ++++++++++++++++++++++++++++------
tests/python/unittest/test_operator.py | 35 ++++++++++++++++++++++++---------
3 files changed, 58 insertions(+), 15 deletions(-)
diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc
index 118af67..ed513c0 100644
--- a/src/c_api/c_api.cc
+++ b/src/c_api/c_api.cc
@@ -443,6 +443,8 @@ MXNET_DLL int MXNDArrayReshape64(NDArrayHandle handle,
API_BEGIN();
NDArray *arr = static_cast<NDArray*>(handle);
nnvm::Tuple<dim_t> shape(dims, dims+ndim);
+ CHECK_GT(arr->shape().Size(), 0) << "Source ndarray's shape is undefined.
Input shape: "
+ << arr->shape();
TShape new_shape = mxnet::op::InferReshapeShape(shape, arr->shape(),
reverse);
*ptr = arr->ReshapeWithRecord(new_shape);
*out = ptr;
diff --git a/src/operator/tensor/matrix_op-inl.h
b/src/operator/tensor/matrix_op-inl.h
index eec9205..78e1fa1 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -122,7 +122,7 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>&
shape,
CHECK(d1 != -1 || d2 != -1) << "Split dims cannot both be -1.";
if (d1 == -1) d1 = d0 / d2;
if (d2 == -1) d2 = d0 / d1;
- CHECK_EQ(d1 * d2, static_cast<IType>(d0)) <<
+ CHECK(d1 * d2 == static_cast<IType>(d0) || static_cast<IType>(d0) ==
IType(0)) <<
"Split dims " << d1 << ", " << d2 << " do not divide original dim " <<
d0;
tmp.push_back(d1);
tmp.push_back(d2);
@@ -151,13 +151,36 @@ inline TShape InferReshapeShape(const nnvm::Tuple<IType>&
shape,
return oshape;
}
+inline bool ReverseReshapeInferShape(TShape *in, const TShape& out) {
+ if (in->Size() && out.Size()) {
+ return true;
+ } else if (!out.Size()) {
+ return false;
+ } else {
+ int zero_axis = -1;
+ int non_zero_prod = 1;
+ for (index_t i = 0; i < in->ndim(); i++) {
+ if ((*in)[i] == 0) {
+ if (zero_axis != -1)
+ return false; // more than 1 zero found.
+ else
+ zero_axis = i;
+ } else {
+ non_zero_prod *= (*in)[i];
+ }
+ }
+ (*in)[zero_axis] = out.Size() / non_zero_prod;
+ return true;
+ }
+}
+
inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
- std::vector<TShape> *in_attrs,
- std::vector<TShape> *out_attrs) {
+ std::vector<TShape> *in_attrs,
+ std::vector<TShape> *out_attrs) {
const ReshapeParam& param_ = nnvm::get<ReshapeParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U) << "Input: [data]";
CHECK_EQ(out_attrs->size(), 1U);
- const TShape &dshape = (*in_attrs)[0];
+ TShape &dshape = (*in_attrs)[0];
if (dshape.ndim() == 0) return false;
TShape oshape;
if (param_.shape.ndim() != 0) {
@@ -182,14 +205,15 @@ inline bool ReshapeShape(const nnvm::NodeAttrs& attrs,
oshape[inf_idx] = dshape.Size() / oshape.Size();
}
} else {
- return (*out_attrs)[0].ndim();
+ return (*out_attrs)[0].ndim() && ReverseReshapeInferShape(&(*in_attrs)[0],
(*out_attrs)[0]);
}
+ ReverseReshapeInferShape(&dshape, oshape);
CHECK_EQ(oshape.Size(), dshape.Size())
<< "Target shape size is different to source. "
<< "Target: " << oshape
<< "\nSource: " << dshape;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
- return true;
+ return ReverseReshapeInferShape(&(*in_attrs)[0], (*out_attrs)[0]);
}
inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 99d635e..12d0bd1 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1943,11 +1943,11 @@ def test_broadcast_binary_op():
test_bmul(a, b)
test_bdiv(a, b)
'''
- Flaky Test Disabled due to master build failure:
-
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
+ Flaky Test Disabled due to master build failure:
+
http://jenkins.mxnet-ci.amazon-ml.com/blue/organizations/jenkins/incubator-mxnet/detail/master/1248/pipeline
Github Issue: https://github.com/apache/incubator-mxnet/issues/11838
-
- test_bmod(a, b)
+
+ test_bmod(a, b)
'''
test_bmod_int(a, b)
test_bpow(a, b)
@@ -2065,6 +2065,23 @@ def test_reshape():
assert np.square(exe.grad_dict['data'].asnumpy() -
grad_npy.reshape(src_shape)).mean() < 1E-7, \
'Src Shape = %s, Shape Arguments = %s, Reverse = %s, Dst Shape =
%s'\
%(str(src_shape), str(shape_args), str(reverse), str(dst_shape))
+
+ for i in range(len(src_shape)):
+ holdout_src_shape = list(src_shape)
+ holdout_src_shape[i] = 0
+ holdout_src_shape = tuple(holdout_src_shape)
+ net = mx.sym.Variable('data')
+ net = mx.sym.elemwise_add(net.reshape(shape_args,
reverse=reverse), mx.sym.ones(shape=dst_shape))
+ input_shape, output_shape, __ =
net.infer_shape(data=holdout_src_shape)
+ assert output_shape[0] == dst_shape, \
+ 'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s,
Dst Shape = %s, ' \
+ 'Output Shape = %s' %(str(holdout_src_shape), str(shape_args),
str(reverse),
+ str(dst_shape), str(output_shape[0]))
+ assert input_shape[0] == src_shape, \
+ 'Holdout Src Shape = %s, Shape Arguments = %s, Reverse = %s,
Dst Shape = %s, ' \
+ 'Output Shape = %s' %(str(holdout_src_shape), str(shape_args),
str(reverse),
+ str(dst_shape), str(output_shape[0]))
+
# Test new api (Using shape)
test_cases = [
[(2, 3, 5, 5), (0, -1), False, (2, 75)],
@@ -6615,7 +6632,7 @@ def test_diag():
w = np.random.randint(2,9)
a_np = np.random.random((h, w)).astype(np.float32)
a = mx.nd.array(a_np).astype('float32')
-
+
# k == 0
r = mx.nd.diag(a)
assert_almost_equal(r.asnumpy(), np.diag(a_np))
@@ -6658,7 +6675,7 @@ def test_diag():
d = np.random.randint(2,9)
a_np = np.random.random((d))
a = mx.nd.array(a_np)
-
+
# k is random
k = np.random.randint(-d,d)
r = mx.nd.diag(a, k=k)
@@ -6725,7 +6742,7 @@ def test_depthtospace():
invalid_shape_inp = (n , c, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.depth_to_space, data, block)
-
+
test_invalid_depth_dim()
test_invalid_space_dim()
test_invalid_block_size()
@@ -6771,12 +6788,12 @@ def test_spacetodepth():
invalid_shape_inp = (n, c, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)
-
+
def test_invalid_depth_dim():
invalid_shape_inp = (n, 0, h, w)
data = rand_ndarray(invalid_shape_inp, 'default')
assertRaises(MXNetError, mx.nd.space_to_depth, data, block)
-
+
test_invalid_space_dim()
test_invalid_block_size()
test_invalid_depth_dim()