This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 6b568fd  Backporting backward inference from 2.x #18348 and #18378 
(#18895)
6b568fd is described below

commit 6b568fd2c10e8cc14c21d6150ba71933a2af2a41
Author: Serge Panev <[email protected]>
AuthorDate: Thu Aug 13 22:22:10 2020 -0700

    Backporting backward inference from 2.x #18348 and #18378 (#18895)
    
    Signed-off-by: Serge Panev <[email protected]>
---
 src/operator/contrib/batch_norm_relu.cc | 41 ++++++++++++++++++++++-----------
 src/operator/nn/batch_norm.cc           | 40 +++++++++++++++++++++-----------
 src/operator/nn/convolution.cc          | 13 ++++++++---
 src/operator/nn/deconvolution.cc        | 18 ++++++++++++---
 src/operator/nn/group_norm.cc           |  8 +++----
 src/operator/nn/layer_norm.cc           |  7 +++---
 src/operator/nn/pooling.cc              |  7 ++++--
 src/operator/softmax_output.cc          | 18 ++++++++++++---
 src/operator/tensor/matrix_op-inl.h     | 22 ++++++++++++++----
 9 files changed, 125 insertions(+), 49 deletions(-)

diff --git a/src/operator/contrib/batch_norm_relu.cc 
b/src/operator/contrib/batch_norm_relu.cc
index 14452cc..51aa4c5 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -55,6 +55,9 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& 
attrs,
   CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, 
MovingVar]";
   CHECK_EQ(out_shape->size(), 4U);
   const mxnet::TShape &dshape = in_shape->at(batchnormrelu::kData);
+  if (!mxnet::ndim_is_known(dshape)) {
+    return false;
+  }
 
   const size_t channelAxis = static_cast<size_t>(param.axis < 0
       ? static_cast<int>(dshape.ndim()) + param.axis
@@ -63,10 +66,6 @@ static bool BatchNormWithReLUShape(const nnvm::NodeAttrs& 
attrs,
 
   const int channelCount = dshape[channelAxis];
 
-  if (!mxnet::ndim_is_known(dshape)) {
-    return false;
-  }
-
   in_shape->at(batchnormrelu::kGamma) = mxnet::TShape(Shape1(channelCount));
   in_shape->at(batchnormrelu::kBeta) = mxnet::TShape(Shape1(channelCount));
   in_shape->at(batchnormrelu::kInMovingMean) = 
mxnet::TShape(Shape1(channelCount));  // kMovingMean
@@ -84,14 +83,36 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& 
attrs,
                                   std::vector<int> *in_type, std::vector<int> 
*out_type) {
   using namespace mshadow;
   CHECK_GE(in_type->size(), 1U);
-  const int dtype = (*in_type)[0];
-  CHECK_NE(dtype, -1) << "First input must have specified type";
+  const size_t n_out = 4;
   // For float16 input type beta, gamma, mean, and average are stored in 
float32.
   // For other input types, these parameters have the same type as input
   // NOTE: This requirement is from cuDNN (v. 4 and 5)
   int dtype_param;
-  MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+  int dtype = (*in_type)[0];
+
+  if (type_is_none(dtype)) {
+    // Input type is undefined, we try backward inference
+    if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+       // Neither the input nor the output are defined,
+       // types cannot be infered for this op
+       return false;
+    } else {
+       // Input type is undefined but output type is: backward inference
+       dtype = (*out_type)[0];
+       (*in_type)[0] = dtype;
+       MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+         dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+    }
+  } else {
+    // Input type is defined but output type is not: forward inference
+    MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
       dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+    out_type->clear();
+    out_type->push_back(dtype);
+    for (size_t i = 1; i < n_out; ++i) {
+      out_type->push_back(dtype_param);
+    }
+  }
   std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
   CHECK_LE(in_type->size(), args.size());
   for (size_t i = 1; i < in_type->size(); ++i) {
@@ -101,12 +122,6 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& 
attrs,
       UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
     }
   }
-  const size_t n_out = 4;
-  out_type->clear();
-  out_type->push_back(dtype);
-  for (size_t i = 1; i < n_out; ++i) {
-    out_type->push_back(dtype_param);
-  }
   return true;
 }
 
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index a59f8ba..60f9553 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -365,6 +365,9 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_shape->size(), 5U) << "Input:[data, gamma, beta, MovingMean, 
MovingVar]";
   CHECK_EQ(out_shape->size(), 3U);
   const mxnet::TShape &dshape = in_shape->at(batchnorm::kData);
