This is an automated email from the ASF dual-hosted git repository.

taolv pushed a commit to branch mkldnn-v1.0
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/mkldnn-v1.0 by this push:
     new f930baa  [mkldnn-v1.0] Add MKL-DNN activation (#16195)
f930baa is described below

commit f930baa533d24497188e33c163dbb1f36707c336
Author: rongzha1 <[email protected]>
AuthorDate: Fri Sep 20 22:09:19 2019 +0800

    [mkldnn-v1.0] Add MKL-DNN activation (#16195)
    
    * add mkldnn act; pass lint; pass mnist training
    
    * make bwd as private member
---
 src/operator/nn/activation.cc           |  18 ++---
 src/operator/nn/mkldnn/mkldnn_act-inl.h |  40 +++++++---
 src/operator/nn/mkldnn/mkldnn_act.cc    | 130 ++++++++------------------------
 src/operator/nn/mkldnn/mkldnn_ops-inl.h |  15 ++--
 4 files changed, 75 insertions(+), 128 deletions(-)

diff --git a/src/operator/nn/activation.cc b/src/operator/nn/activation.cc
index 5abb667..f238e8f 100644
--- a/src/operator/nn/activation.cc
+++ b/src/operator/nn/activation.cc
@@ -27,10 +27,10 @@
 #include "./activation-inl.h"
 #include "../mshadow_op.h"
 #include "../tensor/elemwise_unary_op.h"
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include "./mkldnn/mkldnn_base-inl.h"
 #include "./mkldnn/mkldnn_ops-inl.h"
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100
 #include "../operator_common.h"
 #include "../../common/utils.h"
 
@@ -91,7 +91,7 @@ struct ActivationGrad {
   }
 };
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 static void ActivationComputeExCPU(const nnvm::NodeAttrs& attrs,
                                    const OpContext& ctx,
                                    const std::vector<NDArray>& inputs,
@@ -150,7 +150,7 @@ inline static bool BackwardActStorageType(const 
nnvm::NodeAttrs& attrs,
   return MKLDNNStorageType(attrs, dev_mask, SupportMKLDNNAct(param),
                            dispatch_mode, in_attrs, out_attrs);
 }
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100
 
 
 MXNET_OPERATOR_REGISTER_UNARY(Activation)
@@ -167,7 +167,7 @@ The following activation functions are supported:
 
 )code" ADD_FILELINE)
 .set_attr_parser(ParamParser<ActivationParam>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FInferStorageType>("FInferStorageType", ActivationStorageType)
 #endif
 .set_attr<nnvm::FListOutputNames>("FListOutputNames",
@@ -175,7 +175,7 @@ The following activation functions are supported:
     return std::vector<std::string>{"output"};
 })
 .set_attr<FCompute>("FCompute<cpu>", ActivationCompute<cpu>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ActivationComputeExCPU)
 #endif
@@ -189,7 +189,7 @@ NNVM_REGISTER_OP(_backward_Activation)
 })
 .set_num_outputs(1)
 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FInferStorageType>("FInferStorageType", BackwardActStorageType)
 #endif
 .set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<-1, 1>)
@@ -197,13 +197,13 @@ NNVM_REGISTER_OP(_backward_Activation)
 .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
   return std::vector<std::pair<int, int> >{{0, 0}};
 })
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
 })
 #endif
 .set_attr_parser(ParamParser<ActivationParam>)
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 .set_attr<bool>("TIsMKLDNN", true)
 .set_attr<FComputeEx>("FComputeEx<cpu>", ActivationGradComputeExCPU)
 #endif
diff --git a/src/operator/nn/mkldnn/mkldnn_act-inl.h 
b/src/operator/nn/mkldnn/mkldnn_act-inl.h
index 6bf30e3..57507a5 100644
--- a/src/operator/nn/mkldnn/mkldnn_act-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_act-inl.h
@@ -20,7 +20,7 @@
 /*!
  * Copyright (c) 2019 by Contributors
  * \file mkldnn_act-inl.h
- * \brief MKLDNN(Quantized) Activation operator based on subgraph
+ * \brief MKLDNN Activation operator
  * /author Zhiyuan Huang
 */
 
@@ -28,20 +28,17 @@
 #define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
 
 
-#if MXNET_USE_MKLDNN == 1
+#if MXNET_USE_MKLDNN == 100
 #include <vector>
 #include <utility>
 #include "../activation-inl.h"
