This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new afbef154ed Type fix for FullyConnected with sum (#21043)
afbef154ed is described below
commit afbef154ed09810e8addccd37169ffed8b4f7cda
Author: DominikaJedynak <[email protected]>
AuthorDate: Thu Jun 23 15:29:23 2022 +0200
Type fix for FullyConnected with sum (#21043)
---
src/operator/subgraph/dnnl/dnnl_fc.cc | 32 ++++++++++++++++++++++++++++++--
1 file changed, 30 insertions(+), 2 deletions(-)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 24b7ec6883..2371ce8bd9 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -131,7 +131,8 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
// which make check (req[out_index] == kWriteInplace) useless.
auto in_dnnl_mem = static_cast<const
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
auto out_dnnl_mem = static_cast<const
dnnl::memory*>(out_data[out_index].GetDNNLData());
- if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle()) {
+ if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle() &&
+ in_data[idx.sum].dtype() == out_data[out_index].dtype()) {
inplace_ = true;
}
}
@@ -152,6 +153,26 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
dnnl::reorder(*in_dnnl_mem, *tmp_mem),
{{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
output = NDArray(tmp_mem);
+ } else if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+ out_data[out_index].dtype() == mshadow::kInt8) {
+ auto sum_mem_desc = in_dnnl_mem->get_desc();
+ auto out_dtype = get_dnnl_type(mshadow::kInt8);
+ sum_mem_desc.data.data_type = static_cast<dnnl_data_type_t>(out_dtype);
+ dnnl_mem_ptr tmp_mem(new dnnl::memory(
+ sum_mem_desc, CpuEngine::Get()->get_engine(),
out_dnnl_mem->get_data_handle()));
+ DNNLStream::Get()->RegisterMem(tmp_mem);
+ const float u8_reorder_scale = 0.5;
+ std::vector<float> reorder_scale = {u8_reorder_scale};
+ dnnl::primitive_attr reorder_attr;
+ reorder_attr.set_output_scales(0, reorder_scale);
+ const auto reorder_pd =
dnnl::reorder::primitive_desc(CpuEngine::Get()->get_engine(),
+
in_dnnl_mem->get_desc(),
+
CpuEngine::Get()->get_engine(),
+ sum_mem_desc,
+ reorder_attr);
+ DNNLStream::Get()->RegisterPrimArgs(
+ dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, *in_dnnl_mem},
{DNNL_ARG_TO, *tmp_mem}});
+ output = NDArray(tmp_mem);
} else {
dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
CpuEngine::Get()->get_engine(),
@@ -388,6 +409,12 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
float sum_in_scale =
GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_,
cached_sum_max_);
full_param_.sum_scale = out_scale / sum_in_scale;
+ if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+ out_data[out_index].dtype() == mshadow::kInt8) {
+ // In this case, reorder with scale 0.5 is used on in_data[idx.sum]
to
+ // scale it to s8 range, so sum_scale has to be rescaled as well
+ full_param_.sum_scale *= 2.0;
+ }
}
} // if (dnnl_param.quantized)
@@ -652,7 +679,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
} else {
if (full_param.dnnl_param.min_calib_range.has_value() &&
full_param.dnnl_param.max_calib_range.has_value()) {
- if (IsOutputUint8(full_param)) {
+ if (IsOutputUint8(full_param) &&
+ (!idx.IsSumExist() || in_types->at(idx.sum) == mshadow::kUint8)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);