This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 84b1626b66 oneDNN FullyConnected weight caching & refactor (#21047)
84b1626b66 is described below

commit 84b1626b66f84a35e87a95d26226c3a3171ea78c
Author: bgawrych <[email protected]>
AuthorDate: Wed Jul 6 12:02:58 2022 +0200

    oneDNN FullyConnected weight caching & refactor (#21047)
    
    * FC weight and bias caching
    
    * prepare output for sum
    
    * check initialization conditions
    
    * create output mem desc
    
    * PrepareQuantization
    
    * remove unused variables
    
    * cleanup
    
    * Enable BRGEMM
    
    * Reorder functions
    
    * make minmax enum anonymous
    
    * node identificator & env flag
    
    * fix sanity
    
    * fix sanity
    
    * apply review
    
    * rename variable
---
 src/operator/nn/dnnl/dnnl_base-inl.h          |   9 +-
 src/operator/operator_common.h                |   7 +-
 src/operator/subgraph/dnnl/dnnl_fc.cc         | 688 +++++++++++++++-----------
 src/operator/subgraph/dnnl/dnnl_fc_property.h |   5 +-
 4 files changed, 401 insertions(+), 308 deletions(-)

diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h 
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 66c1dc2c99..0b9645b8a3 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -351,13 +351,6 @@ inline static dnnl::memory::desc GetMemDesc(const NDArray& 
arr, int dtype = -1)
   return dnnl::memory::desc{dims, get_dnnl_type(dtype), 
dnnl::memory::format_tag::any};
 }
 
-inline static bool ChooseBRGEMMImpl(const dnnl::memory::dims& weight_dims, 
size_t batch_size) {
-  // Conditions based on measurement results done on CLX8280
-  // https://github.com/apache/incubator-mxnet/pull/20533
-  return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >= 
16384 &&
-         weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
-}
-
 inline static dnnl::memory::desc GetFCWeightDesc(const NDArray& arr,
                                                  size_t batch_size,
                                                  int dtype = -1) {
@@ -370,7 +363,7 @@ inline static dnnl::memory::desc GetFCWeightDesc(const 
NDArray& arr,
   // for batch 256 alexnet benchmark test
   const bool force_fc_ab_format = 
dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
   if (dims.size() == 2) {
-    if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
+    if (force_fc_ab_format || dtype != mshadow::kInt8) {
       format = dnnl::memory::format_tag::ab;
     }
   }
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index a0f158f2f9..ffb9898104 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -547,7 +547,7 @@ class OpSignature {
 
 #if MXNET_USE_ONEDNN == 1
   void AddSign(const dnnl::memory::desc& desc) {
-    hash      = hash * 2 + desc.data.format_kind;
+    hash = hash * 2 + desc.data.format_kind;
     eles.push_back(desc.data.format_kind);
     hash = hash * 2 + desc.data.data_type;
     eles.push_back(desc.data.data_type);
@@ -617,6 +617,11 @@ class OpSignature {
 
 #endif
 
+  void AddSign(const std::string& s) {
+    uint64_t key = static_cast<uint64_t>(std::hash<std::string>{}(s));
+    eles.push_back(key);
+  }
+
   void AddSign(const std::vector<NDArray>& arrs) {
     for (auto& arr : arrs) {
       AddSign(arr);
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc 
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 2371ce8bd9..22971bf487 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -47,7 +47,9 @@ namespace op {
 class SgDNNLFCOp {
  public:
   explicit SgDNNLFCOp(const nnvm::NodeAttrs& attrs)
-      : subgraph_sym_(*attrs.subgraphs[0]), 
full_param_(nnvm::get<DNNLFCFullParam>(attrs.parsed)) {}
+      : subgraph_sym_(*attrs.subgraphs[0]),
+        attrs(attrs),
+        full_param_(nnvm::get<DNNLFCFullParam>(attrs.parsed)) {}
 
   void Forward(const OpContext& ctx,
                const std::vector<NDArray>& inputs,
@@ -63,11 +65,27 @@ class SgDNNLFCOp {
   }
 
  private:
+  enum { kDataMin = 0, kDataMax, kWeightMin, kWeightMax, kBiasMin, kBiasMax, 
kSumMin, kSumMax };
+  const size_t MIN_MAX_COUNT = 8;
+
+  NDArray PrepareOutputWithSum(const NDArray& sum_input, const NDArray& 
output);
+  bool CheckInitializationConditions(const std::vector<NDArray>& inputs,
+                                     const std::vector<float>& min_max_vec,
+                                     bool is_channel_wise);
+  bool PrepareQuantization(const OpContext& ctx,
+                           const std::vector<NDArray>& in_data,
+                           const NDArray& output,
+                           const std::vector<float>& min_max_vec);
+  dnnl::memory::desc CreateOutputMemoryDesc(const mxnet::TShape& oshape, int 
out_dtype);
+  void GetCachedWeightsAndBias(const NDArray& weight,
+                               bool support_channelwise_scale,
+                               bool has_bias);
+  nnvm::Symbol subgraph_sym_;
+  nnvm::NodeAttrs attrs;
+  DNNLFCFullParam full_param_;
   bool initialized_{false};
   bool reorder_data_{false};
   bool inplace_{false};
-  nnvm::Symbol subgraph_sym_;
-  DNNLFCFullParam full_param_;
   dnnl_args_map_t args_;
   std::shared_ptr<DNNLFullyConnectedForward> fwd_;
   std::shared_ptr<dnnl::memory> cached_data_mem_;
@@ -94,16 +112,15 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
                          const std::vector<NDArray>& in_data,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& out_data) {
-  auto& dnnl_param         = full_param_.dnnl_param;
-  auto& default_param      = full_param_.default_param;
-  const bool has_bias      = !default_param.no_bias;
-  const bool quantized     = dnnl_param.quantized;
-  const bool out_quantized = dnnl_param.quantized && 
!dnnl_param.enable_float_output;
-  const bool channel_wise  = quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
+  const auto& default_param = full_param_.default_param;
+  const auto& dnnl_param    = full_param_.dnnl_param;
+  const bool has_bias       = !default_param.no_bias;
+  const bool quantized      = dnnl_param.quantized;
+  const bool out_quantized  = dnnl_param.quantized && 
!dnnl_param.enable_float_output;
+  const bool channel_wise   = quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
                             dnnl_param.channel_wise_quantize.value();
 
   const FCInputIndex idx(full_param_);
-
   CHECK_EQ(in_data.size(), idx.GetTotal());
 
   int index               = 0;
@@ -112,125 +129,54 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
   const int out_max_index = out_quantized ? index++ : 0;
   CHECK_EQ(out_data.size(), index);  // index is equal to total number of 
outputs
 
-  float data_min   = 0.0f;
-  float data_max   = 0.0f;
-  float weight_min = 0.0f;
-  float weight_max = 0.0f;
-  float bias_min   = 0.0f;
-  float bias_max   = 0.0f;
+  std::vector<float> min_max_vec(MIN_MAX_COUNT);
+  min_max_vec[kDataMin]   = 0.0f;
+  min_max_vec[kDataMax]   = 0.0f;
+  min_max_vec[kWeightMin] = 0.0f;
+  min_max_vec[kWeightMax] = 0.0f;
+  min_max_vec[kBiasMin]   = 0.0f;
+  min_max_vec[kBiasMax]   = 0.0f;
 
-  const float sum_min   = idx.sum_min ? 
in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
-  const float sum_max   = idx.sum_max ? 
in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
+  min_max_vec[kSumMin]  = idx.sum_min ? 
in_data[idx.sum_min].data().dptr<float>()[0] : 0.0f;
+  min_max_vec[kSumMax]  = idx.sum_max ? 
in_data[idx.sum_max].data().dptr<float>()[0] : 0.0f;
   NDArray data          = in_data[idx.data];
   const NDArray& weight = in_data[idx.weight];
   NDArray output;
 
   if (dnnl_param.with_sum) {
-    if (!initialized_) {
-      // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace 
option,
-      // which make check (req[out_index] == kWriteInplace) useless.
-      auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
-      auto out_dnnl_mem = static_cast<const 
dnnl::memory*>(out_data[out_index].GetDNNLData());
-      if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle() &&
-          in_data[idx.sum].dtype() == out_data[out_index].dtype()) {
-        inplace_ = true;
-      }
-    }
-    if (inplace_) {
-      output = in_data[idx.sum];
-    } else {
-      // Not in place: copy in_data[idx.sum] into outputs[out_index].
-      auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
-      auto out_dnnl_mem = static_cast<const 
dnnl::memory*>(out_data[out_index].GetDNNLData());
-      if (out_data[out_index].dtype() == mshadow::kInt32) {
-        auto mem_desc           = in_dnnl_mem->get_desc();
-        auto this_dtype         = get_dnnl_type(mshadow::kInt32);
-        mem_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
-        dnnl_mem_ptr tmp_mem(new dnnl::memory(
-            mem_desc, CpuEngine::Get()->get_engine(), 
out_dnnl_mem->get_data_handle()));
-        DNNLStream::Get()->RegisterMem(tmp_mem);
-        DNNLStream::Get()->RegisterPrimArgs(
-            dnnl::reorder(*in_dnnl_mem, *tmp_mem),
-            {{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
-        output = NDArray(tmp_mem);
-      } else if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
-                 out_data[out_index].dtype() == mshadow::kInt8) {
-        auto sum_mem_desc           = in_dnnl_mem->get_desc();
-        auto out_dtype              = get_dnnl_type(mshadow::kInt8);
-        sum_mem_desc.data.data_type = static_cast<dnnl_data_type_t>(out_dtype);
-        dnnl_mem_ptr tmp_mem(new dnnl::memory(
-            sum_mem_desc, CpuEngine::Get()->get_engine(), 
out_dnnl_mem->get_data_handle()));
-        DNNLStream::Get()->RegisterMem(tmp_mem);
-        const float u8_reorder_scale     = 0.5;
-        std::vector<float> reorder_scale = {u8_reorder_scale};
-        dnnl::primitive_attr reorder_attr;
-        reorder_attr.set_output_scales(0, reorder_scale);
-        const auto reorder_pd = 
dnnl::reorder::primitive_desc(CpuEngine::Get()->get_engine(),
-                                                              
in_dnnl_mem->get_desc(),
-                                                              
CpuEngine::Get()->get_engine(),
-                                                              sum_mem_desc,
-                                                              reorder_attr);
-        DNNLStream::Get()->RegisterPrimArgs(
-            dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, *in_dnnl_mem}, 
{DNNL_ARG_TO, *tmp_mem}});
-        output = NDArray(tmp_mem);
-      } else {
-        dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
-                                              CpuEngine::Get()->get_engine(),
-                                              
out_dnnl_mem->get_data_handle()));
-        DNNLStream::Get()->RegisterMem(tmp_mem);
-        DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get());
-        output = NDArray(tmp_mem);
-      }
-    }
+    output = PrepareOutputWithSum(in_data[idx.sum], out_data[out_index]);
   } else {
     output = out_data[out_index];
   }
 
   if (dnnl_param.quantized) {
     if (!channel_wise) {
-      weight_min = in_data[idx.weight_min].data().dptr<float>()[0];
-      weight_max = in_data[idx.weight_max].data().dptr<float>()[0];
+      min_max_vec[kWeightMin] = 
in_data[idx.weight_min].data().dptr<float>()[0];
+      min_max_vec[kWeightMax] = 
in_data[idx.weight_max].data().dptr<float>()[0];
       if (has_bias) {
-        bias_min = in_data[idx.bias_min].data().dptr<float>()[0];
-        bias_max = in_data[idx.bias_max].data().dptr<float>()[0];
+        min_max_vec[kBiasMin] = in_data[idx.bias_min].data().dptr<float>()[0];
+        min_max_vec[kBiasMax] = in_data[idx.bias_max].data().dptr<float>()[0];
       }
     }
-    data_min = in_data[idx.data_min].data().dptr<float>()[0];
-    data_max = in_data[idx.data_max].data().dptr<float>()[0];
+    min_max_vec[kDataMin] = in_data[idx.data_min].data().dptr<float>()[0];
+    min_max_vec[kDataMax] = in_data[idx.data_max].data().dptr<float>()[0];
   }
 
-  if (initialized_ && dnnl_param.quantized && 
dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
-    if (channel_wise) {
-      if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
-          cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
-          weight_ver_ != weight.version() ||
-          (has_bias && (bias_ver_ != in_data[idx.bias].version()))) {
-        initialized_ = false;
-      }
-    } else {
-      if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
-          cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
-          cached_weight_min_ != weight_min || cached_weight_max_ != weight_max 
||
-          (has_bias && (cached_bias_min_ != bias_min || cached_bias_max_ != 
bias_max))) {
-        initialized_ = false;
-      }
-    }
-  }
+  initialized_ = CheckInitializationConditions(in_data, min_max_vec, 
channel_wise);
 
   if (!initialized_) {
-    const auto nthreads = 
engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
-    const auto engine   = CpuEngine::Get()->get_engine();
-    cached_data_min_    = data_min;
-    cached_data_max_    = data_max;
-    cached_weight_min_  = weight_min;
-    cached_weight_max_  = weight_max;
-    weight_ver_         = weight.version();
-    cached_weight_      = weight;
-    cached_sum_min_     = sum_min;
-    cached_sum_max_     = sum_max;
+    const auto engine  = CpuEngine::Get()->get_engine();
+    cached_data_min_   = min_max_vec[kDataMin];
+    cached_data_max_   = min_max_vec[kDataMax];
+    cached_weight_min_ = min_max_vec[kWeightMin];
+    cached_weight_max_ = min_max_vec[kWeightMax];
+    weight_ver_        = weight.version();
+    cached_weight_     = weight;
+    cached_sum_min_    = min_max_vec[kSumMin];
+    cached_sum_max_    = min_max_vec[kSumMax];
     if (has_bias) {
-      cached_bias_min_ = bias_min;
-      cached_bias_max_ = bias_max;
+      cached_bias_min_ = min_max_vec[kBiasMin];
+      cached_bias_max_ = min_max_vec[kBiasMax];
       bias_ver_        = in_data[idx.bias].version();
       cached_bias_     = in_data[idx.bias];
     } else {
@@ -252,171 +198,14 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
     }
 
     // create cached out_md
-    const mxnet::TShape oshape = output.shape();
-    dnnl::memory::dims out_dims(2);
-    if (oshape.ndim() == 2) {
-      out_dims[0] = static_cast<index_t>(oshape[0]);
-      out_dims[1] = static_cast<index_t>(oshape[1]);
-    } else {
-      if (!default_param.flatten) {
-        out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() - 
1));
-        out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
-      } else {
-        out_dims[0] = static_cast<index_t>(oshape[0]);
-        out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
-      }
-    }
-    dnnl::memory::desc out_md =
-        dnnl::memory::desc(out_dims,
-                           get_dnnl_type(output.dtype()),
-                           
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(2)));
-    cached_out_mem_ = std::make_shared<dnnl::memory>(out_md, engine);
+    dnnl::memory::desc out_md = CreateOutputMemoryDesc(output.shape(), 
output.dtype());
+    cached_out_mem_           = std::make_shared<dnnl::memory>(out_md, engine);
 
     bool support_channelwise_scale = false;
-    if (dnnl_param.quantized) {
-      CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
-      data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, 
cached_data_max_);
-
-      bool fuse_requantize = false;
-      // Channelwise scaling is only supported when fusion is enabled 
(requantize or dequantize).
-      if (dnnl_param.min_calib_range.has_value() && 
dnnl_param.max_calib_range.has_value()) {
-        cached_output_min_        = dnnl_param.min_calib_range.value();
-        cached_output_max_        = dnnl_param.max_calib_range.value();
-        support_channelwise_scale = true;
-        fuse_requantize           = true;
-      }
-      if (dnnl_param.enable_float_output) {
-        support_channelwise_scale = true;
-      }
-      // channel_wise  support_channelwise_scale  result
-      // True          True                       True
-      // True          False                      Error
-      // False         True/False                 False
-      if (channel_wise && !support_channelwise_scale) {
-        LOG(FATAL)
-            << "Currently, channel-wise quantization requires fuse requantize 
or dequantize."
-            << " Please make sure the `min_calib_range` and `max_calib_range` 
are set when only"
-            << " fuse requantize (outputs of FullyConnected are collected 
during calibration "
-               "phase),"
-            << " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT` and 
"
-            << " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true 
(default is false)";
-      }
-      support_channelwise_scale = support_channelwise_scale && channel_wise;
-
-      if (support_channelwise_scale) {
-        MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
-          weight_scales_ = GetWeightScales<DType>(cached_weight_,
-                                                  has_bias ? &cached_bias_ : 
nullptr,
-                                                  data_scale_,
-                                                  support_channelwise_scale);
-        });
-      } else {
-        weight_scales_.resize(1);
-        weight_scales_[0] =
-            GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_, 
cached_weight_max_);
-        if (has_bias) {
-          if (cached_bias_.dtype() == mshadow::kInt8) {
-            float bias_scale = GetQuantizeScale(mshadow::kInt8, 
cached_bias_min_, cached_bias_max_);
-
-            float bias_int32_rescale = data_scale_ * weight_scales_[0] / 
bias_scale;
-            // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set
-            // the maximum value of bias to INT_MAX / 2.
-            float bias_max_rescale =
-                MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_, 
cached_bias_max_) / bias_scale;
-            if (bias_int32_rescale > bias_max_rescale) {
-              // avoid overflow on bias
-              bias_int32_rescale = bias_max_rescale;
-              float weight_rescale =
-                  bias_int32_rescale * bias_scale / data_scale_ / 
weight_scales_[0];
-              int8_t* weight_ptr = weight.data().dptr<int8_t>();
-              size_t weight_size = weight.shape().Size();
-#pragma omp parallel for num_threads(nthreads)
-              for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
-                weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
-              }
-              weight_scales_[0] *= weight_rescale;
-            }
-            NDArray bias = in_data[fullc::kBias];
-            cached_bias_ =
-                NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, 
mshadow::kInt32);
-            int8_t* bias_ptr            = bias.data().dptr<int8_t>();
-            int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
-            size_t bias_size            = bias.shape().Size();
 