-#include "./mkldnn_ops-inl.h"
-#include "./mkldnn_base-inl.h"
 
 namespace mxnet {
 namespace op {
 
 mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param);
 mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
-    const ActivationParam& param, bool is_train,
-    const mkldnn::memory &input_mem, int dtype);
+    const ActivationParam& param, bool is_train, const mkldnn::memory 
&input_mem);
 
 class MKLDNNActForward {
  public:
@@ -49,14 +46,13 @@ class MKLDNNActForward {
 
   MKLDNNActForward(const ActivationParam& param, bool is_train,
                    const NDArray &data, const mkldnn::memory &mem): fwd_pd(
-                       GetActFwdDescImpl(param, is_train, mem, data.dtype())) 
{}
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &output);
-  const mkldnn::eltwise_forward &GetFwd() const;
+                       GetActFwdDescImpl(param, is_train, mem)) {
+    fwd_ = std::make_shared<mkldnn::eltwise_forward>(fwd_pd);
+  }
+  const inline mkldnn::eltwise_forward &GetFwd() const;
 
  private:
   std::shared_ptr<mkldnn::eltwise_forward> fwd_;
-  std::shared_ptr<mkldnn::memory> data_;
-  std::shared_ptr<mkldnn::memory> out_;
 };
 
 typedef ParamOpSign<ActivationParam> MKLDNNActSignature;
@@ -67,8 +63,28 @@ MKLDNNActForward &GetActForward(const ActivationParam& param,
 void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext 
&ctx,
                              const NDArray &in_data, const OpReqType &req,
                              const NDArray &out_data);
+
+mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
+    const ActivationParam &param, const mkldnn::memory &input_mem,
+    const mkldnn::memory &diff_dst_memory);
+
+class MKLDNNActBackward {
+ public:
+  const mkldnn::eltwise_backward::primitive_desc pd;
+
+  explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
+                             const mkldnn::memory &mem,
+                             const mkldnn::memory &diff_dst_memory): pd(
+                                 GetActBwdDescImpl(param, mem, 
diff_dst_memory)) {
+    bwd = std::make_shared<mkldnn::eltwise_backward>(pd);
+  }
+  const inline mkldnn::eltwise_backward &GetBwd() const;
+
+ private:
+  std::shared_ptr<mkldnn::eltwise_backward> bwd;
+};
 }  // namespace op
 }  // namespace mxnet
 
-#endif  // MXNET_USE_MKLDNN == 1
+#endif  // MXNET_USE_MKLDNN == 100
 #endif  // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_ACT_INL_H_
diff --git a/src/operator/nn/mkldnn/mkldnn_act.cc 
b/src/operator/nn/mkldnn/mkldnn_act.cc
index e4c8296..e2ffd0b 100644
--- a/src/operator/nn/mkldnn/mkldnn_act.cc
+++ b/src/operator/nn/mkldnn/mkldnn_act.cc
@@ -23,6 +23,8 @@
  * \author Da Zheng
 */
 
+#if MXNET_USE_MKLDNN == 100
+
 #include <dmlc/logging.h>
 #include <dmlc/parameter.h>
 #include <mxnet/operator.h>
@@ -33,10 +35,7 @@
 #include <utility>
 #include "../../operator_common.h"
 #include "mkldnn_act-inl.h"
