TaoLv commented on a change in pull request #14128: MKLDNN based Quantized 
FullyConnected Operator and its fusion
URL: https://github.com/apache/incubator-mxnet/pull/14128#discussion_r256436903
 
 

 ##########
 File path: src/operator/nn/mkldnn/mkldnn_fully_connected.cc
 ##########
 @@ -23,215 +23,269 @@
  * \author Da Zheng
 */
 
-#include "../fully_connected-inl.h"
-#include "./mkldnn_base-inl.h"
-
 #if MXNET_USE_MKLDNN == 1
+#include "mkldnn_fully_connected-inl.h"
+
 namespace mxnet {
 namespace op {
 
-inline static mkldnn::inner_product_forward::primitive_desc GetIPFwd(
+DMLC_REGISTER_PARAMETER(MKLDNNFCParam);
+
+mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
+    const MKLDNNFCFullParam &full_param, const bool is_train,
     const NDArray &data, const NDArray &weight, const NDArray *bias,
-    const mkldnn::memory::desc &out_md, const bool is_train) {
+    const mkldnn::memory::desc &out_md) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto engine = CpuEngine::Get()->get_engine();
   auto propagation =
     is_train ? mkldnn::prop_kind::forward_training : 
mkldnn::prop_kind::forward_scoring;
+
+  mkldnn::primitive_attr attr;
+  mkldnn::post_ops ops;
+  if (full_param.mkldnn_param.with_relu) {
+    float scale = 1.0f;
+    float alpha = 0.0f;
+    float beta = 1.0f;
+    ops.append_eltwise(scale, eltwise_relu, alpha, beta);
+  }
+  attr.set_post_ops(ops);
+
+  if (full_param.mkldnn_param.quantized) {
+    if (full_param.mkldnn_param.fuse_requantize ||
+        full_param.mkldnn_param.fuse_dequantize) {
+      int mask = 0;
+      std::vector<float> scales = {0.0};
+      if (full_param.requantize_scales.size()) {
+        scales[0] = full_param.requantize_scales[0];
+      } else if (full_param.output_scales.size()) {
+        scales[0] = full_param.output_scales[0];
+      } else {
+        LOG(FATAL) << "Must specified either output_scales or 
requantize_scales!";
+      }
+
+      attr.set_output_scales(mask, scales);
+      attr.set_int_output_round_mode(round_nearest);
+    }
+  }
+
   if (bias) {
     auto bias_md = GetMemDesc(*bias);
-    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
+    mkldnn::inner_product_forward::desc desc(propagation,
         data_md, weight_md, bias_md, out_md);
-    return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
+    return mkldnn::inner_product_forward::primitive_desc(desc, attr, engine);
   } else {
-    mkldnn::inner_product_forward::desc ipFwd_desc(propagation,
+    mkldnn::inner_product_forward::desc desc(propagation,
         data_md, weight_md, out_md);
-    return mkldnn::inner_product_forward::primitive_desc(ipFwd_desc, engine);
+    return mkldnn::inner_product_forward::primitive_desc(desc, attr, engine);
   }
 }
 
-inline static mkldnn::inner_product_backward_data::primitive_desc GetIpBwdData(
+inline static mkldnn::inner_product_backward_data::primitive_desc GetFCBwdData(
     const NDArray &data, const NDArray &weight, const NDArray &output,
-    mkldnn::inner_product_forward::primitive_desc ipFwd_pd) {
+    mkldnn::inner_product_forward::primitive_desc fwd_pd) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   mkldnn::inner_product_backward_data::desc desc(data_md, weight_md, out_md);
-  return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, 
ipFwd_pd);
+  return mkldnn::inner_product_backward_data::primitive_desc(desc, engine, 
fwd_pd);
 }
 
-inline static mkldnn::inner_product_backward_weights::primitive_desc 
GetIPBwdWeights(
+inline static mkldnn::inner_product_backward_weights::primitive_desc 
GetFCBwdWeights(
     const NDArray &data, const NDArray &weight, const NDArray *bias,
-    const NDArray &output, mkldnn::inner_product_forward::primitive_desc 
ipFwd_pd) {
+    const NDArray &output, mkldnn::inner_product_forward::primitive_desc 
fwd_pd) {
   auto data_md = GetMemDesc(data);
   auto weight_md = GetMemDesc(weight);
   auto out_md = GetMemDesc(output);
   auto engine = CpuEngine::Get()->get_engine();
   if (bias) {
     auto bias_md = GetMemDesc(*bias);
-    mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md,
+    mkldnn::inner_product_backward_weights::desc desc(data_md,
         weight_md, bias_md, out_md);
     return mkldnn::inner_product_backward_weights::primitive_desc(
-        ipBwdWeights_desc, engine, ipFwd_pd);
+        desc, engine, fwd_pd);
   } else {
-    mkldnn::inner_product_backward_weights::desc ipBwdWeights_desc(data_md,
+    mkldnn::inner_product_backward_weights::desc desc(data_md,
         weight_md, out_md);
     return mkldnn::inner_product_backward_weights::primitive_desc(
-        ipBwdWeights_desc, engine, ipFwd_pd);
+        desc, engine, fwd_pd);
   }
 }
 
-class MKLDNNFullyConnectForward {
-  std::shared_ptr<mkldnn::memory> data;
-  std::shared_ptr<mkldnn::memory> weight;
-  std::shared_ptr<mkldnn::memory> out;
-  std::shared_ptr<mkldnn::memory> bias;
-  std::shared_ptr<mkldnn::inner_product_forward> ipFwd;
+void MKLDNNFullyConnectedForward::SetNewMem(const mkldnn::memory &data,
+                                            const mkldnn::memory &weight,
+                                            const mkldnn::memory *bias,
+                                            const mkldnn::memory &output) {
+  if (this->data == nullptr)
+    this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.src_primitive_desc(), data.get_data_handle()));
+  else
+    this->data->set_data_handle(data.get_data_handle());
 
- public:
-  mkldnn::inner_product_forward::primitive_desc ipFwd_pd;
+  if (this->weight == nullptr)
+    this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.weights_primitive_desc(), weight.get_data_handle()));
+  else
+    this->weight->set_data_handle(weight.get_data_handle());
 
