This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 84b1626b66 oneDNN FullyConnected weight caching & refactor (#21047)
84b1626b66 is described below
commit 84b1626b66f84a35e87a95d26226c3a3171ea78c
Author: bgawrych <[email protected]>
AuthorDate: Wed Jul 6 12:02:58 2022 +0200
oneDNN FullyConnected weight caching & refactor (#21047)
* FC weight and bias caching
* prepare output for sum
* check initialization conditions
* create output mem desc
* PrepareQuantization
* remove unused variables
* cleanup
* Enable BRGEMM
* Reorder functions
* make minmax enum anonymous
* node identificator & env flag
* fix sanity
* fix sanity
* apply review
* rename variable
---
src/operator/nn/dnnl/dnnl_base-inl.h | 9 +-
src/operator/operator_common.h | 7 +-
src/operator/subgraph/dnnl/dnnl_fc.cc | 688 +++++++++++++++-----------
src/operator/subgraph/dnnl/dnnl_fc_property.h | 5 +-
4 files changed, 401 insertions(+), 308 deletions(-)
diff --git a/src/operator/nn/dnnl/dnnl_base-inl.h
b/src/operator/nn/dnnl/dnnl_base-inl.h
index 66c1dc2c99..0b9645b8a3 100644
--- a/src/operator/nn/dnnl/dnnl_base-inl.h
+++ b/src/operator/nn/dnnl/dnnl_base-inl.h
@@ -351,13 +351,6 @@ inline static dnnl::memory::desc GetMemDesc(const NDArray&
arr, int dtype = -1)
return dnnl::memory::desc{dims, get_dnnl_type(dtype),
dnnl::memory::format_tag::any};
}
-inline static bool ChooseBRGEMMImpl(const dnnl::memory::dims& weight_dims,
size_t batch_size) {
- // Conditions based on measurement results done on CLX8280
- // https://github.com/apache/incubator-mxnet/pull/20533
- return weight_dims[0] >= 1024 && weight_dims[1] >= 1024 && batch_size >=
16384 &&
- weight_dims[0] % 64 == 0 && weight_dims[1] % 64 == 0;
-}
-
inline static dnnl::memory::desc GetFCWeightDesc(const NDArray& arr,
size_t batch_size,
int dtype = -1) {
@@ -370,7 +363,7 @@ inline static dnnl::memory::desc GetFCWeightDesc(const
NDArray& arr,
// for batch 256 alexnet benchmark test
const bool force_fc_ab_format =
dmlc::GetEnv("MXNET_ONEDNN_FORCE_FC_AB_FORMAT", false);
if (dims.size() == 2) {
- if (force_fc_ab_format || !ChooseBRGEMMImpl(dims, batch_size)) {
+ if (force_fc_ab_format || dtype != mshadow::kInt8) {
format = dnnl::memory::format_tag::ab;
}
}
diff --git a/src/operator/operator_common.h b/src/operator/operator_common.h
index a0f158f2f9..ffb9898104 100644
--- a/src/operator/operator_common.h
+++ b/src/operator/operator_common.h
@@ -547,7 +547,7 @@ class OpSignature {
#if MXNET_USE_ONEDNN == 1
void AddSign(const dnnl::memory::desc& desc) {
- hash = hash * 2 + desc.data.format_kind;
+ hash = hash * 2 + desc.data.format_kind;
eles.push_back(desc.data.format_kind);
hash = hash * 2 + desc.data.data_type;
eles.push_back(desc.data.data_type);
@@ -617,6 +617,11 @@ class OpSignature {
#endif
+ void AddSign(const std::string& s) {
+ uint64_t key = static_cast<uint64_t>(std::hash<std::string>{}(s));
+ eles.push_back(key);
+ }
+
void AddSign(const std::vector<NDArray>& arrs) {
for (auto& arr : arrs) {
AddSign(arr);
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 2371ce8bd9..22971bf487 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -47,7 +47,9 @@ namespace op {
class SgDNNLFCOp {
public:
explicit SgDNNLFCOp(const nnvm::NodeAttrs& attrs)
- : subgraph_sym_(*attrs.subgraphs[0]),
full_param_(nnvm::get<DNNLFCFullParam>(attrs.parsed)) {}
+ : subgraph_sym_(*attrs.subgraphs[0]),
+ attrs(attrs),
+ full_param_(nnvm::get<DNNLFCFullParam>(attrs.parsed)) {}
void Forward(const OpContext& ctx,
const std::vector<NDArray>& inputs,
@@ -63,11 +65,27 @@ class SgDNNLFCOp {
}
private:
+ enum { kDataMin = 0, kDataMax, kWeightMin, kWeightMax, kBiasMin, kBiasMax,
kSumMin, kSumMax };
+ const size_t MIN_MAX_COUNT = 8;
+
+ NDArray PrepareOutputWithSum(const NDArray& sum_input, const NDArray&
output);
+ bool CheckInitializationConditions(const std::vector<NDArray>& inputs,
+ const std::vector<float>& min_max_vec,
+ bool is_channel_wise);
+ bool PrepareQuantization(const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const NDArray& output,
+ const std::vector<float>& min_max_vec);
+ dnnl::memory::desc CreateOutputMemoryDesc(const mxnet::TShape& oshape, int
out_dtype);
+ void GetCachedWeightsAndBias(const NDArray& weight,
+ bool support_channelwise_scale,
+ bool has_bias);
+ nnvm::Symbol subgraph_sym_;
+ nnvm::NodeAttrs attrs;
+ DNNLFCFullParam full_param_;
bool initialized_{false};
bool reorder_data_{false};
bool inplace_{false};
- nnvm::Symbol subgraph_sym_;
- DNNLFCFullParam full_param_;
dnnl_args_map_t args_;
std::shared_ptr<DNNLFullyConnectedForward> fwd_;
std::shared_ptr<dnnl::memory> cached_data_mem_;
@@ -94,16 +112,15 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
- auto& dnnl_param = full_param_.dnnl_param;
- auto& default_param = full_param_.default_param;
- const bool has_bias = !default_param.no_bias;
- const bool quantized = dnnl_param.quantized;
- const bool out_quantized = dnnl_param.quantized &&
!dnnl_param.enable_float_output;
- const bool channel_wise = quantized &&
dnnl_param.channel_wise_quantize.has_value() &&
+ const auto& default_param = full_param_.default_param;
+ const auto& dnnl_param = full_param_.dnnl_param;
+ const bool has_bias = !default_param.no_bias;
+ const bool quantized = dnnl_param.quantized;
+ const bool out_quantized = dnnl_param.quantized &&
!dnnl_param.enable_float_output;
+ const bool channel_wise = quantized &&
dnnl_param.channel_wise_quantize.has_value() &&
dnnl_param.channel_wise_quantize.value();
const FCInputIndex idx(full_param_);
-
CHECK_EQ(in_data.size(), idx.GetTotal());
int index = 0;
@@ -112,125 +129,54 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
const int out_max_index = out_quantized ? index++ : 0;
CHECK_EQ(out_data.size(), index); // index is equal to total number of
outputs
- float data_min = 0.0f;
- float data_max = 0.0f;
- float weight_min = 0.0f;
- float weight_max = 0.0f;
- float bias_min = 0.0f;
- float bias_max = 0.0f;
+ std::vector<float> min_max_vec(MIN_MAX_COUNT);
+ min_max_vec[kDataMin] = 0.0f;
+ min_max_vec[kDataMax] = 0.0f;
+ min_max_vec[kWeightMin] = 0.0f;
+ min_max_vec[kWeightMax] = 0.0f;
+ min_max_vec[kBiasMin] = 0.0f;
+ min_max_vec[kBiasMax] = 0.0f;
- const float sum_min = idx.sum_min ?
in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
- const float sum_max = idx.sum_max ?
in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
+ min_max_vec[kSumMin] = idx.sum_min ?
in_data[idx.sum_min].data().dptr<float>()[0] : 0.0f;
+ min_max_vec[kSumMax] = idx.sum_max ?
in_data[idx.sum_max].data().dptr<float>()[0] : 0.0f;
NDArray data = in_data[idx.data];
const NDArray& weight = in_data[idx.weight];
NDArray output;
if (dnnl_param.with_sum) {
- if (!initialized_) {
- // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace
option,
- // which make check (req[out_index] == kWriteInplace) useless.
- auto in_dnnl_mem = static_cast<const
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
- auto out_dnnl_mem = static_cast<const
dnnl::memory*>(out_data[out_index].GetDNNLData());
- if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle() &&
- in_data[idx.sum].dtype() == out_data[out_index].dtype()) {
- inplace_ = true;
- }
- }
- if (inplace_) {
- output = in_data[idx.sum];
- } else {
- // Not in place: copy in_data[idx.sum] into outputs[out_index].
- auto in_dnnl_mem = static_cast<const
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
- auto out_dnnl_mem = static_cast<const
dnnl::memory*>(out_data[out_index].GetDNNLData());
- if (out_data[out_index].dtype() == mshadow::kInt32) {
- auto mem_desc = in_dnnl_mem->get_desc();
- auto this_dtype = get_dnnl_type(mshadow::kInt32);
- mem_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
- dnnl_mem_ptr tmp_mem(new dnnl::memory(
- mem_desc, CpuEngine::Get()->get_engine(),
out_dnnl_mem->get_data_handle()));
- DNNLStream::Get()->RegisterMem(tmp_mem);
- DNNLStream::Get()->RegisterPrimArgs(
- dnnl::reorder(*in_dnnl_mem, *tmp_mem),
- {{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
- output = NDArray(tmp_mem);
- } else if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
- out_data[out_index].dtype() == mshadow::kInt8) {
- auto sum_mem_desc = in_dnnl_mem->get_desc();
- auto out_dtype = get_dnnl_type(mshadow::kInt8);
- sum_mem_desc.data.data_type = static_cast<dnnl_data_type_t>(out_dtype);
- dnnl_mem_ptr tmp_mem(new dnnl::memory(
- sum_mem_desc, CpuEngine::Get()->get_engine(),
out_dnnl_mem->get_data_handle()));
- DNNLStream::Get()->RegisterMem(tmp_mem);
- const float u8_reorder_scale = 0.5;
- std::vector<float> reorder_scale = {u8_reorder_scale};
- dnnl::primitive_attr reorder_attr;
- reorder_attr.set_output_scales(0, reorder_scale);
- const auto reorder_pd =
dnnl::reorder::primitive_desc(CpuEngine::Get()->get_engine(),
-
in_dnnl_mem->get_desc(),
-
CpuEngine::Get()->get_engine(),
- sum_mem_desc,
- reorder_attr);
- DNNLStream::Get()->RegisterPrimArgs(
- dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, *in_dnnl_mem},
{DNNL_ARG_TO, *tmp_mem}});
- output = NDArray(tmp_mem);
- } else {
- dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
- CpuEngine::Get()->get_engine(),
-
out_dnnl_mem->get_data_handle()));
- DNNLStream::Get()->RegisterMem(tmp_mem);
- DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get());
- output = NDArray(tmp_mem);
- }
- }
+ output = PrepareOutputWithSum(in_data[idx.sum], out_data[out_index]);
} else {
output = out_data[out_index];
}
if (dnnl_param.quantized) {
if (!channel_wise) {
- weight_min = in_data[idx.weight_min].data().dptr<float>()[0];
- weight_max = in_data[idx.weight_max].data().dptr<float>()[0];
+ min_max_vec[kWeightMin] =
in_data[idx.weight_min].data().dptr<float>()[0];
+ min_max_vec[kWeightMax] =
in_data[idx.weight_max].data().dptr<float>()[0];
if (has_bias) {
- bias_min = in_data[idx.bias_min].data().dptr<float>()[0];
- bias_max = in_data[idx.bias_max].data().dptr<float>()[0];
+ min_max_vec[kBiasMin] = in_data[idx.bias_min].data().dptr<float>()[0];
+ min_max_vec[kBiasMax] = in_data[idx.bias_max].data().dptr<float>()[0];
}
}
- data_min = in_data[idx.data_min].data().dptr<float>()[0];
- data_max = in_data[idx.data_max].data().dptr<float>()[0];
+ min_max_vec[kDataMin] = in_data[idx.data_min].data().dptr<float>()[0];
+ min_max_vec[kDataMax] = in_data[idx.data_max].data().dptr<float>()[0];
}
- if (initialized_ && dnnl_param.quantized &&
dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
- if (channel_wise) {
- if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
- cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
- weight_ver_ != weight.version() ||
- (has_bias && (bias_ver_ != in_data[idx.bias].version()))) {
- initialized_ = false;
- }
- } else {
- if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
- cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
- cached_weight_min_ != weight_min || cached_weight_max_ != weight_max
||
- (has_bias && (cached_bias_min_ != bias_min || cached_bias_max_ !=
bias_max))) {
- initialized_ = false;
- }
- }
- }
+ initialized_ = CheckInitializationConditions(in_data, min_max_vec,
channel_wise);
if (!initialized_) {
- const auto nthreads =
engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
- const auto engine = CpuEngine::Get()->get_engine();
- cached_data_min_ = data_min;
- cached_data_max_ = data_max;
- cached_weight_min_ = weight_min;
- cached_weight_max_ = weight_max;
- weight_ver_ = weight.version();
- cached_weight_ = weight;
- cached_sum_min_ = sum_min;
- cached_sum_max_ = sum_max;
+ const auto engine = CpuEngine::Get()->get_engine();
+ cached_data_min_ = min_max_vec[kDataMin];
+ cached_data_max_ = min_max_vec[kDataMax];
+ cached_weight_min_ = min_max_vec[kWeightMin];
+ cached_weight_max_ = min_max_vec[kWeightMax];
+ weight_ver_ = weight.version();
+ cached_weight_ = weight;
+ cached_sum_min_ = min_max_vec[kSumMin];
+ cached_sum_max_ = min_max_vec[kSumMax];
if (has_bias) {
- cached_bias_min_ = bias_min;
- cached_bias_max_ = bias_max;
+ cached_bias_min_ = min_max_vec[kBiasMin];
+ cached_bias_max_ = min_max_vec[kBiasMax];
bias_ver_ = in_data[idx.bias].version();
cached_bias_ = in_data[idx.bias];
} else {
@@ -252,171 +198,14 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
}
// create cached out_md
- const mxnet::TShape oshape = output.shape();
- dnnl::memory::dims out_dims(2);
- if (oshape.ndim() == 2) {
- out_dims[0] = static_cast<index_t>(oshape[0]);
- out_dims[1] = static_cast<index_t>(oshape[1]);
- } else {
- if (!default_param.flatten) {
- out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() -
1));
- out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
- } else {
- out_dims[0] = static_cast<index_t>(oshape[0]);
- out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
- }
- }
- dnnl::memory::desc out_md =
- dnnl::memory::desc(out_dims,
- get_dnnl_type(output.dtype()),
-
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(2)));
- cached_out_mem_ = std::make_shared<dnnl::memory>(out_md, engine);
+ dnnl::memory::desc out_md = CreateOutputMemoryDesc(output.shape(),
output.dtype());
+ cached_out_mem_ = std::make_shared<dnnl::memory>(out_md, engine);
bool support_channelwise_scale = false;
- if (dnnl_param.quantized) {
- CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
- data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_,
cached_data_max_);
-
- bool fuse_requantize = false;
- // Channelwise scaling is only supported when fusion is enabled
(requantize or dequantize).
- if (dnnl_param.min_calib_range.has_value() &&
dnnl_param.max_calib_range.has_value()) {
- cached_output_min_ = dnnl_param.min_calib_range.value();
- cached_output_max_ = dnnl_param.max_calib_range.value();
- support_channelwise_scale = true;
- fuse_requantize = true;
- }
- if (dnnl_param.enable_float_output) {
- support_channelwise_scale = true;
- }
- // channel_wise support_channelwise_scale result
- // True True True
- // True False Error
- // False True/False False
- if (channel_wise && !support_channelwise_scale) {
- LOG(FATAL)
- << "Currently, channel-wise quantization requires fuse requantize
or dequantize."
- << " Please make sure the `min_calib_range` and `max_calib_range`
are set when only"
- << " fuse requantize (outputs of FullyConnected are collected
during calibration "
- "phase),"
- << " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT` and
"
- << " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true
(default is false)";
- }
- support_channelwise_scale = support_channelwise_scale && channel_wise;
-
- if (support_channelwise_scale) {
- MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
- weight_scales_ = GetWeightScales<DType>(cached_weight_,
- has_bias ? &cached_bias_ :
nullptr,
- data_scale_,
- support_channelwise_scale);
- });
- } else {
- weight_scales_.resize(1);
- weight_scales_[0] =
- GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_,
cached_weight_max_);
- if (has_bias) {
- if (cached_bias_.dtype() == mshadow::kInt8) {
- float bias_scale = GetQuantizeScale(mshadow::kInt8,
cached_bias_min_, cached_bias_max_);
-
- float bias_int32_rescale = data_scale_ * weight_scales_[0] /
bias_scale;
- // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set
- // the maximum value of bias to INT_MAX / 2.
- float bias_max_rescale =
- MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_,
cached_bias_max_) / bias_scale;
- if (bias_int32_rescale > bias_max_rescale) {
- // avoid overflow on bias
- bias_int32_rescale = bias_max_rescale;
- float weight_rescale =
- bias_int32_rescale * bias_scale / data_scale_ /
weight_scales_[0];
- int8_t* weight_ptr = weight.data().dptr<int8_t>();
- size_t weight_size = weight.shape().Size();
-#pragma omp parallel for num_threads(nthreads)
- for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
- weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
- }
- weight_scales_[0] *= weight_rescale;
- }
- NDArray bias = in_data[fullc::kBias];
- cached_bias_ =
- NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true,
mshadow::kInt32);
- int8_t* bias_ptr = bias.data().dptr<int8_t>();
- int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
- size_t bias_size = bias.shape().Size();
-#pragma omp parallel for num_threads(nthreads)
- for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
- quantized_bias_ptr[i] = std::round(bias_ptr[i] *
bias_int32_rescale);
- }
- }
- }
- }
-
- size_t num_channel = cached_weight_.shape()[0];
- float out_scale = 1.0f;
- if (fuse_requantize || dnnl_param.enable_float_output) {
- float tmp_scale_ = 1.0f;
- if (fuse_requantize) {
- if (dnnl_param.with_eltwise) {
- tmp_scale_ = 1.0 / data_scale_;
- full_param_.eltwise_param.scale =
- GetQuantizeScale(output.dtype(), cached_output_min_,
cached_output_max_);
- } else {
- out_scale = GetQuantizeScale(output.dtype(), cached_output_min_,
cached_output_max_);
- tmp_scale_ = out_scale / data_scale_;
- }
- } else {
- tmp_scale_ = 1.0 / data_scale_;
- }
-
- if (support_channelwise_scale) {
- full_param_.output_scales.resize(num_channel);
-#pragma omp parallel for num_threads(nthreads)
- for (index_t i = 0; i < static_cast<index_t>(num_channel); ++i) {
- full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i];
- }
- } else {
- full_param_.output_scales.resize(1);
- full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0];
- }
- } else {
- Stream<cpu>* s = ctx.get_stream<cpu>();
- if (data.dtype() == mshadow::kInt8) {
- mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct,
cpu>::Launch(
- s,
- 1,
- &cached_output_min_,
- &cached_output_max_,
- &data_min,
- &data_max,
- &weight_min,
- &weight_max);
- } else {
- mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct,
cpu>::Launch(
- s,
- 1,
- &cached_output_min_,
- &cached_output_max_,
- &data_min,
- &data_max,
- &weight_min,
- &weight_max);
- }
- full_param_.output_scales.resize(0);
- out_scale = data_scale_ * weight_scales_[0];
- }
-
- if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
- float sum_in_scale =
- GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_,
cached_sum_max_);
- full_param_.sum_scale = out_scale / sum_in_scale;
- if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
- out_data[out_index].dtype() == mshadow::kInt8) {
- // In this case, reorder with scale 0.5 is used on in_data[idx.sum]
to
- // scale it to s8 range, so sum_scale has to be rescaled as well
- full_param_.sum_scale *= 2.0;
- }
- }
- } // if (dnnl_param.quantized)
+ if (dnnl_param.quantized) {
+ support_channelwise_scale = PrepareQuantization(ctx, in_data, output,
min_max_vec);
+ }
fwd_.reset(new DNNLFullyConnectedForward(full_param_,
ctx.is_train,
@@ -424,33 +213,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
cached_weight_,
(has_bias ? &cached_bias_ :
nullptr),
out_md));
-
- // convert weight and bias to the format that DNNL requires
- if (!dnnl_param.quantized || support_channelwise_scale) {
- dnnl::memory::desc bias_md;
- if (has_bias)
- bias_md = fwd_->fwd_pd.bias_desc();
- ConvertWeightBias2DNNL(&cached_weight_,
- &cached_bias_,
- has_bias,
- fwd_->fwd_pd.weights_desc(),
- has_bias ? &bias_md : nullptr,
- 1,
- data_scale_,
- weight_scales_,
- false);
- } else {
- const auto def_weight_mem = static_cast<const
dnnl::memory*>(weight.GetDNNLData());
- if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
- auto weight_desc = fwd_->fwd_pd.weights_desc();
- cached_weight_ = NDArray(&weight_desc);
- auto cached_weight_mem = static_cast<const
dnnl::memory*>(cached_weight_.GetDNNLData());
- std::unordered_map<int, dnnl::memory> args(
- {{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO,
*cached_weight_mem}});
- DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem,
*cached_weight_mem),
- args);
- }
- }
+ GetCachedWeightsAndBias(weight, support_channelwise_scale, has_bias);
const auto data_mem = static_cast<const dnnl::memory*>(data.GetDNNLData());
cached_data_mem_ = std::make_shared<dnnl::memory>(data_mem->get_desc(),
engine);
@@ -499,6 +262,335 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
}
}
+NDArray SgDNNLFCOp::PrepareOutputWithSum(const NDArray& sum_input, const
NDArray& output) {
+ if (!initialized_) {
+ // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace
option,
+ // which make check (req[out_index] == kWriteInplace) useless.
+ auto in_dnnl_mem = static_cast<const
dnnl::memory*>(sum_input.GetDNNLData());
+ auto out_dnnl_mem = static_cast<const dnnl::memory*>(output.GetDNNLData());
+ if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle() &&
+ sum_input.dtype() == output.dtype()) {
+ inplace_ = true;
+ }
+ }
+ if (inplace_) {
+ return sum_input;
+ } else {
+ // Not in place: copy sum_input into output.
+ auto in_dnnl_mem = static_cast<const
dnnl::memory*>(sum_input.GetDNNLData());
+ auto out_dnnl_mem = static_cast<const dnnl::memory*>(output.GetDNNLData());
+ if (output.dtype() == mshadow::kInt32) {
+ auto mem_desc = in_dnnl_mem->get_desc();
+ auto this_dtype = get_dnnl_type(mshadow::kInt32);
+ mem_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
+ dnnl_mem_ptr tmp_mem(new dnnl::memory(
+ mem_desc, CpuEngine::Get()->get_engine(),
out_dnnl_mem->get_data_handle()));
+ DNNLStream::Get()->RegisterMem(tmp_mem);
+ DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*in_dnnl_mem,
*tmp_mem),
+ {{DNNL_ARG_FROM, *in_dnnl_mem},
{DNNL_ARG_TO, *tmp_mem}});
+ return NDArray(tmp_mem);
+ } else if (sum_input.dtype() == mshadow::kUint8 && output.dtype() ==
mshadow::kInt8) {
+ auto sum_mem_desc = in_dnnl_mem->get_desc();
+ auto out_dtype = get_dnnl_type(mshadow::kInt8);
+ sum_mem_desc.data.data_type = static_cast<dnnl_data_type_t>(out_dtype);
+ dnnl_mem_ptr tmp_mem(new dnnl::memory(
+ sum_mem_desc, CpuEngine::Get()->get_engine(),
out_dnnl_mem->get_data_handle()));
+ DNNLStream::Get()->RegisterMem(tmp_mem);
+ const float u8_reorder_scale = 0.5;
+ std::vector<float> reorder_scale = {u8_reorder_scale};
+ dnnl::primitive_attr reorder_attr;
+ reorder_attr.set_output_scales(0, reorder_scale);
+ const auto reorder_pd =
dnnl::reorder::primitive_desc(CpuEngine::Get()->get_engine(),
+
in_dnnl_mem->get_desc(),
+
CpuEngine::Get()->get_engine(),
+ sum_mem_desc,
+ reorder_attr);
+ DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd),
+ {{DNNL_ARG_FROM, *in_dnnl_mem},
{DNNL_ARG_TO, *tmp_mem}});
+ return NDArray(tmp_mem);
+ } else {
+ dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
+ CpuEngine::Get()->get_engine(),
+ out_dnnl_mem->get_data_handle()));
+ DNNLStream::Get()->RegisterMem(tmp_mem);
+ DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get());
+ return NDArray(tmp_mem);
+ }
+ }
+}
+
+bool SgDNNLFCOp::CheckInitializationConditions(const std::vector<NDArray>&
inputs,
+ const std::vector<float>&
min_max_vec,
+ bool is_channel_wise) {
+ if (initialized_ && full_param_.dnnl_param.quantized &&
+ dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
+ bool has_bias = !full_param_.default_param.no_bias;
+ if (cached_data_min_ != min_max_vec[kDataMin] || cached_data_max_ !=
min_max_vec[kDataMax] ||
+ cached_sum_min_ != min_max_vec[kSumMin] || cached_sum_max_ !=
min_max_vec[kSumMax]) {
+ return false;
+ }
+
+ if (is_channel_wise) {
+ if (weight_ver_ != inputs[fullc::kWeight].version() ||
+ (has_bias && (bias_ver_ != inputs[fullc::kBias].version()))) {
+ return false;
+ }
+ } else {
+ if (cached_weight_min_ != min_max_vec[kWeightMin] ||
+ cached_weight_max_ != min_max_vec[kWeightMax] ||
+ (has_bias && (cached_bias_min_ != min_max_vec[kBiasMin] ||
+ cached_bias_max_ != min_max_vec[kBiasMax]))) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
+dnnl::memory::desc SgDNNLFCOp::CreateOutputMemoryDesc(const mxnet::TShape&
oshape, int out_dtype) {
+ auto default_param = full_param_.default_param;
+ dnnl::memory::dims out_dims(2);
+ if (oshape.ndim() == 2) {
+ out_dims[0] = static_cast<index_t>(oshape[0]);
+ out_dims[1] = static_cast<index_t>(oshape[1]);
+ } else {
+ if (!default_param.flatten) {
+ out_dims[0] = static_cast<index_t>(oshape.ProdShape(0, oshape.ndim() -
1));
+ out_dims[1] = static_cast<index_t>(oshape[oshape.ndim() - 1]);
+ } else {
+ out_dims[0] = static_cast<index_t>(oshape[0]);
+ out_dims[1] = static_cast<index_t>(oshape.ProdShape(1, oshape.ndim()));
+ }
+ }
+ dnnl::memory::desc out_md =
+ dnnl::memory::desc(out_dims,
+ get_dnnl_type(out_dtype),
+
static_cast<dnnl::memory::format_tag>(GetDefaultFormat(2)));
+ return out_md;
+}
+
+bool SgDNNLFCOp::PrepareQuantization(const OpContext& ctx,
+ const std::vector<NDArray>& in_data,
+ const NDArray& output,
+ const std::vector<float>& min_max_vec) {
+ const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
+ const FCInputIndex idx(full_param_);
+ bool support_channelwise_scale = false;
+ auto dnnl_param = full_param_.dnnl_param;
+ bool has_bias = !full_param_.default_param.no_bias;
+ const NDArray& data = in_data[fullc::kData];
+ const NDArray& weight = in_data[fullc::kWeight];
+ const bool channel_wise = dnnl_param.quantized &&
dnnl_param.channel_wise_quantize.has_value() &&
+ dnnl_param.channel_wise_quantize.value();
+
+ CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
+ data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_,
cached_data_max_);
+
+ bool fuse_requantize = false;
+ // Channelwise scaling is only supported when fusion is enabled (requantize
or dequantize).
+ if (dnnl_param.min_calib_range.has_value() &&
dnnl_param.max_calib_range.has_value()) {
+ cached_output_min_ = dnnl_param.min_calib_range.value();
+ cached_output_max_ = dnnl_param.max_calib_range.value();
+ support_channelwise_scale = true;
+ fuse_requantize = true;
+ }
+ if (dnnl_param.enable_float_output) {
+ support_channelwise_scale = true;
+ }
+ // channel_wise support_channelwise_scale result
+ // True True True
+ // True False Error
+ // False True/False False
+ if (channel_wise && !support_channelwise_scale) {
+ LOG(FATAL) << "Currently, channel-wise quantization requires fuse
requantize or dequantize."
+ << " Please make sure the `min_calib_range` and
`max_calib_range` are set when only"
+ << " fuse requantize (outputs of FullyConnected are collected
during calibration "
+ "phase),"
+ << " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT`
and "
+ << " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true
(default is false)";
+ }
+ support_channelwise_scale = support_channelwise_scale && channel_wise;
+
+ if (support_channelwise_scale) {
+ MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
+ weight_scales_ = GetWeightScales<DType>(cached_weight_,
+ has_bias ? &cached_bias_ :
nullptr,
+ data_scale_,
+ support_channelwise_scale);
+ });
+ } else {
+ weight_scales_.resize(1);
+ weight_scales_[0] =
+ GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_,
cached_weight_max_);
+ if (has_bias) {
+ if (cached_bias_.dtype() == mshadow::kInt8) {
+ float bias_scale = GetQuantizeScale(mshadow::kInt8, cached_bias_min_,
cached_bias_max_);
+
+ float bias_int32_rescale = data_scale_ * weight_scales_[0] /
bias_scale;
+ // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set
+ // the maximum value of bias to INT_MAX / 2.
+ float bias_max_rescale =
+ MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_,
cached_bias_max_) / bias_scale;
+ if (bias_int32_rescale > bias_max_rescale) {
+ // avoid overflow on bias
+ bias_int32_rescale = bias_max_rescale;
+ float weight_rescale = bias_int32_rescale * bias_scale / data_scale_
/ weight_scales_[0];
+ int8_t* weight_ptr = weight.data().dptr<int8_t>();
+ size_t weight_size = weight.shape().Size();
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
+ weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+ }
+ weight_scales_[0] *= weight_rescale;
+ }
+ NDArray bias = in_data[fullc::kBias];
+ cached_bias_ =
+ NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true,
mshadow::kInt32);
+ int8_t* bias_ptr = bias.data().dptr<int8_t>();
+ int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
+ size_t bias_size = bias.shape().Size();
+
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
+ quantized_bias_ptr[i] = std::round(bias_ptr[i] * bias_int32_rescale);
+ }
+ }
+ }
+ }
+
+ size_t num_channel = cached_weight_.shape()[0];
+ float out_scale = 1.0f;
+ if (fuse_requantize || dnnl_param.enable_float_output) {
+ float tmp_scale_ = 1.0f;
+ if (fuse_requantize) {
+ if (dnnl_param.with_eltwise) {
+ tmp_scale_ = 1.0 / data_scale_;
+ full_param_.eltwise_param.scale =
+ GetQuantizeScale(output.dtype(), cached_output_min_,
cached_output_max_);
+ } else {
+ out_scale = GetQuantizeScale(output.dtype(), cached_output_min_,
cached_output_max_);
+ tmp_scale_ = out_scale / data_scale_;
+ }
+ } else {
+ tmp_scale_ = 1.0 / data_scale_;
+ }
+
+ if (support_channelwise_scale) {
+ full_param_.output_scales.resize(num_channel);
+#pragma omp parallel for num_threads(nthreads)
+ for (index_t i = 0; i < static_cast<index_t>(num_channel); ++i) {
+ full_param_.output_scales[i] = tmp_scale_ / weight_scales_[i];
+ }
+ } else {
+ full_param_.output_scales.resize(1);
+ full_param_.output_scales[0] = tmp_scale_ / weight_scales_[0];
+ }
+ } else {
+ Stream<cpu>* s = ctx.get_stream<cpu>();
+ if (data.dtype() == mshadow::kInt8) {
+ mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct,
cpu>::Launch(
+ s,
+ 1,
+ &cached_output_min_,
+ &cached_output_max_,
+ &(min_max_vec[kDataMin]),
+ &(min_max_vec[kDataMax]),
+ &(min_max_vec[kWeightMin]),
+ &(min_max_vec[kWeightMax]));
+ } else {
+ mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct,
cpu>::Launch(
+ s,
+ 1,
+ &cached_output_min_,
+ &cached_output_max_,
+ &(min_max_vec[kDataMin]),
+ &(min_max_vec[kDataMax]),
+ &(min_max_vec[kWeightMin]),
+ &(min_max_vec[kWeightMax]));
+ }
+ full_param_.output_scales.resize(0);
+ out_scale = data_scale_ * weight_scales_[0];
+ }
+
+ if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+ float sum_in_scale =
+ GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_,
cached_sum_max_);
+ full_param_.sum_scale = out_scale / sum_in_scale;
+ if (in_data[idx.sum].dtype() == mshadow::kUint8 && output.dtype() ==
mshadow::kInt8) {
+ // In this case, reorder with scale 0.5 is used on in_data[idx.sum] to
+ // scale it to s8 range, so sum_scale has to be rescaled as well
+ full_param_.sum_scale *= 2.0;
+ }
+ }
+ return support_channelwise_scale;
+}
+
+void SgDNNLFCOp::GetCachedWeightsAndBias(const NDArray& weight,
+ bool support_channelwise_scale,
+ bool has_bias) {
+#if DMLC_CXX11_THREAD_LOCAL
+ static thread_local std::unordered_map<DNNLFullyconSignature,
std::pair<NDArray, NDArray>, OpHash>
+ fcWeightsAndBias;
+#else
+ static MX_THREAD_LOCAL
+ std::unordered_map<DNNLFullyconSignature, std::pair<NDArray, NDArray>,
OpHash>
+ fcWeightsAndBias;
+#endif
+ static const bool use_cache =
!(dmlc::GetEnv("MXNET_ONEDNN_DISABLE_FC_CACHE", 0));
+ const bool has_id = attrs.dict.count("__identifier__");
+ bool read_from_cache = false;
+
+ DNNLFullyconSignature key(full_param_.default_param);
+ if (use_cache && has_id) {
+ key.AddSign(fwd_->fwd_pd.weights_desc());
+ key.AddSign(attrs.dict["__identifier__"]);
+ key.AddSign(attrs.name);
+
+ auto it = fcWeightsAndBias.find(key);
+ if (it != fcWeightsAndBias.end()) {
+ cached_weight_ = it->second.first;
+ cached_bias_ = it->second.second;
+ read_from_cache = true;
+ common::LogOnce(
+ "oneDNN optimized version of FullyConnected for inference is being
used. Weights and "
+ "bias are cached and can not be dynamically changed during runtime.
To disable caching "
+ "mechanism use MXNET_ONEDNN_DISABLE_FC_CACHE=1.");
+ }
+ }
+
+ if (!read_from_cache) {
+ // convert weight and bias to the format that oneDNN requires
+ if (!full_param_.dnnl_param.quantized || support_channelwise_scale) {
+ dnnl::memory::desc bias_md;
+ if (has_bias)
+ bias_md = fwd_->fwd_pd.bias_desc();
+ ConvertWeightBias2DNNL(&cached_weight_,
+ &cached_bias_,
+ has_bias,
+ fwd_->fwd_pd.weights_desc(),
+ has_bias ? &bias_md : nullptr,
+ 1,
+ data_scale_,
+ weight_scales_,
+ false);
+ } else {
+ const auto def_weight_mem = weight.GetDNNLData();
+ if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
+ auto weight_desc = fwd_->fwd_pd.weights_desc();
+ cached_weight_ = NDArray(&weight_desc);
+ auto cached_weight_mem = cached_weight_.GetDNNLData();
+ std::unordered_map<int, dnnl::memory> args(
+ {{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO,
*cached_weight_mem}});
+ DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem,
*cached_weight_mem),
+ args);
+ }
+ }
+ if (use_cache && has_id)
+ AddToCache(&fcWeightsAndBias, key, {cached_weight_, cached_bias_});
+ }
+}
+
static void SgDNNLFCParamParser(nnvm::NodeAttrs* attrs) {
// For backward compatible, with_relu->with_eltwise
auto legacy = attrs->dict.find("with_relu");
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h
b/src/operator/subgraph/dnnl/dnnl_fc_property.h
index fd5272ef5a..ccf5401705 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h
@@ -167,7 +167,9 @@ class SgDNNLFCProperty : public SubgraphProperty {
nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
const int subgraph_id = 0) const override
{
- nnvm::ObjectPtr n = nnvm::Node::Create();
+ // distingush between exactly same node in different networks - for
caching weights
+ static unsigned int node_identifier = 0;
+ nnvm::ObjectPtr n = nnvm::Node::Create();
// This op has single output, remove duplicated.
auto last_node = sym.outputs[0].node;
nnvm::Symbol new_sym;
@@ -189,6 +191,7 @@ class SgDNNLFCProperty : public SubgraphProperty {
n->attrs.name = node_name.str();
n->attrs.op = Op::Get("_sg_onednn_fully_connected");
CHECK(n->attrs.op);
+ n->attrs.dict["__identifier__"] = std::to_string(node_identifier++);
n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
n->op()->attr_parser(&(n->attrs));
return n;