-
-#if MXNET_USE_MKLDNN == 1
-
-#include <mkldnn.hpp>
+#include "./mkldnn_base-inl.h"
 
 namespace mxnet {
 namespace op {
@@ -81,41 +80,19 @@ mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& 
param) {
 
 mkldnn::eltwise_forward::primitive_desc GetActFwdDescImpl(
     const ActivationParam& param, bool is_train,
-    const mkldnn::memory &input_mem, int dtype) {
-  mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
-  mkldnn::memory::desc data_md = data_mpd.desc();
-  auto cpu_engine = data_mpd.get_engine();
-
+    const mkldnn::memory &input_mem) {
+  mkldnn::memory::desc data_md = input_mem.get_desc();
+  auto cpu_engine = CpuEngine::Get()->get_engine();
   auto alg = GetMKLDNNActAlgo(param);
 
   auto prop = is_train ? mkldnn::prop_kind::forward_training :
                          mkldnn::prop_kind::forward_scoring;
   auto desc = mkldnn::eltwise_forward::desc(prop, alg, data_md, 0.0f);
-  return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
-}
-
-void MKLDNNActForward::SetNewMem(const mkldnn::memory &data, const 
mkldnn::memory &output) {
-  if (this->data_ == nullptr)
-    this->data_ = std::make_shared<mkldnn::memory>(data.get_primitive_desc(),
-            data.get_data_handle());
-  else
-    this->data_->set_data_handle(data.get_data_handle());
 
-  CHECK(fwd_pd.dst_primitive_desc() == output.get_primitive_desc());
-  if (this->out_ == nullptr)
-    this->out_ = std::make_shared<mkldnn::memory>(fwd_pd.dst_primitive_desc(),
-            output.get_data_handle());
-  else
-    this->out_->set_data_handle(output.get_data_handle());
-
-  if (this->fwd_ == nullptr) {
-    this->fwd_ = std::shared_ptr<mkldnn::eltwise_forward>(
-        new mkldnn::eltwise_forward(fwd_pd, 
mkldnn::primitive::at(*this->data_),
-                                    *this->out_));
-  }
+  return mkldnn::eltwise_forward::primitive_desc(desc, cpu_engine);
 }
 
-const mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
+const inline mkldnn::eltwise_forward &MKLDNNActForward::GetFwd() const {
   return *fwd_;
 }
 
@@ -131,7 +108,6 @@ MKLDNNActForward &GetActForward(const ActivationParam& 
param,
   key.AddSign(ctx.is_train);
   key.AddSign(param.act_type);
   key.AddSign(in_data);
-
   auto it = fwds.find(key);
   if (it == fwds.end()) {
     MKLDNNActForward fwd(param, ctx.is_train, in_data, in_mem);
@@ -153,81 +129,34 @@ void MKLDNNActivationForward(const nnvm::NodeAttrs& 
attrs, const OpContext &ctx,
 
   auto input_mem = in_buffer.GetMKLDNNData();
   MKLDNNActForward &fwd = GetActForward(param, ctx, in_buffer, *input_mem);
-  auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_primitive_desc(), 
req, &in_buffer);
-  fwd.SetNewMem(*input_mem, *out_mem_t.second);
-  stream->RegisterPrim(fwd.GetFwd());
+  auto out_mem_t = CreateMKLDNNMem(out_data, fwd.fwd_pd.dst_desc(), req, 
&in_buffer);
+  stream->RegisterPrimArgs(fwd.GetFwd(),
+                           {{ MKLDNN_ARG_SRC, *input_mem}, { MKLDNN_ARG_DST, 
*out_mem_t.second}});
   CommitOutput(out_data, out_mem_t);
   stream->Submit();
 }
 
-static mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
+mkldnn::eltwise_backward::primitive_desc GetActBwdDescImpl(
     const ActivationParam &param, const mkldnn::memory &input_mem,
-    const mkldnn::memory &diff_dst_memory, int dtype) {
-  mkldnn::memory::primitive_desc data_mpd = input_mem.get_primitive_desc();
-  mkldnn::memory::desc data_md = data_mpd.desc();
-  mkldnn::memory::desc diff_md = diff_dst_memory.get_primitive_desc().desc();
-  auto cpu_engine = data_mpd.get_engine();
+    const mkldnn::memory &diff_dst_memory) {
+  mkldnn::memory::desc data_md = input_mem.get_desc();
+  mkldnn::memory::desc diff_md = diff_dst_memory.get_desc();
+  auto cpu_engine = CpuEngine::Get()->get_engine();
   auto alg = GetMKLDNNActAlgo(param);
 
-  MSHADOW_REAL_TYPE_SWITCH(dtype, DType, {
-    DType alpha = 0;
-    mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
-                                          alg, data_md, alpha);
-    mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
-    mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
-    mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
-                                                      fw_pdesc);
-    return bw_pdesc;
-  });
-  LOG(FATAL) << "Unsupported data type for MKLDNN activation";
+  float alpha = 0;
   mkldnn::eltwise_forward::desc fw_desc(mkldnn::prop_kind::forward_training,
-                                        alg, data_md, 0.0);
+                                        alg, data_md, alpha);
   mkldnn::eltwise_forward::primitive_desc fw_pdesc(fw_desc, cpu_engine);
-  mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, 0.0);
+  mkldnn::eltwise_backward::desc bw_desc(alg, diff_md, data_md, alpha);
   mkldnn::eltwise_backward::primitive_desc bw_pdesc(bw_desc, cpu_engine,
                                                     fw_pdesc);
   return bw_pdesc;
 }
 
