TaoLv commented on a change in pull request #16075: Integrate MKL-DNN leakyrelu
URL: https://github.com/apache/incubator-mxnet/pull/16075#discussion_r321306668
 
 

 ##########
 File path: src/operator/leaky_relu.cc
 ##########
 @@ -25,27 +25,123 @@
 */
 
 #include "./leaky_relu-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./nn/mkldnn/mkldnn_base-inl.h"
+#include "./nn/mkldnn/mkldnn_ops-inl.h"
+#endif  // MXNET_USE_MKLDNN == 1
 
 #include <nnvm/op_attr_types.h>
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<cpu>(LeakyReLUParam param, int dtype) {
-  Operator* op = nullptr;
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new LeakyReLUOp<cpu, DType>(param);
-  });
-  return op;
+
+DMLC_REGISTER_PARAMETER(LeakyReLUParam);
+
+static bool LeakyReLUType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int> *in_type,
+                          std::vector<int> *out_type) {
+  int dtype = -1;
+  for (const int& type : *in_type) {
+    type_assign(&dtype, type);
+  }
+  for (const int& type : *out_type) {
+    type_assign(&dtype, type);
+  }
+  for (size_t i = 0; i < in_type->size(); ++i) {
+    TYPE_ASSIGN_CHECK(*in_type, i, dtype);
+  }
+  for (size_t i = 0; i < out_type->size(); ++i) {
+    TYPE_ASSIGN_CHECK(*out_type, i, dtype);
+  }
+  return dtype != -1;
 }
 
-Operator *LeakyReLUProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector 
*in_shape,
-                                          std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
+static bool LeakyReLUShape(const nnvm::NodeAttrs& attrs,
+                           std::vector<TShape> *in_shape,
+                           std::vector<TShape> *out_shape) {
+  using namespace mshadow;
+  const LeakyReLUParam &param_ = nnvm::get<LeakyReLUParam>(attrs.parsed);
+  if (param_.act_type == leakyrelu::kPReLU) {
+    CHECK_EQ(in_shape->size(), 2U) << "Input:[data, gamma]";
+  } else {
+    CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
+  }
+  const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData);
+  if (!mxnet::ndim_is_known(dshape)) return false;
+  if (param_.act_type == leakyrelu::kPReLU) {
+    const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma);
+    if (!mxnet::ndim_is_known(gshape)) {
+      in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1]));
+    }
+    if (dshape == gshape) {
+      SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
+    }
+  }
+  out_shape->clear();
+  out_shape->push_back(dshape);
+  if (param_.act_type == leakyrelu::kRReLU) {
+    out_shape->push_back(dshape);
+  }
+  return true;
 }
 
-DMLC_REGISTER_PARAMETER(LeakyReLUParam);
+#if MXNET_USE_MKLDNN == 1
+static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<NDArray>& outputs) {
+  const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
+  size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
+  CHECK_EQ(inputs.size(), expected);
+  if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
+    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNLeakyReluForward(attrs, ctx, inputs[0], req[0], outputs[0]);
+    MKLDNN_OPCHECK_RUN(LeakyReLUCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
+    return;
+  }
+  FallBackCompute(LeakyReLUCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<NDArray>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<NDArray>& outputs) {
+  const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
+  if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
+    MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+    MKLDNNLeakyReluBackward(attrs, ctx, inputs.at(0), inputs.at(1), req[0],
 
 Review comment:
   Only two inputs are needed? Is it possible to use vector so we can have a 
more unified interface?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to