+  if (!mxnet::ndim_is_known(dshape)) {
+    return false;
+  }
 
   const size_t channelAxis = static_cast<size_t>(param.axis < 0
       ? static_cast<int>(dshape.ndim()) + param.axis
@@ -373,10 +376,6 @@ static bool BatchNormShape(const nnvm::NodeAttrs& attrs,
 
   const int channelCount = dshape[channelAxis];
 
-  if (!mxnet::ndim_is_known(dshape)) {
-    return false;
-  }
-
   in_shape->at(batchnorm::kGamma) = mxnet::TShape(Shape1(channelCount));
   in_shape->at(batchnorm::kBeta) = mxnet::TShape(Shape1(channelCount));
   in_shape->at(batchnorm::kInMovingMean) = 
mxnet::TShape(Shape1(channelCount));  // kMovingMean
@@ -394,14 +393,35 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
                           std::vector<int> *in_type, std::vector<int> 
*out_type) {
   using namespace mshadow;
   CHECK_GE(in_type->size(), 1U);
-  const int dtype = (*in_type)[0];
-  CHECK_NE(dtype, -1) << "First input must have specified type";
+  const size_t n_out = 3;
   // For float16 input type beta, gamma, mean, and average are stored in 
float32.
   // For other input types, these parameters have the same type as input
   // NOTE: This requirement is from cuDNN (v. 4 and 5)
   int dtype_param;
-  MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+  int dtype = (*in_type)[0];
+  if (type_is_none(dtype)) {
+    // Input type is undefined, we try backward inference
+     if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+       // Neither the input nor the output are defined,
+       // types cannot be infered for this op
+       return false;
+     } else {
+       // Input type is undefined but output type is: backward inference
+       dtype = (*out_type)[0];
+       (*in_type)[0] = dtype;
+       MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
+         dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+     }
+  } else {
+    // Input type is defined but output type is not: forward inference
+    MSHADOW_REAL_TYPE_SWITCH_EX(dtype, DTypeX, AccRealX, {
       dtype_param = mshadow::DataType<AccRealX>::kFlag; });
+    out_type->clear();
+    out_type->push_back(dtype);
+    for (size_t i = 1; i < n_out; ++i) {
+      out_type->push_back(dtype_param);
+    }
+  }
   std::vector<std::string> args{"data", "gamma", "beta", "mean", "var"};
   CHECK_LE(in_type->size(), args.size());
   for (size_t i = 1; i < in_type->size(); ++i) {
@@ -411,12 +431,6 @@ static bool BatchNormType(const nnvm::NodeAttrs& attrs,
       UNIFORM_TYPE_CHECK((*in_type)[i], dtype_param, args[i]);
     }
   }
-  const size_t n_out = 3;
-  out_type->clear();
-  out_type->push_back(dtype);
-  for (size_t i = 1; i < n_out; ++i) {
-    out_type->push_back(dtype_param);
-  }
   return true;
 }
 
diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 8ff5ea7..3ebb67a 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -285,7 +285,16 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
   const ConvolutionParam& param_ = nnvm::get<ConvolutionParam>(attrs.parsed);
   CHECK_GE(in_type->size(), 1U);
   int dtype = (*in_type)[0];
-  CHECK_NE(dtype, -1) << "First input must have specified type";
+  if (type_is_none(dtype)) {
+    if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+      return false;
+    } else {
+      dtype = (*out_type)[0];
+    }
+  } else {
+    out_type->clear();
+    out_type->push_back(dtype);
+  }
   for (size_t i = 0; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype;
@@ -293,8 +302,6 @@ static bool ConvolutionType(const nnvm::NodeAttrs& attrs,
       UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
     }
   }
-  out_type->clear();
-  out_type->push_back(dtype);
   return true;
 }
 
diff --git a/src/operator/nn/deconvolution.cc b/src/operator/nn/deconvolution.cc
index cd22ace..08d6306 100644
--- a/src/operator/nn/deconvolution.cc
+++ b/src/operator/nn/deconvolution.cc
@@ -332,7 +332,21 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
   const DeconvolutionParam& param_ = 
nnvm::get<DeconvolutionParam>(attrs.parsed);
   CHECK_GE(in_type->size(), 1U);
   int dtype = (*in_type)[0];
