This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 702e47594b [v1.x] Fix for fc with sum when types are incompatible 
(#21042)
702e47594b is described below

commit 702e47594b7c615678d5873bc984a1447bec8271
Author: DominikaJedynak <[email protected]>
AuthorDate: Wed Jul 20 11:25:32 2022 +0200

    [v1.x] Fix for fc with sum when types are incompatible (#21042)
    
    * Type sum fix
    
    * Incompatible fc and sum type fix
    
    * Clang formatting
---
 src/operator/subgraph/mkldnn/mkldnn_fc.cc | 37 +++++++++++++++++++++++++++----
 1 file changed, 33 insertions(+), 4 deletions(-)

diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc 
b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
index 9c481d675a..352e638c90 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc
@@ -135,7 +135,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
       // which make check (req[out_index] == kWriteInplace) useless.
       auto in_mkl_mem  = static_cast<const 
mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
       auto out_mkl_mem = static_cast<const 
mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
-      if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()) {
+      if (in_mkl_mem->get_data_handle() == out_mkl_mem->get_data_handle()
+          && in_data[idx.sum].dtype() == out_data[out_index].dtype()) {
         inplace_ = true;
       }
     }
@@ -146,8 +147,8 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
       auto in_mkl_mem  = static_cast<const 
mkldnn::memory*>(in_data[idx.sum].GetMKLDNNData());
       auto out_mkl_mem = static_cast<const 
mkldnn::memory*>(out_data[out_index].GetMKLDNNData());
       if (out_data[out_index].dtype() == mshadow::kInt32) {
-        auto mem_desc           = in_mkl_mem->get_desc();
-        auto this_dtype         = get_mkldnn_type(mshadow::kInt32);
+        auto mem_desc = in_mkl_mem->get_desc();
+        auto this_dtype = get_mkldnn_type(mshadow::kInt32);
         mem_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
         mkldnn_mem_ptr tmp_mem(new mkldnn::memory(
             mem_desc, CpuEngine::Get()->get_engine(), 
out_mkl_mem->get_data_handle()));
@@ -156,6 +157,27 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
             mkldnn::reorder(*in_mkl_mem, *tmp_mem),
             {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
         output = NDArray(tmp_mem);
+      } else if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+                 out_data[out_index].dtype() == mshadow::kInt8) {
+        auto sum_mem_desc = in_mkl_mem->get_desc();
+        auto out_dtype = get_mkldnn_type(mshadow::kInt8);
+        sum_mem_desc.data.data_type =
+            static_cast<mkldnn_data_type_t>(out_dtype);
+        mkldnn_mem_ptr tmp_mem(
+            new mkldnn::memory(sum_mem_desc, CpuEngine::Get()->get_engine(),
+                               out_mkl_mem->get_data_handle()));
+        MKLDNNStream::Get()->RegisterMem(tmp_mem);
+        const float u8_reorder_scale = 0.5;
+        std::vector<float> reorder_scale = {u8_reorder_scale};
+        mkldnn::primitive_attr reorder_attr;
+        reorder_attr.set_output_scales(0, reorder_scale);
+        const auto reorder_pd = mkldnn::reorder::primitive_desc(
+            CpuEngine::Get()->get_engine(), in_mkl_mem->get_desc(),
+            CpuEngine::Get()->get_engine(), sum_mem_desc, reorder_attr);
+        MKLDNNStream::Get()->RegisterPrimArgs(
+            mkldnn::reorder(reorder_pd),
+            {{MKLDNN_ARG_FROM, *in_mkl_mem}, {MKLDNN_ARG_TO, *tmp_mem}});
+        output = NDArray(tmp_mem);
       } else {
         mkldnn_mem_ptr tmp_mem(new mkldnn::memory(in_mkl_mem->get_desc(),
                                                   
CpuEngine::Get()->get_engine(),
@@ -393,6 +415,12 @@ void SgMKLDNNFCOp::Forward(const OpContext& ctx,
         float sum_in_scale =
             GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, 
cached_sum_max_);
         mkldnn_param.sum_scale = out_scale / sum_in_scale;
+        if (in_data[idx.sum].dtype() == mshadow::kUint8 &&
+            out_data[out_index].dtype() == mshadow::kInt8) {
+          // In this case, reorder with scale 0.5 is used on in_data[idx.sum] 
to
+          // scale it to s8 range, so sum_scale has to be rescaled as well
+          mkldnn_param.sum_scale *= 2.0;
+        }
       }
     }  // if (mkldnn_param.quantized)
 
@@ -659,7 +687,8 @@ static bool SgMKLDNNFCInferType(const nnvm::NodeAttrs& 
attrs,
     } else {
       if (full_param.mkldnn_param.min_calib_range.has_value() &&
           full_param.mkldnn_param.max_calib_range.has_value()) {
-        if (IsOutputUint8(full_param)) {
+        if (IsOutputUint8(full_param) &&
+            (!idx.IsSumExist() || in_types->at(idx.sum) == mshadow::kUint8)) {
           TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
         } else {
           TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);

Reply via email to