-#pragma omp parallel for num_threads(nthreads)
-            for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
-              quantized_bias_ptr[i] = std::round(bias_ptr[i] * 
bias_int32_rescale);
-            }
-          }
-        }
-      }
-
-      size_t num_channel = cached_weight_.shape()[0];
-      float out_scale    = 1.0f;
-      if (fuse_requantize || dnnl_param.enable_float_output) {
-        float tmp_scale_ = 1.0f;
-        if (fuse_requantize) {
-          if (dnnl_param.with_eltwise) {
-            tmp_scale_ = 1.0 / data_scale_;
-            full_param_.eltwise_param.scale =
-                GetQuantizeScale(output.dtype(), cached_output_min_, 
cached_output_max_);
-          } else {
-            out_scale  = GetQuantizeScale(output.dtype(), cached_output_min_, 
cached_output_max_);
-            tmp_scale_ = out_scale / data_scale_;
-          }
-        } else {
-          tmp_scale_ = 1.0 / data_scale_;
-        }
-
-        if (support_channelwise_scale) {
-          full_param_.output_scales.resize(num_channel);
-#pragma omp parallel for num_threads(nthreads)
-          for (index_t i = 0; i < static_cast<index_t>(num_channel); ++i) {
-            full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i];
-          }
-        } else {
-          full_param_.output_scales.resize(1);
-          full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0];
-        }
-      } else {
-        Stream<cpu>* s = ctx.get_stream<cpu>();
-        if (data.dtype() == mshadow::kInt8) {
-          mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, 
cpu>::Launch(
-              s,
-              1,
-              &cached_output_min_,
-              &cached_output_max_,
-              &data_min,
-              &data_max,
-              &weight_min,
-              &weight_max);
-        } else {
-          mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, 
cpu>::Launch(
-              s,
-              1,
-              &cached_output_min_,
-              &cached_output_max_,
-              &data_min,
-              &data_max,
-              &weight_min,
-              &weight_max);
-        }
-        full_param_.output_scales.resize(0);
-        out_scale = data_scale_ * weight_scales_[0];
-      }
-
-      if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
-        float sum_in_scale =
-            GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, 
cached_sum_max_);
-        full_param_.sum_scale = out_scale / sum_in_scale;
-        if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
-            out_data[out_index].dtype() == mshadow::kInt8) {
-          // In this case, reorder with scale 0.5 is used on in_data[idx.sum] 
to
-          // scale it to s8 range, so sum_scale has to be rescaled as well
-          full_param_.sum_scale *= 2.0;
-        }
-      }
-    }  // if (dnnl_param.quantized)
+    if (dnnl_param.quantized) {
+      support_channelwise_scale = PrepareQuantization(ctx, in_data, output, 
min_max_vec);
+    }
 
     fwd_.reset(new DNNLFullyConnectedForward(full_param_,
                                              ctx.is_train,
@@ -424,33 +213,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
                                              cached_weight_,
                                              (has_bias ? &cached_bias_ : 
nullptr),
                                              out_md));