-  CHECK_NE(dtype, -1) << "First input must have specified type";
+  if (type_is_none(dtype)) {
+    // Input type is undefined, we try backward inference
+    if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+       // Neither the input nor the output are defined,
+       // types cannot be infered for this op
+      return false;
+    } else {
+       // Input type is undefined but output type is: backward inference
+      dtype = (*out_type)[0];
+    }
+  } else {
+    // Input type is defined but output type is not: forward inference
+    out_type->clear();
+    out_type->push_back(dtype);
+  }
   for (size_t i = 0; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype;
@@ -340,8 +354,6 @@ static bool DeconvolutionType(const nnvm::NodeAttrs& attrs,
       UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments(param_)[i]);
     }
   }
-  out_type->clear();
-  out_type->push_back(dtype);
   return true;
 }
 
diff --git a/src/operator/nn/group_norm.cc b/src/operator/nn/group_norm.cc
index 6b8fe9b..a92ac31 100644
--- a/src/operator/nn/group_norm.cc
+++ b/src/operator/nn/group_norm.cc
@@ -39,14 +39,14 @@ static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
   using namespace mshadow;
   CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
   const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
-  CHECK_GE(dshape.ndim(), 3U);
-  const int num_groups = param.num_groups;
-  CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # 
of groups";
-
   if (!mxnet::ndim_is_known(dshape)) {
     return false;
   }
 
+  CHECK_GE(dshape.ndim(), 3U);
+  const int num_groups = param.num_groups;
+  CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # 
of groups";
+
   in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
   in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
 
