This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 6b568fd Backporting backward inference from 2.x #18348 and #18378
(#18895)
6b568fd is described below
commit 6b568fd2c10e8cc14c21d6150ba71933a2af2a41
Author: Serge Panev <[email protected]>
AuthorDate: Thu Aug 13 22:22:10 2020 -0700
Backporting backward inference from 2.x #18348 and #18378 (#18895)
Signed-off-by: Serge Panev <[email protected]>
---
src/operator/contrib/batch_norm_relu.cc | 41 ++++++++++++++++++++++-----------
src/operator/nn/batch_norm.cc | 40 +++++++++++++++++++++-----------
src/operator/nn/convolution.cc | 13 ++++++++---
src/operator/nn/deconvolution.cc | 18 ++++++++++++---
src/operator/nn/group_norm.cc | 8 +++----
src/operator/nn/layer_norm.cc | 7 +++---
src/operator/nn/pooling.cc | 7 ++++--
src/operator/softmax_output.cc | 18 ++++++++++++---
src/operator/tensor/matrix_op-inl.h | 22 ++++++++++++++----
9 files changed, 125 insertions(+), 49 deletions(-)
diff --git a/src/operator/contrib/batch_norm_relu.cc
b/src/operator/contrib/batch_norm_relu.cc
index 14452cc..51aa4c5 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs&
attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean,
MovingVar]";
CHECK_EQ(out_shape->size(), 4U);
const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData);
+ if (!mxnet::ndim_is_known(dshape)) {
+ return false;
+ }
const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
@@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs&
attrs,
const int channelCount = dshape[channelAxis];
- if (!mxnet::ndim_is_known(dshape)) {
- return false;
- }
-
in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnormrelu::kInMovingMean) =
mxnet::TShape(Shape1(channelCount)); // kMovingMean
@@ -84,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs&
attrs,
std::vector<int> *in_type, std::vector<int>
*out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
- const int dtype = (*in_type)[0];
- CHECK_NE(dtype, -1) << "First input must have specified type";
+ const size_t n_out = 4;
// For float16 input type beta, gamma, mean, and average are stored in
float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
- MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+ int dtype = (*in_type)[0];
+
+ if (type_is_none(dtype)) {
+ // Input type is undefined, we try backward inference
+ if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+ // Neither the input nor the output are defined,
+ // types cannot be infered for this op
+ return false;
+ } else {
+ // Input type is undefined but output type is: backward inference
+ dtype = (*out_type)[0];
+ (*in_type)[0] = dtype;
+ MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+ dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+ }
+ } else {
+ // Input type is defined but output type is not: forward inference
+ MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+ out_type->clear();
+ out_type->push_back(dtype);
+ for (size_t i = 1; i < n_out; ++i) {
+ out_type->push_back(dtype_param);
+ }
+ }
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
@@ -101,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs&
attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
- const size_t n_out = 4;
- out_type->clear();
- out_type->push_back(dtype);
- for (size_t i = 1; i < n_out; ++i) {
- out_type->push_back(dtype_param);
- }
return true;
}
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index a59f8ba..60f9553 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -365,6 +365,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean,
MovingVar]";
CHECK_EQ(out_shape->size(), 3U);
const mxnet::TShape &dshape = in_shape->at(batchnorm::kData);
+ if (!mxnet::ndim_is_known(dshape)) {
+ return false;
+ }
const size_t channelAxis = static_cast<size_t>(param.axis < 0
? static_cast<int>(dshape.ndim()) + param.axis
@@ -373,10 +376,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
const int channelCount = dshape[channelAxis];
- if (!mxnet::ndim_is_known(dshape)) {
- return false;
- }
-
in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
in_shape->at(batchnorm::kInMovingMean) =
mxnet::TShape(Shape1(channelCount)); // kMovingMean
@@ -394,14 +393,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_type, std::vector<int>
*out_type) {
using namespace mshadow;
CHECK_GE(in_type->size(), 1U);
- const int dtype = (*in_type)[0];
- CHECK_NE(dtype, -1) << "First input must have specified type";
+ const size_t n_out = 3;
// For float16 input type beta, gamma, mean, and average are stored in
float32.
// For other input types, these parameters have the same type as input
// NOTE: This requirement is from cuDNN (v. 4 and 5)
int dtype_param;
- MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+ int dtype = (*in_type)[0];
+ if (type_is_none(dtype)) {
+ // Input type is undefined, we try backward inference
+ if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+ // Neither the input nor the output are defined,
+ // types cannot be infered for this op
+ return false;
+ } else {
+ // Input type is undefined but output type is: backward inference
+ dtype = (*out_type)[0];
+ (*in_type)[0] = dtype;
+ MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+ dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+ }
+ } else {
+ // Input type is defined but output type is not: forward inference
+ MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+ out_type->clear();
+ out_type->push_back(dtype);
+ for (size_t i = 1; i < n_out; ++i) {
+ out_type->push_back(dtype_param);
+ }
+ }
std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
CHECK_LE(in_type->size(), args.size());
for (size_t i = 1; i < in_type->size(); ++i) {
@@ -411,12 +431,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
}
}
- const size_t n_out = 3;
- out_type->clear();
- out_type->push_back(dtype);
- for (size_t i = 1; i < n_out; ++i) {
- out_type->push_back(dtype_param);
- }
return true;
}
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 8ff5ea7..3ebb67a 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -285,7 +285,16 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
const ConvolutionParam& param_ = nnvm::get<ConvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
- CHECK_NE(dtype, -1) << "First input must have specified type";
+ if (type_is_none(dtype)) {
+ if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+ return false;
+ } else {
+ dtype = (*out_type)[0];
+ }
+ } else {
+ out_type->clear();
+ out_type->push_back(dtype);
+ }
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
@@ -293,8 +302,6 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
- out_type->clear();
- out_type->push_back(dtype);
return true;
}
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index cd22ace..08d6306 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -332,7 +332,21 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
const DeconvolutionParam& param_ =
nnvm::get<DeconvolutionParam>(attrs.parsed);
CHECK_GE(in_type->size(), 1U);
int dtype = (*in_type)[0];
- CHECK_NE(dtype, -1) << "First input must have specified type";
+ if (type_is_none(dtype)) {
+ // Input type is undefined, we try backward inference
+ if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+ // Neither the input nor the output are defined,
+ // types cannot be infered for this op
+ return false;
+ } else {
+ // Input type is undefined but output type is: backward inference
+ dtype = (*out_type)[0];
+ }
+ } else {
+ // Input type is defined but output type is not: forward inference
+ out_type->clear();
+ out_type->push_back(dtype);
+ }
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
@@ -340,8 +354,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
}
}
- out_type->clear();
- out_type->push_back(dtype);
return true;
}
diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc
index 6b8fe9b..a92ac31 100644
--- a/src/operator/nn/group_norm.cc
+++ b/src/operator/nn/group_norm.cc
@@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
- CHECK_GE(dshape.ndim(), 3U);
- const int num_groups = param.num_groups;
- CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by #
of groups";
-
if (!mxnet::ndim_is_known(dshape)) {
return false;
}
+ CHECK_GE(dshape.ndim(), 3U);
+ const int num_groups = param.num_groups;
+ CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by #
of groups";
+
in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index d385b93..c3ccd0d 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
using namespace mshadow;
CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
+ if (!mxnet::ndim_is_known(dshape)) {
+ return false;
+ }
+
int axis = GetRealAxis(param.axis, dshape.ndim());
CHECK(axis >= 0 && axis < dshape.ndim())
<< "Channel axis out of range: axis=" << param.axis;
const int channelCount = dshape[axis];
- if (!mxnet::ndim_is_known(dshape)) {
- return false;
- }
SHAPE_ASSIGN_CHECK(*in_shape,
layernorm::kGamma,
mxnet::TShape(Shape1(channelCount)));
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 03787f4..c81cae3 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -95,10 +95,14 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
mxnet::ShapeVector *out_shape) {
const PoolingParam ¶m = nnvm::get<PoolingParam>(attrs.parsed);
CHECK_EQ(in_shape->size(), 1U);
+ const mxnet::TShape &dshape = (*in_shape)[0];
+ if (!mxnet::ndim_is_known(dshape)) {
+ return false;
+ }
if (param.pool_type == pool_enum::kLpPooling) {
CHECK(param.p_value.has_value());
}
- const mxnet::TShape &dshape = (*in_shape)[0];
+
if (param.pooling_convention == pool_enum::kSame) {
CHECK_EQ(dshape.ndim(), 3U)
<< "Pooling: Input data should be 3D in (batch, channel, x)"
@@ -114,7 +118,6 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
<< "Pooling: Input data should be 3D in (batch, channel, x)"
<< " Or 4D in (batch, channel, y, x) "
<< " Or 5D in (batch, channel, d, y, x)";
- if (!mxnet::ndim_is_known(dshape)) return false;
int layout = param.GetLayout(dshape.ndim());
if (param.global_pool) {
mxnet::TShape oshape = dshape;
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 13bb647..d87b781 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -66,7 +66,21 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
std::vector<int> *out_type) {
CHECK_EQ(in_type->size(), 2U);
int dtype = (*in_type)[0];
- CHECK_NE(dtype, -1) << "First input must have specified type";
+ if (type_is_none(dtype)) {
+ // Input type is undefined, we try backward inference
+ if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+ // Neither the input nor the output are defined,
+ // types cannot be infered for this op
+ return false;
+ } else {
+ // Input type is undefined but output type is: backward inference
+ dtype = (*out_type)[0];
+ }
+ } else {
+ // Input type is defined but output type is not: forward inference
+ out_type->clear();
+ out_type->push_back(dtype);
+ }
for (size_t i = 0; i < in_type->size(); ++i) {
if ((*in_type)[i] == -1) {
(*in_type)[i] = dtype;
@@ -74,8 +88,6 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
}
}
- out_type->clear();
- out_type->push_back(dtype);
return true;
}
diff --git a/src/operator/tensor/matrix_op-inl.h
b/src/operator/tensor/matrix_op-inl.h
index fa7b8a1..217bf10 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -455,9 +455,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
mxnet::TShape& out_shp = (*out_attrs)[0];
- CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
- if (shp.ndim() == -1 && out_shp.ndim() == -1)
+ if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
return false; // none of the shapes is known
+ CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
CHECK_EQ(out_shp.ndim(), shp.ndim());
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
@@ -506,12 +506,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
- if (!mxnet::ndim_is_known(in_attrs->at(0)) &&
!mxnet::ndim_is_known(out_attrs->at(0))) {
+ mxnet::TShape& ishape = (*in_attrs)[0];
+ mxnet::TShape& oshape = (*out_attrs)[0];
+ if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) {
return false;
}
- mxnet::TShape& ishape = (*in_attrs)[0];
- mxnet::TShape& oshape = (*out_attrs)[0];
int indim = ishape.ndim();
bool unknown_ishape = false;
if (-1 == indim) {
@@ -1434,6 +1434,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& ishape = (*in_attrs)[0];
mxnet::TShape& from_shape = (*in_attrs)[1];
+ if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) {
+ return false;
+ }
if (param.axes.ndim() == 0) {
CHECK_EQ(ishape.ndim(), from_shape.ndim())
<< "By default slice_axis performs slice on all axes, but ndim mismatch "
@@ -1727,6 +1730,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
const mxnet::TShape& ishape = (*in_attrs)[0];
+ if (!mxnet::ndim_is_known(ishape)) {
+ return false;
+ }
int repeats = 0;
dmlc::optional<int> axisOpt;
GetRepeatParams(param, ishape, &repeats, &axisOpt);
@@ -2395,6 +2401,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs&
attrs,
mxnet::TShape expected_out(4, -1);
mxnet::TShape& in_shape = in_attrs->at(0);
+ if (!mxnet::ndim_is_known(in_shape)) {
+ return false;
+ }
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
@@ -2559,6 +2568,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs&
attrs,
mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
mxnet::TShape& in_shape = in_attrs->at(0);
+ if (!mxnet::ndim_is_known(in_shape)) {
+ return false;
+ }
int block = param.block_size;
CHECK_NE(block, 0) << "block_size must be a positive integer value";
CHECK_NE(in_shape[0], 0)