zheng-da commented on a change in pull request #8302: Refactor operators & 
MKLDNN
URL: https://github.com/apache/incubator-mxnet/pull/8302#discussion_r162432577
 
 

 ##########
 File path: src/operator/nn/pooling.cc
 ##########
 @@ -32,67 +33,304 @@
 #if MXNET_USE_NNPACK == 1
 #include "./nnpack/nnpack_pooling-inl.h"
 #endif  // MXNET_USE_NNPACK
+#if MXNET_USE_MKLDNN == 1
+#include "./mkldnn/mkldnn_pooling-inl.h"
+#endif  // MXNET_USE_MKLDNN
 
 namespace mxnet {
 namespace op {
 
-template<>
-Operator *CreateOp<cpu>(PoolingParam param, int dtype) {
-  Operator *op = NULL;
-#if MXNET_USE_MKL2017 == 1
-    if (param.kernel.ndim() == 2
-      && ((param.pool_type == pool_enum::kMaxPooling)
-      || (param.pool_type == pool_enum::kAvgPooling))) {
-      switch (dtype) {
-      case mshadow::kFloat32:
-        return new MKLPoolingOp<cpu, float>(param);
-      case mshadow::kFloat64:
-        return new MKLPoolingOp<cpu, double>(param);
-      default:
-        break;
+static void PoolingParamParser(nnvm::NodeAttrs *attrs) {
+  using namespace mshadow;
+  PoolingParam param_;
+  param_.Init(attrs->dict);
+  if (param_.kernel.ndim() == 1) {
+    if (param_.stride.ndim() == 0) param_.stride = Shape1(1);
+    if (param_.pad.ndim() == 0) param_.pad = Shape1(0);
+  } else if (param_.kernel.ndim() == 2) {
+    if (param_.stride.ndim() == 0) param_.stride = Shape2(1, 1);
+    if (param_.pad.ndim() == 0) param_.pad = Shape2(0, 0);
+  } else {
+    CHECK_EQ(param_.kernel.ndim(), 3U) << param_.kernel.ndim()
+                                       << "D pooling not supported";
+    if (param_.stride.ndim() == 0) param_.stride = Shape3(1, 1, 1);
+    if (param_.pad.ndim() == 0) param_.pad = Shape3(0, 0, 0);
+  }
+  CHECK_EQ(param_.stride.ndim(), param_.kernel.ndim())
+      << "stride and kernel should have the same length";
+  CHECK_EQ(param_.pad.ndim(), param_.kernel.ndim())
+      << "pad and kernel should have the same length";
+  attrs->parsed = std::move(param_);
+}
+
+int GetNumOutputs(const PoolingParam &param) {
+#if MXNET_USE_MKLDNN == 1
+  return MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param) ? 2 : 1;
+#else
+  return 1;
+#endif
+}
+
+int GetNumBackInputs(const PoolingParam &param) {
+#if MXNET_USE_MKLDNN == 1
+  return MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param) ? 5 : 3;
+#else
+  return 3;
+#endif
+}
+
+static bool PoolingType(const nnvm::NodeAttrs& attrs,
+                        std::vector<int> *in_attrs,
+                        std::vector<int> *out_attrs) {
+  out_attrs->at(0) = in_attrs->at(0);
+#if MXNET_USE_MKLDNN == 1
+  const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
+  if (MKLDNNRequireWorkspace(param) && SupportMKLDNNPooling(param)) {
+    CHECK_GT(out_attrs->size(), 1U);
+    out_attrs->at(1) = mshadow::kInt32;
+  }
+#endif
+  return true;
+}
+
+static bool PoolingShape(const nnvm::NodeAttrs &attrs,
+                         std::vector<TShape> *in_shape,
+                         std::vector<TShape> *out_shape) {
+  const PoolingParam &param_ = nnvm::get<PoolingParam>(attrs.parsed);
+  CHECK_EQ(in_shape->size(), 1U);
+  const TShape &dshape = (*in_shape)[0];
+  CHECK_GE(dshape.ndim(), 3U)
+      << "Pooling: Input data should be  3D in (batch, channel, x)"
+      << " Or 4D in (batch, channel, y, x) "
+      << " Or 5D in (batch, channel, d, y, x)";
+  TShape oshape = dshape;
+  if (dshape.ndim() == 0) return false;
+  if (param_.kernel.ndim() == 1) {
+    CHECK_EQ(dshape.ndim(), 3U)
+        << "Pooling: Input data should be 3D in (batch, channel, x)";
+    if (param_.global_pool) {
+      oshape[2] = 1;
+    } else {
+      CHECK(param_.kernel[0] <= dshape[2] + 2 * param_.pad[0])
+          << "kernel size (" << param_.kernel[0] << ") exceeds input ("
+          << dshape[2] << " padded to " << (dshape[2] + 2 * param_.pad[0])
+          << ")";
+      if (param_.pooling_convention == pool_enum::kValid) {
+        oshape[2] = 1 +
+                    (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) /
+                        param_.stride[0];
+      } else {
+        oshape[2] = 1 + static_cast<int>(ceil(
+                            static_cast<float>(dshape[2] + 2 * param_.pad[0] -
+                                               param_.kernel[0]) /
+                            param_.stride[0]));
       }
     }
+    out_shape->clear();
+    out_shape->push_back(oshape);  // save output shape
+#if MXNET_USE_MKLDNN == 1
+    if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_))
+      out_shape->push_back(oshape);   // for workspace
 #endif
-#if MXNET_USE_NNPACK == 1
-  // NNPACK only support max-pooling with kernel = 2, stride = 2, 
pooling_convention
-  // = kFull(note that the default value is kValid in MXNet)
-  if ((param.pool_type == pool_enum::kMaxPooling)
-    && (param.pooling_convention == pool_enum::kFull)
-    && (param.kernel.ndim() == 2) && (param.stride.ndim() == 2)
-    && (param.kernel[0] == 2) && (param.kernel[1] == 2)
-    && (param.stride[0] == 2) && (param.stride[1] == 2)) {
-    switch (dtype) {
-    case mshadow::kFloat32:
-      return new NNPACKPoolingOp<cpu, float>(param);
-    default:
-      break;
+  } else if (param_.kernel.ndim() == 2) {
+    CHECK_EQ(dshape.ndim(), 4U)
+        << "Pooling: Input data should be 4D in (batch, channel, y, x)";
+    if (param_.global_pool) {
+      oshape[2] = 1;
+      oshape[3] = 1;
+    } else {
+      CHECK(param_.kernel[0] <= dshape[2] + 2 * param_.pad[0])
+          << "kernel size (" << param_.kernel[0] << ") exceeds input ("
+          << dshape[2] << " padded to " << (dshape[2] + 2 * param_.pad[0])
+          << ")";
+      CHECK(param_.kernel[1] <= dshape[3] + 2 * param_.pad[1])
+          << "kernel size (" << param_.kernel[1] << ") exceeds input ("
+          << dshape[3] << " padded to " << (dshape[3] + 2 * param_.pad[1])
+          << ")";
+      if (param_.pooling_convention == pool_enum::kValid) {
+        oshape[2] = 1 +
+                    (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) /
+                        param_.stride[0];
+        oshape[3] = 1 +
+                    (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) /
+                        param_.stride[1];
+      } else {
+        oshape[2] = 1 + static_cast<int>(ceil(
+                            static_cast<float>(dshape[2] + 2 * param_.pad[0] -
+                                               param_.kernel[0]) /
+                            param_.stride[0]));
+        oshape[3] = 1 + static_cast<int>(ceil(
+                            static_cast<float>(dshape[3] + 2 * param_.pad[1] -
+                                               param_.kernel[1]) /
+                            param_.stride[1]));
+      }
     }
-  }
+    out_shape->clear();
+    out_shape->push_back(oshape);  // save output shape
+#if MXNET_USE_MKLDNN == 1
+    if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_))
+      out_shape->push_back(oshape);   // for workspace
 #endif
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    if (pool_enum::kMaxPooling == param.pool_type
-        || pool_enum::kAvgPooling == param.pool_type
-        || pool_enum::kSumPooling == param.pool_type) {
-      op = new PoolingOp<cpu, DType>(param);
+  } else if (param_.kernel.ndim() == 3) {
+    CHECK_EQ(dshape.ndim(), 5U)
+        << "Pooling: Input data should be 5D in (batch, channel, d, y, x)";
+    CHECK_LE(param_.kernel[0], dshape[2] + 2 * param_.pad[0])
+        << "kernel size exceeds input";
+    CHECK_LE(param_.kernel[1], dshape[3] + 2 * param_.pad[1])
+        << "kernel size exceeds input";
+    CHECK_LE(param_.kernel[2], dshape[4] + 2 * param_.pad[2])
+        << "kernel size exceeds input";
+    if (param_.global_pool) {
+      oshape[2] = 1;
+      oshape[3] = 1;
+      oshape[4] = 1;
     } else {
-      LOG(FATAL) << "unknown pooling type";
-      return NULL;
+      if (param_.pooling_convention == pool_enum::kValid) {
+        oshape[2] = 1 +
+                    (dshape[2] + 2 * param_.pad[0] - param_.kernel[0]) /
+                        param_.stride[0];
+        oshape[3] = 1 +
+                    (dshape[3] + 2 * param_.pad[1] - param_.kernel[1]) /
+                        param_.stride[1];
+        oshape[4] = 1 +
+                    (dshape[4] + 2 * param_.pad[2] - param_.kernel[2]) /
+                        param_.stride[2];
+      } else {
+        oshape[2] = 1 + static_cast<int>(ceil(
+                            static_cast<float>(dshape[2] + 2 * param_.pad[0] -
+                                               param_.kernel[0]) /
+                            param_.stride[0]));
+        oshape[3] = 1 + static_cast<int>(ceil(
+                            static_cast<float>(dshape[3] + 2 * param_.pad[1] -
+                                               param_.kernel[1]) /
+                            param_.stride[1]));
+        oshape[4] = 1 + static_cast<int>(ceil(
+                            static_cast<float>(dshape[4] + 2 * param_.pad[2] -
+                                               param_.kernel[2]) /
+                            param_.stride[2]));
+      }
     }
-  });
 
-  return op;
+    out_shape->clear();
+    out_shape->push_back(oshape);  // save output shape
+#if MXNET_USE_MKLDNN == 1
+    if (MKLDNNRequireWorkspace(param_) && SupportMKLDNNPooling(param_))
+      out_shape->push_back(oshape);   // for workspace
+#endif
+  }
+  return true;
 }
 