-  MKLDNNFullyConnectForward(const FullyConnectedParam &param, bool is_train,
-                            const NDArray &data, const NDArray &weight,
-                            const NDArray *bias,
-                            const mkldnn::memory::desc &output)
-      : ipFwd_pd(GetIPFwd(data, weight, bias, output, is_train)) {}
+  if (this->out == nullptr)
+    this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+            fwd_pd.dst_primitive_desc(), output.get_data_handle()));
+  else
+    this->out->set_data_handle(output.get_data_handle());
 
-  void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
-                 const mkldnn::memory *bias, const mkldnn::memory &output) {
-    if (this->data == nullptr)
-      this->data = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              ipFwd_pd.src_primitive_desc(), data.get_data_handle()));
+  if (bias != nullptr) {
+    if (this->bias == nullptr)
+      this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
+      fwd_pd.bias_primitive_desc(), bias->get_data_handle()));
     else
-      this->data->set_data_handle(data.get_data_handle());
-
-    if (this->weight == nullptr)
-      this->weight = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              ipFwd_pd.weights_primitive_desc(), weight.get_data_handle()));
-    else
-      this->weight->set_data_handle(weight.get_data_handle());
-
-    if (this->out == nullptr)
-      this->out = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-              ipFwd_pd.dst_primitive_desc(), output.get_data_handle()));
-    else
-      this->out->set_data_handle(output.get_data_handle());
-
-    if (bias != nullptr) {
-      if (this->bias == nullptr)
-        this->bias = std::shared_ptr<mkldnn::memory>(new mkldnn::memory(
-        ipFwd_pd.bias_primitive_desc(), bias->get_data_handle()));
-      else
-        this->bias->set_data_handle(bias->get_data_handle());
-      if (this->ipFwd == nullptr)
-        this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
-            new mkldnn::inner_product_forward(
-                ipFwd_pd, mkldnn::primitive::at(*this->data),
-                mkldnn::primitive::at(*this->weight),
-                mkldnn::primitive::at(*this->bias), *this->out));
-    } else if (this->ipFwd == nullptr) {
-      this->ipFwd = std::shared_ptr<mkldnn::inner_product_forward>(
+      this->bias->set_data_handle(bias->get_data_handle());
+    if (this->fwd == nullptr)
+      this->fwd = std::shared_ptr<mkldnn::inner_product_forward>(
           new mkldnn::inner_product_forward(
-              ipFwd_pd, mkldnn::primitive::at(*this->data),
-              mkldnn::primitive::at(*this->weight), *this->out));
-    }
+              fwd_pd, mkldnn::primitive::at(*this->data),
+              mkldnn::primitive::at(*this->weight),
+              mkldnn::primitive::at(*this->bias), *this->out));
+  } else if (this->fwd == nullptr) {
+    this->fwd = std::shared_ptr<mkldnn::inner_product_forward>(
+        new mkldnn::inner_product_forward(
+            fwd_pd, mkldnn::primitive::at(*this->data),
+            mkldnn::primitive::at(*this->weight), *this->out));
   }
-  const mkldnn::inner_product_forward &GetIpFwd() const {
-    return *ipFwd;
-  }
-};
-
-typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;
+}
 
