TaoLv commented on a change in pull request #16075: Integrate MKL-DNN leakyrelu
URL: https://github.com/apache/incubator-mxnet/pull/16075#discussion_r321307114
 
 

 ##########
 File path: src/operator/leaky_relu.cc
 ##########
 @@ -25,27 +25,123 @@
 */
 
 #include "./leaky_relu-inl.h"
+#if MXNET_USE_MKLDNN == 1
+#include "./nn/mkldnn/mkldnn_base-inl.h"
+#include "./nn/mkldnn/mkldnn_ops-inl.h"
+#endif  // MXNET_USE_MKLDNN == 1
 
 #include <nnvm/op_attr_types.h>
 namespace mxnet {
 namespace op {
-template<>
-Operator *CreateOp<cpu>(LeakyReLUParam param, int dtype) {
-  Operator* op = nullptr;
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    op = new LeakyReLUOp<cpu, DType>(param);
-  });
-  return op;
+
+DMLC_REGISTER_PARAMETER(LeakyReLUParam);
+
+static bool LeakyReLUType(const nnvm::NodeAttrs& attrs,
+                          std::vector<int> *in_type,
+                          std::vector<int> *out_type) {
+  int dtype = -1;
+  for (const int& type : *in_type) {
+    type_assign(&dtype, type);
+  }
+  for (const int& type : *out_type) {
+    type_assign(&dtype, type);
+  }
+  for (size_t i = 0; i < in_type->size(); ++i) {
+    TYPE_ASSIGN_CHECK(*in_type, i, dtype);
+  }
+  for (size_t i = 0; i < out_type->size(); ++i) {
+    TYPE_ASSIGN_CHECK(*out_type, i, dtype);
+  }
+  return dtype != -1;
 }
 
-Operator *LeakyReLUProp::CreateOperatorEx(Context ctx, mxnet::ShapeVector 
*in_shape,
-                                          std::vector<int> *in_type) const {
-  DO_BIND_DISPATCH(CreateOp, param_, in_type->at(0));
+static bool LeakyReLUShape(const nnvm::NodeAttrs& attrs,
+                           std::vector<TShape> *in_shape,
+                           std::vector<TShape> *out_shape) {
+  using namespace mshadow;
+  const LeakyReLUParam &param_ = nnvm::get<LeakyReLUParam>(attrs.parsed);
+  if (param_.act_type == leakyrelu::kPReLU) {
+    CHECK_EQ(in_shape->size(), 2U) << "Input:[data, gamma]";
+  } else {
+    CHECK_EQ(in_shape->size(), 1U) << "Input:[data]";
+  }
+  const mxnet::TShape &dshape = in_shape->at(leakyrelu::kData);
+  if (!mxnet::ndim_is_known(dshape)) return false;
+  if (param_.act_type == leakyrelu::kPReLU) {
+    const mxnet::TShape &gshape = in_shape->at(leakyrelu::kGamma);
+    if (!mxnet::ndim_is_known(gshape)) {
+      in_shape->at(leakyrelu::kGamma) = mxnet::TShape(Shape1(dshape[1]));
+    }
+    if (dshape == gshape) {
+      SHAPE_ASSIGN_CHECK(*out_shape, 0, dshape);
+    }
+  }
+  out_shape->clear();
+  out_shape->push_back(dshape);
+  if (param_.act_type == leakyrelu::kRReLU) {
+    out_shape->push_back(dshape);
+  }
+  return true;
 }
 
-DMLC_REGISTER_PARAMETER(LeakyReLUParam);
+#if MXNET_USE_MKLDNN == 1
+static void LeakyReLUComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                   const OpContext& ctx,
+                                   const std::vector<NDArray>& inputs,
+                                   const std::vector<OpReqType>& req,
+                                   const std::vector<NDArray>& outputs) {
+  const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
+  size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
+  CHECK_EQ(inputs.size(), expected);
+  if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
+    MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
+    MKLDNNLeakyReluForward(attrs, ctx, inputs[0], req[0], outputs[0]);
+    MKLDNN_OPCHECK_RUN(LeakyReLUCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
+    return;
+  }
+  FallBackCompute(LeakyReLUCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+void LeakyReLUGradComputeExCPU(const nnvm::NodeAttrs& attrs,
+                                const OpContext& ctx,
+                                const std::vector<NDArray>& inputs,
+                                const std::vector<OpReqType>& req,
+                                const std::vector<NDArray>& outputs) {
+  const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
+  if (SupportMKLDNNLeakyRelu(param, inputs[0])) {
+    MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
+    MKLDNNLeakyReluBackward(attrs, ctx, inputs.at(0), inputs.at(1), req[0],
+                             outputs[0]);
+    MKLDNN_OPCHECK_RUN(LeakyReLUGradCompute<cpu>, attrs, ctx, inputs, req, 
outputs);
+    return;
+  }
+  FallBackCompute(LeakyReLUGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
+}
+
+inline static bool LeakyReLUStorageType(const nnvm::NodeAttrs& attrs,
+                                         const int dev_mask,
+                                         DispatchMode* dispatch_mode,
+                                         std::vector<int> *in_attrs,
+                                         std::vector<int> *out_attrs) {
+  const LeakyReLUParam& param = nnvm::get<LeakyReLUParam>(attrs.parsed);
+  size_t expected = param.act_type == leakyrelu::kPReLU ? 2 : 1;
+  CHECK_EQ(in_attrs->size(), expected);
+  return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNLeakyRelu(param),
+                           dispatch_mode, in_attrs, out_attrs);
+}
 
-MXNET_REGISTER_OP_PROPERTY(LeakyReLU, LeakyReLUProp)
+inline static bool BackwardLeakyReLUStorageType(const nnvm::NodeAttrs& attrs,
+                                          const int dev_mask,
 
 Review comment:
   Indentation.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to