-// DO_BIND_DISPATCH comes from operator_common.h
-Operator* PoolingProp::CreateOperatorEx(Context ctx, std::vector<TShape> 
*in_shape,
-                                     std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, (*in_type)[0]);
+#if MXNET_USE_MKLDNN == 1
+void PoolingComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
+                         const std::vector<NDArray> &inputs,
+                         const std::vector<OpReqType> &req,
+                         const std::vector<NDArray> &outputs) {
+  const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
+  const NDArray *workspace = nullptr;
+  if (MKLDNNRequireWorkspace(param)) {
+    CHECK_GT(outputs.size(), 1U);
+    workspace = &outputs[1];
+  }
+  if (SupportMKLDNN(inputs[0])
+      && SupportMKLDNNPooling(param, inputs[0].shape())) {
+    MKLDNNPoolingCompute(ctx, param, inputs[0], req[0], outputs[0],
+                         workspace);
+    return;
+  }
+  FallBackCompute(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+void PoolingGradComputeExCPU(const nnvm::NodeAttrs &attrs, const OpContext 
&ctx,
+                             const std::vector<NDArray> &inputs,
+                             const std::vector<OpReqType> &req,
+                             const std::vector<NDArray> &outputs) {
+  const PoolingParam &param = nnvm::get<PoolingParam>(attrs.parsed);
+  const NDArray &out_grad = inputs[0];
+  const NDArray *workspace = nullptr;
+  const NDArray *in_data = nullptr;
+  if (MKLDNNRequireWorkspace(param)) {
+    // The first two elements are the gradient of the outputs in forward.
+    // The third is the input of forward.
+    // The fourth and the fifth are the outputs of forward.
+    CHECK_EQ(inputs.size(), 5U);
+    in_data = &inputs[2];
+    workspace = &inputs[4];
+  } else {
+    CHECK_EQ(inputs.size(), 3U);
+    in_data = &inputs[1];
+  }
+  const NDArray &in_grad = outputs[0];
+  if (SupportMKLDNN(inputs[0])
+      && SupportMKLDNNPooling(param, inputs[0].shape())) {
+    MKLDNNPoolingGradCompute(ctx, param, out_grad, *in_data, workspace,
+                             req[0], in_grad);
+    return;
+  }
+  FallBackCompute(PoolingGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+#endif
+
+struct PoolingGrad {
 
 Review comment:
   it seems it's not used. i'll remove it.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to