diff --git a/src/operator/nn/layer_norm.cc b/src/operator/nn/layer_norm.cc
index d385b93..c3ccd0d 100644
--- a/src/operator/nn/layer_norm.cc
+++ b/src/operator/nn/layer_norm.cc
@@ -43,15 +43,16 @@ static bool LayerNormShape(const nnvm::NodeAttrs& attrs,
   using namespace mshadow;
   CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
   const mxnet::TShape &dshape = in_shape->at(layernorm::kData);
+  if (!mxnet::ndim_is_known(dshape)) {
+    return false;
+  }
+
   int axis = GetRealAxis(param.axis, dshape.ndim());
   CHECK(axis >= 0 && axis < dshape.ndim())
     << "Channel axis out of range: axis=" << param.axis;
 
   const int channelCount = dshape[axis];
 
-  if (!mxnet::ndim_is_known(dshape)) {
-    return false;
-  }
   SHAPE_ASSIGN_CHECK(*in_shape,
                      layernorm::kGamma,
                      mxnet::TShape(Shape1(channelCount)));
diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc
index 03787f4..c81cae3 100644
--- a/src/operator/nn/pooling.cc
+++ b/src/operator/nn/pooling.cc
@@ -95,10 +95,14 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
                          mxnet::ShapeVector *out_shape) {
   const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
   CHECK_EQ(in_shape->size(), 1U);
+  const mxnet::TShape &dshape = (*in_shape)[0];
+  if (!mxnet::ndim_is_known(dshape)) {
+    return false;
+  }
   if (param.pool_type == pool_enum::kLpPooling) {
     CHECK(param.p_value.has_value());
   }
-  const mxnet::TShape &dshape = (*in_shape)[0];
+
   if (param.pooling_convention == pool_enum::kSame) {
     CHECK_EQ(dshape.ndim(), 3U)
       << "Pooling: Input data should be 3D in (batch, channel, x)"
@@ -114,7 +118,6 @@ static bool PoolingShape(const nnvm::NodeAttrs &attrs,
       << "Pooling: Input data should be  3D in (batch, channel, x)"
       << " Or 4D in (batch, channel, y, x) "
       << " Or 5D in (batch, channel, d, y, x)";
-  if (!mxnet::ndim_is_known(dshape)) return false;
   int layout = param.GetLayout(dshape.ndim());
   if (param.global_pool) {
     mxnet::TShape oshape = dshape;
diff --git a/src/operator/softmax_output.cc b/src/operator/softmax_output.cc
index 13bb647..d87b781 100644
--- a/src/operator/softmax_output.cc
+++ b/src/operator/softmax_output.cc
@@ -66,7 +66,21 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
                               std::vector<int> *out_type) {
   CHECK_EQ(in_type->size(), 2U);
   int dtype = (*in_type)[0];
-  CHECK_NE(dtype, -1) << "First input must have specified type";
+  if (type_is_none(dtype)) {
+    // Input type is undefined, we try backward inference
+    if (out_type->size() == 0 || type_is_none((*out_type)[0])) {
+      // Neither the input nor the output are defined,
+      // types cannot be infered for this op
+      return false;
+    } else {
+      // Input type is undefined but output type is: backward inference
+      dtype = (*out_type)[0];
+    }
+  } else {
+    // Input type is defined but output type is not: forward inference
+    out_type->clear();
+    out_type->push_back(dtype);
+  }
   for (size_t i = 0; i < in_type->size(); ++i) {
     if ((*in_type)[i] == -1) {
       (*in_type)[i] = dtype;
@@ -74,8 +88,6 @@ static bool SoftmaxOutputType(const nnvm::NodeAttrs& attrs,
       UNIFORM_TYPE_CHECK((*in_type)[i], dtype, ListArguments()[i]);
     }
   }
-  out_type->clear();
-  out_type->push_back(dtype);
   return true;
 }
 
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index fa7b8a1..217bf10 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -455,9 +455,9 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   mxnet::TShape& shp = (*in_attrs)[0];
   mxnet::TShape& out_shp = (*out_attrs)[0];
-  CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
-  if (shp.ndim() == -1 && out_shp.ndim() == -1)
+  if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
     return false;  // none of the shapes is known
+  CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
   if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
     CHECK_EQ(out_shp.ndim(), shp.ndim());
   mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
@@ -506,12 +506,12 @@ inline bool ExpandDimShape(const nnvm::NodeAttrs& attrs,
   const ExpandDimParam& param = nnvm::get<ExpandDimParam>(attrs.parsed);
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
-  if (!mxnet::ndim_is_known(in_attrs->at(0)) && 
!mxnet::ndim_is_known(out_attrs->at(0))) {
+  mxnet::TShape& ishape = (*in_attrs)[0];
+  mxnet::TShape& oshape = (*out_attrs)[0];
+  if (!mxnet::ndim_is_known(ishape) && !mxnet::ndim_is_known(oshape)) {
     return false;
   }
 
-  mxnet::TShape& ishape = (*in_attrs)[0];
-  mxnet::TShape& oshape = (*out_attrs)[0];
   int indim = ishape.ndim();
   bool unknown_ishape = false;
   if (-1 == indim) {
@@ -1434,6 +1434,9 @@ inline bool SliceLikeShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(out_attrs->size(), 1U);
   mxnet::TShape& ishape = (*in_attrs)[0];
   mxnet::TShape& from_shape = (*in_attrs)[1];
+  if (!mxnet::ndim_is_known(ishape) || !mxnet::ndim_is_known(from_shape)) {
+    return false;
+  }
   if (param.axes.ndim() == 0) {
     CHECK_EQ(ishape.ndim(), from_shape.ndim())
       << "By default slice_axis performs slice on all axes, but ndim mismatch "
@@ -1727,6 +1730,9 @@ inline bool RepeatOpShape(const nnvm::NodeAttrs& attrs,
   CHECK_EQ(in_attrs->size(), 1U);
   CHECK_EQ(out_attrs->size(), 1U);
   const mxnet::TShape& ishape = (*in_attrs)[0];
+  if (!mxnet::ndim_is_known(ishape)) {
+    return false;
+  }
   int repeats = 0;
   dmlc::optional<int> axisOpt;
   GetRepeatParams(param, ishape, &repeats, &axisOpt);
@@ -2395,6 +2401,9 @@ inline bool DepthToSpaceOpShape(const nnvm::NodeAttrs& 
attrs,
   mxnet::TShape expected_out(4, -1);
 
   mxnet::TShape& in_shape = in_attrs->at(0);
+  if (!mxnet::ndim_is_known(in_shape)) {
+    return false;
+  }
   int block = param.block_size;
   CHECK_NE(block, 0) << "block_size must be a positive integer value";
   CHECK_NE(in_shape[1], 0) << "Depth dimension:1 cannot be 0";
@@ -2559,6 +2568,9 @@ inline bool SpaceToDepthOpShape(const nnvm::NodeAttrs& 
attrs,
   mxnet::TShape expected_out(in_attrs->at(0).ndim(), -1);
 
   mxnet::TShape& in_shape = in_attrs->at(0);
+  if (!mxnet::ndim_is_known(in_shape)) {
+    return false;
+  }
   int block = param.block_size;
   CHECK_NE(block, 0) << "block_size must be a positive integer value";
   CHECK_NE(in_shape[0], 0)

Reply via email to