This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new afbef154ed Type fix for FullyConnected with sum (#21043)
afbef154ed is described below

commit afbef154ed09810e8addccd37169ffed8b4f7cda
Author: DominikaJedynak <[email protected]>
AuthorDate: Thu Jun 23 15:29:23 2022 +0200

    Type fix for FullyConnected with sum (#21043)
---
 src/operator/subgraph/dnnl/dnnl_fc.cc | 32 ++++++++++++++++++++++++++++++--
 1 file changed, 30 insertions(+), 2 deletions(-)

diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc 
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 24b7ec6883..2371ce8bd9 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -131,7 +131,8 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
       // which make check (req[out_index] == kWriteInplace) useless.
       auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
       auto out_dnnl_mem = static_cast<const 
dnnl::memory*>(out_data[out_index].GetDNNLData());
-      if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle()) {
+      if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle() &&
+          in_data[idx.sum].dtype() == out_data[out_index].dtype()) {
         inplace_ = true;
       }
     }
@@ -152,6 +153,26 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
             dnnl::reorder(*in_dnnl_mem, *tmp_mem),
             {{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
         output = NDArray(tmp_mem);
+      } else if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+                 out_data[out_index].dtype() == mshadow::kInt8) {
+        auto sum_mem_desc           = in_dnnl_mem->get_desc();
+        auto out_dtype              = get_dnnl_type(mshadow::kInt8);
+        sum_mem_desc.data.data_type = static_cast<dnnl_data_type_t>(out_dtype);
+        dnnl_mem_ptr tmp_mem(new dnnl::memory(
+            sum_mem_desc, CpuEngine::Get()->get_engine(), 
out_dnnl_mem->get_data_handle()));
+        DNNLStream::Get()->RegisterMem(tmp_mem);
+        const float u8_reorder_scale     = 0.5;
+        std::vector<float> reorder_scale = {u8_reorder_scale};
+        dnnl::primitive_attr reorder_attr;
+        reorder_attr.set_output_scales(0, reorder_scale);
+        const auto reorder_pd = 
dnnl::reorder::primitive_desc(CpuEngine::Get()->get_engine(),
+                                                              
in_dnnl_mem->get_desc(),
+                                                              
CpuEngine::Get()->get_engine(),
+                                                              sum_mem_desc,
+                                                              reorder_attr);
+        DNNLStream::Get()->RegisterPrimArgs(
+            dnnl::reorder(reorder_pd), {{DNNL_ARG_FROM, *in_dnnl_mem}, 
{DNNL_ARG_TO, *tmp_mem}});
+        output = NDArray(tmp_mem);
       } else {
         dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
                                               CpuEngine::Get()->get_engine(),
@@ -388,6 +409,12 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
         float sum_in_scale =
             GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, 
cached_sum_max_);
         full_param_.sum_scale = out_scale / sum_in_scale;
+        if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+            out_data[out_index].dtype() == mshadow::kInt8) {
+          // In this case, reorder with scale 0.5 is used on in_data[idx.sum] 
to
+          // scale it to s8 range, so sum_scale has to be rescaled as well
+          full_param_.sum_scale *= 2.0;
+        }
       }
     }  // if (dnnl_param.quantized)
 
@@ -652,7 +679,8 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& attrs,
     } else {
       if (full_param.dnnl_param.min_calib_range.has_value() &&
           full_param.dnnl_param.max_calib_range.has_value()) {
-        if (IsOutputUint8(full_param)) {
+        if (IsOutputUint8(full_param) &&
+            (!idx.IsSumExist() || in_types->at(idx.sum) == mshadow::kUint8)) {
           TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
         } else {
           TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);

Reply via email to