-class MKLDNNActBackward {
-  std::shared_ptr<mkldnn::eltwise_backward> bwd;
-  std::shared_ptr<mkldnn::memory> data;
-  std::shared_ptr<mkldnn::memory> diff_dst_memory;
-  std::shared_ptr<mkldnn::memory> diff_src_memory;
-
- public:
-  const mkldnn::eltwise_backward::primitive_desc pd;
-
-  explicit MKLDNNActBackward(const ActivationParam &param, const NDArray &data,
-                             const mkldnn::memory &mem,
-                             const mkldnn::memory &diff_dst_memory)
-      : pd(GetActBwdDescImpl(param, mem, diff_dst_memory, data.dtype())) {}
-
-  void SetNewMem(const mkldnn::memory &data,
-                 const mkldnn::memory &diff_dst_memory,
-                 const mkldnn::memory &diff_src_memory) {
-    if (this->bwd != nullptr) {
-      this->data->set_data_handle(data.get_data_handle());
-      
this->diff_dst_memory->set_data_handle(diff_dst_memory.get_data_handle());
-      
this->diff_src_memory->set_data_handle(diff_src_memory.get_data_handle());
-    } else {
-      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-          data.get_primitive_desc(), data.get_data_handle()));
-      this->diff_dst_memory = std::shared_ptr<mkldnn::memory>(
-          new mkldnn::memory(diff_dst_memory.get_primitive_desc(),
-                             diff_dst_memory.get_data_handle()));
-      this->diff_src_memory = std::shared_ptr<mkldnn::memory>(
-          new mkldnn::memory(diff_src_memory.get_primitive_desc(),
-                             diff_src_memory.get_data_handle()));
-      this->bwd = std::shared_ptr<mkldnn::eltwise_backward>(
-          new mkldnn::eltwise_backward(
-              this->pd, mkldnn::primitive::at(*this->data),
-              *this->diff_dst_memory, *this->diff_src_memory));
-    }
-  }
-
-  const inline mkldnn::eltwise_backward &GetBwd() const { return *bwd; }
-};
+const inline mkldnn::eltwise_backward &MKLDNNActBackward::GetBwd() const {
+  return *bwd;
+}
 
 static inline MKLDNNActBackward &GetActBackward(const ActivationParam &param,
                                                 const OpContext &ctx,
@@ -274,20 +203,23 @@ void MKLDNNActivationBackward(const nnvm::NodeAttrs& 
attrs, const OpContext &ctx
   auto input_mem = in_buffer.GetMKLDNNData();
   // We need to make sure the two inputs to eltwise_backward has the same 
memory
   // descriptor. Otherwise, the perf will suffer.
-  if (input_mem->get_primitive_desc() != diff_dst_memory->get_primitive_desc())
-    input_mem = 
in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_primitive_desc());
+  if (input_mem->get_desc() != diff_dst_memory->get_desc())
+    input_mem = in_buffer.GetMKLDNNDataReorder(diff_dst_memory->get_desc());
   MKLDNNActBackward &bwd =
       GetActBackward(param, ctx, in_buffer, out_buffer, *input_mem);
   MKLDNNStream *stream = MKLDNNStream::Get();
   mkldnn_output_t diff_src_memory =
-      CreateMKLDNNMem(in_grad, bwd.pd.diff_src_primitive_desc(), req);
-  bwd.SetNewMem(*input_mem, *diff_dst_memory, *diff_src_memory.second);
-  stream->RegisterPrim(bwd.GetBwd());
+      CreateMKLDNNMem(in_grad, bwd.pd.diff_src_desc(), req);
+  mkldnn_args_map_t args = {
+    { MKLDNN_ARG_SRC, *input_mem },
+    { MKLDNN_ARG_DIFF_DST, *diff_dst_memory },
+    { MKLDNN_ARG_DIFF_SRC, *diff_src_memory.second },
+  };
+  stream->RegisterPrimArgs(bwd.GetBwd(), args);
   CommitOutput(in_grad, diff_src_memory);
   stream->Submit();
 }
 
 }  // namespace op
 }  // namespace mxnet
-
 #endif
diff --git a/src/operator/nn/mkldnn/mkldnn_ops-inl.h 
b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
index ddfcecc..3c83f6b 100644
--- a/src/operator/nn/mkldnn/mkldnn_ops-inl.h
+++ b/src/operator/nn/mkldnn/mkldnn_ops-inl.h
@@ -95,14 +95,6 @@ void MKLDNNConcatBackward(const nnvm::NodeAttrs& attrs, 
const OpContext &ctx,
                           const std::vector<OpReqType>& req,
                           const std::vector<NDArray>& outputs);
 
-/* For activation */
-void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext 
&ctx,
-                             const NDArray &in_data, const OpReqType &req,
-                             const NDArray &out_data);
-void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext 
&ctx,
-                              const NDArray &out_grad, const NDArray &in_data,
-                              const OpReqType &req, const NDArray &in_grad);
-
 void MKLDNNTransposeForward(const nnvm::NodeAttrs& attrs,
                             const OpContext &ctx,
                             const NDArray &data,
@@ -133,6 +125,13 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& 
attrs, const OpContext &ct
                                const std::vector<OpReqType>& req,
                                const std::vector<NDArray>& outputs);
 
+/* For activation */
+void MKLDNNActivationForward(const nnvm::NodeAttrs& attrs, const OpContext 
&ctx,
+                             const NDArray &in_data, const OpReqType &req,
+                             const NDArray &out_data);
+void MKLDNNActivationBackward(const nnvm::NodeAttrs& attrs, const OpContext 
&ctx,
+                              const NDArray &out_grad, const NDArray &in_data,
+                              const OpReqType &req, const NDArray &in_grad);
 
 void MKLDNNSum(const mkldnn::memory &arr1, const mkldnn::memory &arr2,
          const mkldnn::memory &out);

Reply via email to