anko-intel commented on code in PR #21041:
URL: https://github.com/apache/incubator-mxnet/pull/21041#discussion_r896751091


##########
src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc:
##########
@@ -113,65 +169,91 @@ static void DNNLQuantizedElemwiseAddForward(const 
nnvm::NodeAttrs& attrs,
   float output_min     = 0;
   float output_max     = 0;
   float output_scale   = 0;
+  const int scales_num = 2;  // 2: scale 0 for input A, scale 1 for input B
+  std::vector<float> scales(scales_num, 1);
   if (params.max_calib_range.has_value() && 
params.min_calib_range.has_value()) {
     output_min     = params.min_calib_range.value();
     output_max     = params.max_calib_range.value();
     output_scale   = output_data_range / MaxAbs(output_min, output_max);
+    scales[0]      = output_scale / A_scale;
+    scales[1]      = output_scale / B_scale;
   } else {
     output_max = A_absmax + B_absmax;
     output_min = -output_max;
+    scales[0]  = A_absmax * output_data_range / (output_max * in_range);
+    scales[1]  = B_absmax * output_data_range / (output_max * in_range);
   }
-  // 2: scale 0 for input A, scale 1 for input B
-  const int scales_num = 2;
-  std::vector<float> scales(scales_num, 1);
-  auto engine = CpuEngine::Get()->get_engine();
-  if (inputs[q_elemwise_add::kDataA].dtype() != 
inputs[q_elemwise_add::kDataB].dtype()) {
-    auto s8_desc                     = is_A_int8 ? A_mem->get_desc() : 
B_mem->get_desc();
-    rescaled_mem = TmpMemMgr::Get()->Alloc(s8_desc);
-    const float u8_reorder_scale     = 0.5;
-    std::vector<float> reorder_scale = {u8_reorder_scale};
-    dnnl::primitive_attr reorder_attr;
-    reorder_attr.set_output_scales(0, reorder_scale);
-    auto u8_mem = (is_A_int8 == true) ? B_mem : A_mem;
-    const auto reorder_pd =
-        dnnl::reorder::primitive_desc(engine, u8_mem->get_desc(), engine, 
s8_desc, reorder_attr);
-    dnnl_args_map_t args({{DNNL_ARG_FROM, *u8_mem}, {DNNL_ARG_TO, 
*rescaled_mem}});
-    DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd), args);
-
-    if (is_A_int8) {
-      B_mem = rescaled_mem;
+
+  // We can use more efficient sum kernel when there is no broadcast - when 
shapes are the same
+  const bool sum_kernel =
+      (inputs[q_elemwise_add::kDataA].shape() == 
inputs[q_elemwise_add::kDataB].shape());
+
+  if (diff_in_types) {
+    if (sum_kernel) {
+      // rescale uint8 to int8 by reorder to temporary memory
+      auto s8_desc                     = is_A_int8 ? A_mem->get_desc() : 
B_mem->get_desc();
+      rescaled_mem                     = TmpMemMgr::Get()->Alloc(s8_desc);
+      const float u8_reorder_scale     = 0.5;
+      std::vector<float> reorder_scale = {u8_reorder_scale};
+      auto engine                      = CpuEngine::Get()->get_engine();
+      dnnl::primitive_attr reorder_attr;
+      reorder_attr.set_output_scales(0, reorder_scale);
+      auto u8_mem = (is_A_int8 == true) ? B_mem : A_mem;
+      const auto reorder_pd =
+          dnnl::reorder::primitive_desc(engine, u8_mem->get_desc(), engine, 
s8_desc, reorder_attr);
+      dnnl_args_map_t args({{DNNL_ARG_FROM, *u8_mem}, {DNNL_ARG_TO, 
*rescaled_mem}});
+      DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd), args);
+      if (is_A_int8) {
+        B_mem = rescaled_mem;
+      } else {
+        A_mem = rescaled_mem;
+      }
     } else {
-      A_mem = rescaled_mem;
+      // take into account conversion from uint8 to int8 in binary operator 
input scales
+      if (is_A_int8) {
+        scales[1] *= 0.5;  // convert B from uint8
+      } else {
+        scales[0] *= 0.5;  // convert A from uint8
+      }
     }

Review Comment:
   No, but you are right that common scale was not preserved. I have "rescaled" 
u8 input to the original values in new  commit



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to