This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new 702e47594b [v1.x] Fix for fc with sum when types are incompatible
(#21042)
702e47594b is described below
commit 702e47594b7c615678d5873bc984a1447bec8271
Author: DominikaJedynak <[email protected]>
AuthorDate: Wed Jul 20 11:25:32 2022 +0200
[v1.x] Fix for fc with sum when types are incompatible (#21042)
* Type sum fix
* Incompatible fc and sum type fix
* Clang formatting
---
src/operator/subgraph/mkldnn/mkldnn_fc.cc | 37 +++++++++++++++++++++++++++----
1 file changed, 33 insertions(+), 4 deletions(-)
diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index 9c481d675a..352e638c90 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -135,7 +135,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
// which make check (req[out_index] == kWriteInplace) useless.
auto in_mkl_mem = static_cast<const
mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
auto out_mkl_mem = static_cast<const
mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
- if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
+ if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()
+ && in_data[idx.sum].dtype() == out_data[out_index].dtype()) {
inplace_ = true;
}
}
@@ -146,8 +147,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
auto in_mkl_mem = static_cast<const
mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
auto out_mkl_mem = static_cast<const
mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
if (out_data[out_index].dtype() == mshadow::kInt32) {
- auto mem_desc = in_mkl_mem->get_desc();
- auto this_dtype = get_mkldnn_type(mshadow::kInt32);
+ auto mem_desc = in_mkl_mem->get_desc();
+ auto this_dtype = get_mkldnn_type(mshadow::kInt32);
mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(
mem_desc, CpuEngine::Get()->get_engine(),
out_mkl_mem->get_data_handle()));
@@ -156,6 +157,27 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
mkldnn::reorder(*in_mkl_mem, *tmp_mem),
{{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
output = NDArray(tmp_mem);
+ } else if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+ out_data[out_index].dtype() == mshadow::kInt8) {
+ auto sum_mem_desc = in_mkl_mem->get_desc();
+ auto out_dtype = get_mkldnn_type(mshadow::kInt8);
+ sum_mem_desc.data.data_type =
+ static_cast<mkldnn_data_type_t>(out_dtype);
+ mkldnn_mem_ptr tmp_mem(
+ new mkldnn::memory(sum_mem_desc, CpuEngine::Get()->get_engine(),
+ out_mkl_mem->get_data_handle()));
+ MKLDNNStream::Get()->RegisterMem(tmp_mem);
+ const float u8_reorder_scale = 0.5;
+ std::vector<float> reorder_scale = {u8_reorder_scale};
+ mkldnn::primitive_attr reorder_attr;
+ reorder_attr.set_output_scales(0, reorder_scale);
+ const auto reorder_pd = mkldnn::reorder::primitive_desc(
+ CpuEngine::Get()->get_engine(), in_mkl_mem->get_desc(),
+ CpuEngine::Get()->get_engine(), sum_mem_desc, reorder_attr);
+ MKLDNNStream::Get()->RegisterPrimArgs(
+ mkldnn::reorder(reorder_pd),
+ {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
+ output = NDArray(tmp_mem);
} else {
mkldnn_mem_ptr tmp_mem(new mkldnn::memory(in_mkl_mem->get_desc(),
CpuEngine::Get()->get_engine(),
@@ -393,6 +415,12 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
float sum_in_scale =
GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_,
cached_sum_max_);
mkldnn_param.sum_scale = out_scale / sum_in_scale;
+ if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+ out_data[out_index].dtype() == mshadow::kInt8) {
+ // In this case, reorder with scale 0.5 is used on in_data[idx.sum]
to
+ // scale it to s8 range, so sum_scale has to be rescaled as well
+ mkldnn_param.sum_scale *= 2.0;
+ }
}
} // if (mkldnn_param.quantized)
@@ -659,7 +687,8 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs&
attrs,
} else {
if (full_param.mkldnn_param.min_calib_range.has_value() &&
full_param.mkldnn_param.max_calib_range.has_value()) {
- if (IsOutputUint8(full_param)) {
+ if (IsOutputUint8(full_param) &&
+ (!idx.IsSumExist() || in_types->at(idx.sum) == mshadow::kUint8)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);