-
-    // convert weight and bias to the format that DNNL requires
-    if (!dnnl_param.quantized || support_channelwise_scale) {
-      dnnl::memory::desc bias_md;
-      if (has_bias)
-        bias_md = fwd_->fwd_pd.bias_desc();
-      ConvertWeightBias2DNNL(&cached_weight_,
-                             &cached_bias_,
-                             has_bias,
-                             fwd_->fwd_pd.weights_desc(),
-                             has_bias ? &bias_md : nullptr,
-                             1,
-                             data_scale_,
-                             weight_scales_,
-                             false);
-    } else {
-      const auto def_weight_mem = static_cast<const 
dnnl::memory*>(weight.GetDNNLData());
-      if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
-        auto weight_desc       = fwd_->fwd_pd.weights_desc();
-        cached_weight_         = NDArray(&weight_desc);
-        auto cached_weight_mem = static_cast<const 
dnnl::memory*>(cached_weight_.GetDNNLData());
-        std::unordered_map<int, dnnl::memory> args(
-            {{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO, 
*cached_weight_mem}});
-        DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem, 
*cached_weight_mem),
-                                            args);
-      }
-    }
+    GetCachedWeightsAndBias(weight, support_channelwise_scale, has_bias);
 
     const auto data_mem = static_cast<const dnnl::memory*>(data.GetDNNLData());
     cached_data_mem_    = std::make_shared<dnnl::memory>(data_mem->get_desc(), 
engine);
@@ -499,6 +262,335 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
   }
 }
 