-static inline MKLDNNFullyConnectForward &GetFCFwd(
-    const nnvm::NodeAttrs &attrs, const NDArray &data, const NDArray &weight,
-    const NDArray *bias, const mkldnn::memory::desc &output,
-    const bool is_train) {
+MKLDNNFullyConnectedForward &GetFCFwd(
+    const FullyConnectedParam &param, const bool is_train,
+    const NDArray &data, const NDArray &weight,
+    const NDArray *bias, const mkldnn::memory::desc &out_md) {
 #if DMLC_CXX11_THREAD_LOCAL
   static thread_local std::unordered_map<MKLDNNFullyconSignature,
-              MKLDNNFullyConnectForward, OpHash> fcFwds;
+              MKLDNNFullyConnectedForward, OpHash> fcFwds;
 #else
   static MX_THREAD_LOCAL std::unordered_map<MKLDNNFullyconSignature,
-              MKLDNNFullyConnectForward, OpHash> fcFwds;
+              MKLDNNFullyConnectedForward, OpHash> fcFwds;
 #endif
-  const FullyConnectedParam& param = 
nnvm::get<FullyConnectedParam>(attrs.parsed);
   MKLDNNFullyconSignature key(param);
+  key.AddSign(is_train);
   key.AddSign(data);
   key.AddSign(weight);
-  key.AddSign(is_train);
-
   if (bias)
     key.AddSign(*bias);
 
   auto it = fcFwds.find(key);
   if (it == fcFwds.end()) {
-    MKLDNNFullyConnectForward fcFwd(param, is_train, data, weight, bias,
-                                    output);
-    auto ins_ret = fcFwds.insert(
-        std::pair<MKLDNNFullyconSignature, MKLDNNFullyConnectForward>(key, 
fcFwd));
-    CHECK(ins_ret.second);
-    it = ins_ret.first;
+    MKLDNNFCFullParam full_param;
+    full_param.default_param = param;
+    full_param.mkldnn_param.Init(std::unordered_map<std::string, 
std::string>());
+    MKLDNNFullyConnectedForward fcFwd(full_param, is_train, data, weight, 
bias, out_md);
+    it = AddToCache(&fcFwds, key, fcFwd);
   }
   return it->second;
 }
 
-void MKLDNNFCForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx,
-                     const std::vector<NDArray> &in_data,
-                     const std::vector<OpReqType> &req,
-                     const std::vector<NDArray> &out_data) {
-  TmpMemMgr::Get()->Init(ctx.requested[fullc::kTempSpace]);
-  const FullyConnectedParam& param = 
nnvm::get<FullyConnectedParam>(attrs.parsed);
-  const TShape& ishape = in_data[fullc::kData].shape();
-  const TShape& oshape = out_data[fullc::kOut].shape();
-  NDArray weight = in_data[fullc::kWeight];
-  NDArray data = in_data[fullc::kData];
+void MKLDNNFCFlattenData(const FullyConnectedParam &param,
+                         const NDArray &out_data,
+                         NDArray *in_data,
+                         mkldnn::memory::desc *out_md) {
+  const TShape& ishape = in_data->shape();
+  const TShape& oshape = out_data.shape();
+
   // If the input data is a view of an MKLDNN array, we should create a new
   // NDArray with reordered data.
-  if (data.IsMKLDNNData() && data.IsView())
-    data = in_data[fullc::kData].Reorder2Default();
+  if (in_data->IsMKLDNNData() && in_data->IsView())
+    *in_data = in_data->Reorder2Default();
 
-  auto out_md = GetMemDesc(out_data[fullc::kOut]);
-  if (data.shape().ndim() != 2 && !param.flatten) {
-    data = data.MKLDNNDataReshape(Shape2(ishape.ProdShape(0, ishape.ndim()-1),
-                                     ishape[ishape.ndim()-1]));
+  if (ishape.ndim() != 2 && !param.flatten) {
+    *in_data = in_data->MKLDNNDataReshape(Shape2(ishape.ProdShape(0, 
ishape.ndim()-1),
+                                                  ishape[ishape.ndim()-1]));
     mkldnn::memory::dims out_dims{static_cast<int>(oshape.ProdShape(0, 
oshape.ndim()-1)),
       static_cast<int>(oshape[ishape.ndim()-1])};
-    out_md = mkldnn::memory::desc(out_dims, 
get_mkldnn_type(out_data[fullc::kOut].dtype()),
+    *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
       mkldnn::memory::format::any);
-  } else if (data.shape().ndim() != 2) {
-    data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, 
ishape.ndim())));
+  } else if (ishape.ndim() != 2) {
+    *in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], 
ishape.ProdShape(1, ishape.ndim())));
     mkldnn::memory::dims out_dims{static_cast<int>(oshape[0]),
       static_cast<int>(oshape.ProdShape(1, oshape.ndim()))};
-    out_md = mkldnn::memory::desc(out_dims, 
get_mkldnn_type(out_data[fullc::kOut].dtype()),
+    *out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(out_data.dtype()),
       mkldnn::memory::format::any);
   }
 
 Review comment:
   need `else`.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to