+NDArray SgDNNLFCOp::PrepareOutputWithSum(const NDArray& sum_input, const 
NDArray& output) {
+  if (!initialized_) {
+    // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace 
option,
+    // which make check (req[out_index] == kWriteInplace) useless.
+    auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(sum_input.GetDNNLData());
+    auto out_dnnl_mem = static_cast<const dnnl::memory*>(output.GetDNNLData());
+    if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle() &&
+        sum_input.dtype() == output.dtype()) {
+      inplace_ = true;
+    }
+  }
+  if (inplace_) {
+    return sum_input;
+  } else {
+    // Not in place: copy sum_input into output.
+    auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(sum_input.GetDNNLData());
+    auto out_dnnl_mem = static_cast<const dnnl::memory*>(output.GetDNNLData());
+    if (output.dtype() == mshadow::kInt32) {
+      auto mem_desc           = in_dnnl_mem->get_desc();
+      auto this_dtype         = get_dnnl_type(mshadow::kInt32);
+      mem_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
+      dnnl_mem_ptr tmp_mem(new dnnl::memory(
+          mem_desc, CpuEngine::Get()->get_engine(), 
out_dnnl_mem->get_data_handle()));
+      DNNLStream::Get()->RegisterMem(tmp_mem);
+      DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*in_dnnl_mem, 
*tmp_mem),
+                                          {{DNNL_ARG_FROM, *in_dnnl_mem}, 
{DNNL_ARG_TO, *tmp_mem}});
+      return NDArray(tmp_mem);
+    } else if (sum_input.dtype() == mshadow::kUint8 && output.dtype() == 
mshadow::kInt8) {
+      auto sum_mem_desc           = in_dnnl_mem->get_desc();
+      auto out_dtype              = get_dnnl_type(mshadow::kInt8);
+      sum_mem_desc.data.data_type = static_cast<dnnl_data_type_t>(out_dtype);
+      dnnl_mem_ptr tmp_mem(new dnnl::memory(
+          sum_mem_desc, CpuEngine::Get()->get_engine(), 
out_dnnl_mem->get_data_handle()));
+      DNNLStream::Get()->RegisterMem(tmp_mem);
+      const float u8_reorder_scale     = 0.5;
+      std::vector<float> reorder_scale = {u8_reorder_scale};
+      dnnl::primitive_attr reorder_attr;
+      reorder_attr.set_output_scales(0, reorder_scale);
+      const auto reorder_pd = 
dnnl::reorder::primitive_desc(CpuEngine::Get()->get_engine(),
+                                                            
in_dnnl_mem->get_desc(),
+                                                            
CpuEngine::Get()->get_engine(),
+                                                            sum_mem_desc,
+                                                            reorder_attr);
+      DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd),
+                                          {{DNNL_ARG_FROM, *in_dnnl_mem}, 
{DNNL_ARG_TO, *tmp_mem}});
+      return NDArray(tmp_mem);
+    } else {
+      dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
+                                            CpuEngine::Get()->get_engine(),
+                                            out_dnnl_mem->get_data_handle()));
+      DNNLStream::Get()->RegisterMem(tmp_mem);
+      DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get());
+      return NDArray(tmp_mem);
+    }
+  }
+}
+
+bool SgDNNLFCOp::CheckInitializationConditions(const std::vector<NDArray>& 
inputs,
+                                               const std::vector<float>& 
min_max_vec,
+                                               bool is_channel_wise) {
+  if (initialized_ && full_param_.dnnl_param.quantized &&
+      dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
+    bool has_bias = !full_param_.default_param.no_bias;
+    if (cached_data_min_ != min_max_vec[kDataMin] || cached_data_max_ != 
min_max_vec[kDataMax] ||
+        cached_sum_min_ != min_max_vec[kSumMin] || cached_sum_max_ != 
min_max_vec[kSumMax]) {
+      return false;
+    }
+
+    if (is_channel_wise) {
+      if (weight_ver_ != inputs[fullc::kWeight].version() ||
+          (has_bias && (bias_ver_ != inputs[fullc::kBias].version()))) {
+        return false;
+      }
+    } else {
+      if (cached_weight_min_ != min_max_vec[kWeightMin] ||
+          cached_weight_max_ != min_max_vec[kWeightMax] ||
+          (has_bias && (cached_bias_min_ != min_max_vec[kBiasMin] ||
+                        cached_bias_max_ != min_max_vec[kBiasMax]))) {
+        return false;
+      }
+    }
+    return true;
+  }
+  return false;
+}
+
+dnnl::memory::desc SgDNNLFCOp::CreateOutputMemoryDesc(const mxnet::TShape& 
oshape, int out_dtype) {
+  auto default_param = full_param_.default_param;
+  dnnl::memory::dims out_dims(2);
+  if (oshape.ndim() == 2) {
+    out_dims[0] = static_cast<index_t>(oshape[0]);
+    out_dims[1] = static_cast<index_t>(oshape[1]);
+  } else {
+    if (!default_param.flatten) {
+      out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() - 
1));
+      out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
+    } else {
+      out_dims[0] = static_cast<index_t>(oshape[0]);
+      out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
+    }
+  }
+  dnnl::memory::desc out_md =
+      dnnl::memory::desc(out_dims,
+                         get_dnnl_type(out_dtype),
+                         
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(2)));
+  return out_md;
+}
+
+bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
+                                     const std::vector<NDArray>& in_data,
+                                     const NDArray& output,
+                                     const std::vector<float>& min_max_vec) {
+  const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+  const FCInputIndex idx(full_param_);
+  bool support_channelwise_scale = false;
+  auto dnnl_param                = full_param_.dnnl_param;
+  bool has_bias                  = !full_param_.default_param.no_bias;
+  const NDArray& data            = in_data[fullc::kData];
+  const NDArray& weight          = in_data[fullc::kWeight];
+  const bool channel_wise = dnnl_param.quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
+                            dnnl_param.channel_wise_quantize.value();
+
+  CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
+  data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, 
cached_data_max_);
+
+  bool fuse_requantize = false;
+  // Channelwise scaling is only supported when fusion is enabled (requantize 
or dequantize).
+  if (dnnl_param.min_calib_range.has_value() && 
dnnl_param.max_calib_range.has_value()) {
+    cached_output_min_        = dnnl_param.min_calib_range.value();
+    cached_output_max_        = dnnl_param.max_calib_range.value();
+    support_channelwise_scale = true;
+    fuse_requantize           = true;
+  }
+  if (dnnl_param.enable_float_output) {
+    support_channelwise_scale = true;
+  }
+  // channel_wise  support_channelwise_scale  result
+  // True          True                       True
+  // True          False                      Error
+  // False         True/False                 False
+  if (channel_wise && !support_channelwise_scale) {
+    LOG(FATAL) << "Currently, channel-wise quantization requires fuse 
requantize or dequantize."
+               << " Please make sure the `min_calib_range` and 
`max_calib_range` are set when only"
+               << " fuse requantize (outputs of FullyConnected are collected 
during calibration "
+                  "phase),"
+               << " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT` 
and "
+               << " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true 
(default is false)";
+  }
+  support_channelwise_scale = support_channelwise_scale && channel_wise;
+
+  if (support_channelwise_scale) {
+    MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
+      weight_scales_ = GetWeightScales<DType>(cached_weight_,
+                                              has_bias ? &cached_bias_ : 
nullptr,
+                                              data_scale_,
+                                              support_channelwise_scale);
+    });
+  } else {
+    weight_scales_.resize(1);
+    weight_scales_[0] =
+        GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_, 
cached_weight_max_);
+    if (has_bias) {
+      if (cached_bias_.dtype() == mshadow::kInt8) {
+        float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_bias_min_, 
cached_bias_max_);
+
+        float bias_int32_rescale = data_scale_ * weight_scales_[0] / 
bias_scale;
+        // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set
+        // the maximum value of bias to INT_MAX / 2.
+        float bias_max_rescale =
+            MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_, 
cached_bias_max_) / bias_scale;
+        if (bias_int32_rescale > bias_max_rescale) {
+          // avoid overflow on bias
+          bias_int32_rescale   = bias_max_rescale;
+          float weight_rescale = bias_int32_rescale * bias_scale / data_scale_ 
/ weight_scales_[0];
+          int8_t* weight_ptr   = weight.data().dptr<int8_t>();
+          size_t weight_size   = weight.shape().Size();
+#pragma omp parallel for num_threads(nthreads)
+          for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
+            weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+          }
+          weight_scales_[0] *= weight_rescale;
+        }
+        NDArray bias = in_data[fullc::kBias];
+        cached_bias_ =
+            NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, 
mshadow::kInt32);
+        int8_t* bias_ptr            = bias.data().dptr<int8_t>();
+        int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
+        size_t bias_size            = bias.shape().Size();
+
+#pragma omp parallel for num_threads(nthreads)
+        for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
+          quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale);
+        }
+      }
+    }
+  }
+
+  size_t num_channel = cached_weight_.shape()[0];
+  float out_scale    = 1.0f;
+  if (fuse_requantize || dnnl_param.enable_float_output) {
+    float tmp_scale_ = 1.0f;
+    if (fuse_requantize) {
+      if (dnnl_param.with_eltwise) {
+        tmp_scale_ = 1.0 / data_scale_;
+        full_param_.eltwise_param.scale =
+            GetQuantizeScale(output.dtype(), cached_output_min_, 
cached_output_max_);
+      } else {
+        out_scale  = GetQuantizeScale(output.dtype(), cached_output_min_, 
cached_output_max_);
+        tmp_scale_ = out_scale / data_scale_;
+      }
+    } else {
+      tmp_scale_ = 1.0 / data_scale_;
+    }
+
+    if (support_channelwise_scale) {
+      full_param_.output_scales.resize(num_channel);
+#pragma omp parallel for num_threads(nthreads)
+      for (index_t i = 0; i < static_cast<index_t>(num_channel); ++i) {
+        full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i];
+      }
+    } else {
+      full_param_.output_scales.resize(1);
+      full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0];
+    }
+  } else {
+    Stream<cpu>* s = ctx.get_stream<cpu>();
+    if (data.dtype() == mshadow::kInt8) {
+      mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, 
cpu>::Launch(
+          s,
+          1,
+          &cached_output_min_,
+          &cached_output_max_,
+          &(min_max_vec[kDataMin]),
+          &(min_max_vec[kDataMax]),
+          &(min_max_vec[kWeightMin]),
+          &(min_max_vec[kWeightMax]));
+    } else {
+      mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, 
cpu>::Launch(
+          s,
+          1,
+          &cached_output_min_,
+          &cached_output_max_,
+          &(min_max_vec[kDataMin]),
+          &(min_max_vec[kDataMax]),
+          &(min_max_vec[kWeightMin]),
+          &(min_max_vec[kWeightMax]));
+    }
+    full_param_.output_scales.resize(0);
+    out_scale = data_scale_ * weight_scales_[0];
+  }
+
+  if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+    float sum_in_scale =
+        GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, 
cached_sum_max_);
+    full_param_.sum_scale = out_scale / sum_in_scale;
+    if (in_data[idx.sum].dtype() == mshadow::kUint8 && output.dtype() == 
mshadow::kInt8) {
+      // In this case, reorder with scale 0.5 is used on in_data[idx.sum] to
+      // scale it to s8 range, so sum_scale has to be rescaled as well
+      full_param_.sum_scale *= 2.0;
+    }
+  }
+  return support_channelwise_scale;
+}
+
+void SgDNNLFCOp::GetCachedWeightsAndBias(const NDArray& weight,
+                                         bool support_channelwise_scale,
+                                         bool has_bias) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<DNNLFullyconSignature, 
std::pair<NDArray, NDArray>, OpHash>
+      fcWeightsAndBias;
+#else
+  static MX_THREAD_LOCAL
+      std::unordered_map<DNNLFullyconSignature, std::pair<NDArray, NDArray>, 
OpHash>
+          fcWeightsAndBias;
+#endif
+  static const bool use_cache = 
!(dmlc::GetEnv("MXNET_ONEDNN_DISABLE_FC_CACHE", 0));
+  const bool has_id           = attrs.dict.count("__identifier__");
+  bool read_from_cache        = false;
+
+  DNNLFullyconSignature key(full_param_.default_param);
+  if (use_cache && has_id) {
+    key.AddSign(fwd_->fwd_pd.weights_desc());
+    key.AddSign(attrs.dict["__identifier__"]);
+    key.AddSign(attrs.name);
+
+    auto it = fcWeightsAndBias.find(key);
+    if (it != fcWeightsAndBias.end()) {
+      cached_weight_  = it->second.first;
+      cached_bias_    = it->second.second;
+      read_from_cache = true;
+      common::LogOnce(
+          "oneDNN optimized version of FullyConnected for inference is being 
used. Weights and "
+          "bias are cached and can not be dynamically changed during runtime. 
To disable caching "
+          "mechanism use MXNET_ONEDNN_DISABLE_FC_CACHE=1.");
+    }
+  }
+
+  if (!read_from_cache) {
+    // convert weight and bias to the format that oneDNN requires
+    if (!full_param_.dnnl_param.quantized || support_channelwise_scale) {
+      dnnl::memory::desc bias_md;
+      if (has_bias)
+        bias_md = fwd_->fwd_pd.bias_desc();
+      ConvertWeightBias2DNNL(&cached_weight_,
+                             &cached_bias_,
+                             has_bias,
+                             fwd_->fwd_pd.weights_desc(),
+                             has_bias ? &bias_md : nullptr,
+                             1,
+                             data_scale_,
+                             weight_scales_,
+                             false);
+    } else {
+      const auto def_weight_mem = weight.GetDNNLData();
+      if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
+        auto weight_desc       = fwd_->fwd_pd.weights_desc();
+        cached_weight_         = NDArray(&weight_desc);
+        auto cached_weight_mem = cached_weight_.GetDNNLData();
+        std::unordered_map<int, dnnl::memory> args(
+            {{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO, 
*cached_weight_mem}});
+        DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem, 
*cached_weight_mem),
+                                            args);
+      }
+    }
+    if (use_cache && has_id)
+      AddToCache(&fcWeightsAndBias, key, {cached_weight_, cached_bias_});
+  }
+}
+
 static void SgDNNLFCParamParser(nnvm::NodeAttrs* attrs) {
   // For backward compatible, with_relu->with_eltwise
   auto legacy = attrs->dict.find("with_relu");
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h 
b/src/operator/subgraph/dnnl/dnnl_fc_property.h
index fd5272ef5a..ccf5401705 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h
@@ -167,7 +167,9 @@ class SgDNNLFCProperty : public SubgraphProperty {
 
   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
                                      const int subgraph_id = 0) const override 
{
-    nnvm::ObjectPtr n = nnvm::Node::Create();
+    // distingush between exactly same node in different networks - for 
caching weights
+    static unsigned int node_identifier = 0;
+    nnvm::ObjectPtr n                   = nnvm::Node::Create();
     // This op has single output, remove duplicated.
     auto last_node = sym.outputs[0].node;
     nnvm::Symbol new_sym;
@@ -189,6 +191,7 @@ class SgDNNLFCProperty : public SubgraphProperty {
     n->attrs.name = node_name.str();
     n->attrs.op   = Op::Get("_sg_onednn_fully_connected");
     CHECK(n->attrs.op);
+    n->attrs.dict["__identifier__"] = std::to_string(node_identifier++);
     n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
     n->op()->attr_parser(&(n->attrs));
     return n;

Reply via email to