This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 63aea9e031 [FEATURE] Add quantization for npi_add with oneDNN (#21041)
63aea9e031 is described below

commit 63aea9e0310ed672b8816d332ae9ec25f9d8a893
Author: Andrzej KotÅ‚owski <[email protected]>
AuthorDate: Wed Jun 22 08:27:26 2022 +0200

    [FEATURE] Add quantization for npi_add with oneDNN (#21041)
    
    * [FEATURE] Add quantization for npi_add with oneDNN
    
    + tests for it,
    + benchmark for FC + npi_add with broadcasted input
    
    * Fix tests in test_fc_subgraph.py
    
    The initialisation of model weights was wrong
    as negative numbers couldn't be represented with uint8
    
    * Fix merge
    
    * Fix callibration and scales
    
    + some rearrangment to make code more readable
    
    * Add tests for full path of quantization of add operators
    
    * Fix test_fc_subgraph.py::test_fc_eltwise[exp-True-True
    
    After previous fix initializing FC weights to only positive values for
    uint8, input for exp function was to high.
---
 benchmark/python/dnnl/fc_add.py                    |  22 +-
 src/operator/numpy/np_elemwise_broadcast_op.h      |   3 +
 .../dnnl/dnnl_quantized_elemwise_add.cc            | 228 ++++++++++++++-------
 .../quantization/quantized_elemwise_add-inl.h      |  15 ++
 .../quantization/quantized_elemwise_add.cc         |  50 +++++
 .../subgraph/dnnl/dnnl_post_quantize_property.h    |   4 +-
 src/operator/tensor/elemwise_binary_broadcast_op.h |  16 +-
 tests/python/dnnl/subgraphs/subgraph_common.py     |  22 +-
 tests/python/dnnl/subgraphs/test_fc_subgraph.py    |  55 +++++
 tests/python/quantization/test_quantization.py     |  77 +++++++
 10 files changed, 400 insertions(+), 92 deletions(-)

diff --git a/benchmark/python/dnnl/fc_add.py b/benchmark/python/dnnl/fc_add.py
index 8bdefd39ef..6cf2f929ec 100644
--- a/benchmark/python/dnnl/fc_add.py
+++ b/benchmark/python/dnnl/fc_add.py
@@ -97,8 +97,8 @@ class FCWithSum(nn.HybridBlock):
             _sum1 = _fc1 + _sum0
         return _sum1
 
-def benchmark_float(elemwise_add):
-    header = operator_string(elemwise_add) + ', float'
+def benchmark_float(elemwise_add, broadcast=False):
+    header = operator_string(elemwise_add) + ', float' + (' , broadcast' if 
broadcast else "")
     print_header(header)
     for shape, nhid in sizes:
         net = FCWithSum(shape[1], nhid, elemwise_add)
@@ -107,6 +107,9 @@ def benchmark_float(elemwise_add):
         data0 = mx.np.random.uniform(size=shape, low=-1.0, high=1.0)
         data1 = mx.np.random.uniform(size=shape, low=-1.0, high=1.0)
         shape2 = (shape[0], nhid)
+        if broadcast and not elemwise_add:
+            # broadcast is allowed only for npi_add version
+            shape2 = (1, 1)
         data2 = mx.np.random.uniform(size=shape2, low=-1.0, high=1.0)
         net.optimize_for(data0, data1, data2, backend='ONEDNN')
         measure(net, data0, data1, data2, shape, nhid)
@@ -126,9 +129,9 @@ class CalibIter(mx.io.DataIter):
     def __iter__(self):
         yield self.batch
 
-def benchmark_int8(quantize_mode, quantize_granularity, elemwise_add):
+def benchmark_int8(quantize_mode, quantize_granularity, elemwise_add, 
broadcast = False):
     header = operator_string(elemwise_add) + ', mode = ' + quantize_mode + \
-             ', granularity = ' + quantize_granularity
+             ', granularity = ' + quantize_granularity + (' , broadcast' if 
broadcast else "")
     print_header(header)
     for shape, nhid in sizes:
         net = FCWithSum(shape[1], nhid, elemwise_add)
@@ -137,6 +140,9 @@ def benchmark_int8(quantize_mode, quantize_granularity, 
elemwise_add):
         data0 = mx.np.random.uniform(size=shape, low=-1.0, high=1.0)
         data1 = mx.np.random.uniform(size=shape, low=-1.0, high=1.0)
         shape2 = (shape[0], nhid)
+        if broadcast and not elemwise_add:
+            # broadcast is allowed only for npi_add
+            shape2 = (shape[0], 1)
         data2 = mx.np.random.uniform(size=shape2, low=-1.0, high=1.0)
         data = mx.gluon.data.ArrayDataset(data0, data1, data2)
         calib_data = mx.gluon.data.DataLoader(data, batch_size=1)
@@ -162,3 +168,11 @@ for quantize_mode in ['smart', 'full']:
     for quantize_granularity in ['tensor-wise', 'channel-wise']:
         for elemwise_add in [True, False]:
             benchmark_int8(quantize_mode, quantize_granularity, elemwise_add)
+
+# Benchmark FC + npi_add with broadcasted input
+benchmark_float(False, True)
+
+# Benchmark quantized FC + npi_add with broadcasted input
+for quantize_mode in ['smart', 'full']:
+    for quantize_granularity in ['tensor-wise', 'channel-wise']:
+        benchmark_int8(quantize_mode, quantize_granularity, False, True)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op.h 
b/src/operator/numpy/np_elemwise_broadcast_op.h
index 1b931e01c2..f27b9a7772 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op.h
+++ b/src/operator/numpy/np_elemwise_broadcast_op.h
@@ -974,6 +974,9 @@ inline bool NumpyBinaryMixedPrecisionType(const 
nnvm::NodeAttrs& attrs,
                                        [](const NodeAttrs& attrs) {            
                   \
                                          return 
std::vector<std::string>{"lhs", "rhs"};           \
                                        })                                      
                   \
+      .set_attr<nnvm::FListOutputNames>(                                       
                   \
+          "FListOutputNames",                                                  
                   \
+          [](const NodeAttrs& attrs) { return 
std::vector<std::string>{"output"}; })              \
       .set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)       
                   \
       .set_attr<nnvm::FInferType>("FInferType", NumpyBinaryMixedPrecisionType) 
                   \
       .set_attr<nnvm::FInplaceOption>("FInplaceOption",                        
                   \
diff --git a/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc 
b/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
index 82e66a1ed2..979d2cbd53 100644
--- a/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
+++ b/src/operator/quantization/dnnl/dnnl_quantized_elemwise_add.cc
@@ -32,18 +32,13 @@ namespace op {
 
 DMLC_REGISTER_PARAMETER(QuantizeElemwiseAddParam);
 
-static inline float GetScale(const NDArray& data, float min, float max) {
-  auto data_range = (data.dtype() == mshadow::kInt8) ? kInt8Range : 
kUint8Range;
-  return data_range / MaxAbs(min, max);
-}
-
-class DNNLQuantizedElemwiseSumFwd {
+class DNNLQuantizedSumFwd {
  public:
   dnnl::sum::primitive_desc fwd_pd;
 
-  DNNLQuantizedElemwiseSumFwd(const dnnl::memory::desc& output_md,
-                              const std::vector<float>& scales,
-                              const std::vector<dnnl::memory::desc>& inputs_md)
+  DNNLQuantizedSumFwd(const dnnl::memory::desc& output_md,
+                      const std::vector<float>& scales,
+                      const std::vector<dnnl::memory::desc>& inputs_md)
       : fwd_pd(output_md, scales, inputs_md, CpuEngine::Get()->get_engine()) {
     fwd_ = std::make_shared<dnnl::sum>(fwd_pd);
   }
@@ -52,19 +47,23 @@ class DNNLQuantizedElemwiseSumFwd {
     return *fwd_;
   }
 
+  static DNNLQuantizedSumFwd& GetCached(const dnnl::memory::desc& output_md,
+                                        const std::vector<float>& scales,
+                                        const std::vector<dnnl::memory::desc>& 
inputs_md);
+
  private:
   std::shared_ptr<dnnl::sum> fwd_;
   std::shared_ptr<dnnl::memory> out_;
 };
 
-static DNNLQuantizedElemwiseSumFwd& GetQuantizedElemwiseSumForward(
+DNNLQuantizedSumFwd& DNNLQuantizedSumFwd::GetCached(
     const dnnl::memory::desc& output_md,
     const std::vector<float>& scales,
     const std::vector<dnnl::memory::desc>& inputs_md) {
 #if DMLC_CXX11_THREAD_LOCAL
-  static thread_local std::unordered_map<OpSignature, 
DNNLQuantizedElemwiseSumFwd, OpHash> fwds;
+  static thread_local std::unordered_map<OpSignature, DNNLQuantizedSumFwd, 
OpHash> fwds;
 #else
-  static MX_THREAD_LOCAL std::unordered_map<OpSignature, 
DNNLQuantizedElemwiseSumFwd, OpHash> fwds;
+  static MX_THREAD_LOCAL std::unordered_map<OpSignature, DNNLQuantizedSumFwd, 
OpHash> fwds;
 #endif
   OpSignature key;
   key.AddSign(output_md);
@@ -72,7 +71,56 @@ static DNNLQuantizedElemwiseSumFwd& 
GetQuantizedElemwiseSumForward(
   key.AddSign(inputs_md);
   auto it = fwds.find(key);
   if (it == fwds.end()) {
-    DNNLQuantizedElemwiseSumFwd fwd(output_md, scales, inputs_md);
+    DNNLQuantizedSumFwd fwd(output_md, scales, inputs_md);
+    it = AddToCache(&fwds, key, fwd);
+  }
+  return it->second;
+}
+
+class DNNLQuantizedBinAddFwd {
+ public:
+  dnnl::binary::primitive_desc fwd_pd;
+
+  DNNLQuantizedBinAddFwd(const dnnl::memory::desc& output_md,
+                         const std::vector<float>& scales,
+                         const std::vector<dnnl::memory::desc>& inputs_md) {
+    dnnl::binary::desc fwd_desc(dnnl::algorithm::binary_add, inputs_md[0], 
inputs_md[1], output_md);
+    dnnl::primitive_attr input_scales;
+    input_scales.set_scales(DNNL_ARG_SRC_0, 0, {scales[0]});
+    input_scales.set_scales(DNNL_ARG_SRC_1, 0, {scales[1]});
+    fwd_pd = dnnl::binary::primitive_desc(fwd_desc, input_scales, 
CpuEngine::Get()->get_engine());
+    fwd_   = std::make_shared<dnnl::binary>(fwd_pd);
+  }
+
+  const dnnl::binary& GetFwd() const {
+    return *fwd_;
+  }
+
+  static DNNLQuantizedBinAddFwd& GetCached(const dnnl::memory::desc& output_md,
+                                           const std::vector<float>& scales,
+                                           const 
std::vector<dnnl::memory::desc>& inputs_md);
+
+ private:
+  std::shared_ptr<dnnl::binary> fwd_;
+  std::shared_ptr<dnnl::memory> out_;
+};
+
+DNNLQuantizedBinAddFwd& DNNLQuantizedBinAddFwd::GetCached(
+    const dnnl::memory::desc& output_md,
+    const std::vector<float>& scales,
+    const std::vector<dnnl::memory::desc>& inputs_md) {
+#if DMLC_CXX11_THREAD_LOCAL
+  static thread_local std::unordered_map<OpSignature, DNNLQuantizedBinAddFwd, 
OpHash> fwds;
+#else
+  static MX_THREAD_LOCAL std::unordered_map<OpSignature, 
DNNLQuantizedBinAddFwd, OpHash> fwds;
+#endif
+  OpSignature key;
+  key.AddSign(output_md);
+  key.AddSign(scales);
+  key.AddSign(inputs_md);
+  auto it = fwds.find(key);
+  if (it == fwds.end()) {
+    DNNLQuantizedBinAddFwd fwd(output_md, scales, inputs_md);
     it = AddToCache(&fwds, key, fwd);
   }
   return it->second;
@@ -89,88 +137,113 @@ static void DNNLQuantizedElemwiseAddForward(const 
nnvm::NodeAttrs& attrs,
   // C, C_min, C_max
   CHECK_EQ(outputs.size(), 3U) << "should be C, C_min, C_max";
   // Collect data min,max,absmax
-  const float A_min    = inputs[q_elemwise_add::kAMin].data().dptr<float>()[0];
-  const float B_min    = inputs[q_elemwise_add::kBMin].data().dptr<float>()[0];
-  const float A_max    = inputs[q_elemwise_add::kAMax].data().dptr<float>()[0];
-  const float B_max    = inputs[q_elemwise_add::kBMax].data().dptr<float>()[0];
-  const float A_absmax = MaxAbs(A_min, A_max);
-  const float B_absmax = MaxAbs(B_min, B_max);
-  const bool is_A_int8 = (inputs[q_elemwise_add::kDataA].dtype() == 
mshadow::kInt8);
-  const float A_scale  = GetScale(inputs[q_elemwise_add::kDataA], A_min, 
A_max);
-  const float B_scale  = GetScale(inputs[q_elemwise_add::kDataB], B_min, 
B_max);
-  auto A_mem           = inputs[q_elemwise_add::kDataA].GetDNNLData();
-  auto B_mem           = inputs[q_elemwise_add::kDataB].GetDNNLData();
-  float in_range;
+  const float A_min        = 
inputs[q_elemwise_add::kAMin].data().dptr<float>()[0];
+  const float B_min        = 
inputs[q_elemwise_add::kBMin].data().dptr<float>()[0];
+  const float A_max        = 
inputs[q_elemwise_add::kAMax].data().dptr<float>()[0];
+  const float B_max        = 
inputs[q_elemwise_add::kBMax].data().dptr<float>()[0];
+  const bool is_A_int8     = (inputs[q_elemwise_add::kDataA].dtype() == 
mshadow::kInt8);
+  const bool is_B_int8     = (inputs[q_elemwise_add::kDataB].dtype() == 
mshadow::kInt8);
+  const float A_type_range = is_A_int8 ? kInt8Range : kUint8Range;
+  const float B_type_range = is_B_int8 ? kInt8Range : kUint8Range;
+  const float A_absmax     = MaxAbs(A_min, A_max);
+  const float B_absmax     = MaxAbs(B_min, B_max);
+  const float A_scale      = A_type_range / A_absmax;
+  const float B_scale      = B_type_range / B_absmax;
+  auto A_mem               = inputs[q_elemwise_add::kDataA].GetDNNLData();
+  auto B_mem               = inputs[q_elemwise_add::kDataB].GetDNNLData();
+  bool diff_in_types       = (is_A_int8 != is_B_int8);
+  assert(diff_in_types ==
+         (inputs[q_elemwise_add::kDataA].dtype() != 
inputs[q_elemwise_add::kDataB].dtype()));
   dnnl::memory* rescaled_mem;              // rescaled_mem is for reorder dnnl 
memory
   double output_data_range = kInt32Range;  // output default set as int32
+
   if (outputs[q_elemwise_add::kOut].dtype() == mshadow::kInt8) {
     output_data_range = kInt8Range;
   } else if (outputs[q_elemwise_add::kOut].dtype() == mshadow::kUint8) {
     output_data_range = kUint8Range;
   }
 
-  float output_min     = 0;
-  float output_max     = 0;
-  float output_scale   = 0;
+  float output_min   = 0;
+  float output_max   = 0;
+  float output_scale = 0;
   if (params.max_calib_range.has_value() && 
params.min_calib_range.has_value()) {
     output_min     = params.min_calib_range.value();
     output_max     = params.max_calib_range.value();
     output_scale   = output_data_range / MaxAbs(output_min, output_max);
   } else {
-    output_max = A_absmax + B_absmax;
-    output_min = -output_max;
-  }
-  // 2: scale 0 for input A, scale 1 for input B
-  const int scales_num = 2;
-  std::vector<float> scales(scales_num, 1);
-  auto engine = CpuEngine::Get()->get_engine();
-  if (inputs[q_elemwise_add::kDataA].dtype() != 
inputs[q_elemwise_add::kDataB].dtype()) {
-    auto s8_desc                     = is_A_int8 ? A_mem->get_desc() : 
B_mem->get_desc();
-    rescaled_mem = TmpMemMgr::Get()->Alloc(s8_desc);
-    const float u8_reorder_scale     = 0.5;
-    std::vector<float> reorder_scale = {u8_reorder_scale};
-    dnnl::primitive_attr reorder_attr;
-    reorder_attr.set_output_scales(0, reorder_scale);
-    auto u8_mem = (is_A_int8 == true) ? B_mem : A_mem;
-    const auto reorder_pd =
-        dnnl::reorder::primitive_desc(engine, u8_mem->get_desc(), engine, 
s8_desc, reorder_attr);
-    dnnl_args_map_t args({{DNNL_ARG_FROM, *u8_mem}, {DNNL_ARG_TO, 
*rescaled_mem}});
-    DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd), args);
-
-    if (is_A_int8) {
-      B_mem = rescaled_mem;
-    } else {
-      A_mem = rescaled_mem;
-    }
-    in_range = kInt8Range;  // range after conversion uint8 to common int8
-  } else {
-    // both inputs has the same type
-    in_range = is_A_int8 ? kInt8Range : kUint8Range;
+    output_max   = A_absmax + B_absmax;
+    output_min   = -output_max;
+    output_scale = output_data_range / output_max;
   }
 
-  if (params.max_calib_range.has_value() && 
params.min_calib_range.has_value()) {
-    scales[0] = output_scale / A_scale;
-    scales[1] = output_scale / B_scale;
-  } else {
-    scales[0] = A_absmax * output_data_range / (output_max * in_range);
-    scales[1] = B_absmax * output_data_range / (output_max * in_range);
+  std::vector<float> scales(2);  // 2: scale 0 for input A, scale 1 for input B
+  scales[0] = output_scale / A_scale;
+  scales[1] = output_scale / B_scale;
+
+  // We can use more efficient sum kernel when there is no broadcast - when 
shapes are the same
+  const bool sum_kernel =
+      (inputs[q_elemwise_add::kDataA].shape() == 
inputs[q_elemwise_add::kDataB].shape());
+
+  if (diff_in_types) {
+    if (sum_kernel) {
+      // rescale uint8 to int8 by reorder to temporary memory
+      auto s8_desc                     = is_A_int8 ? A_mem->get_desc() : 
B_mem->get_desc();
+      rescaled_mem                     = TmpMemMgr::Get()->Alloc(s8_desc);
+      const float u8_reorder_scale     = 0.5;
+      std::vector<float> reorder_scale = {u8_reorder_scale};
+      auto engine                      = CpuEngine::Get()->get_engine();
+      dnnl::primitive_attr reorder_attr;
+      reorder_attr.set_output_scales(0, reorder_scale);
+      auto u8_mem = (is_A_int8 == true) ? B_mem : A_mem;
+      const auto reorder_pd =
+          dnnl::reorder::primitive_desc(engine, u8_mem->get_desc(), engine, 
s8_desc, reorder_attr);
+      dnnl_args_map_t args({{DNNL_ARG_FROM, *u8_mem}, {DNNL_ARG_TO, 
*rescaled_mem}});
+      DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd), args);
+      // Modify scale to restore original uint8 values:
+      if (is_A_int8) {
+        B_mem = rescaled_mem;
+        scales[1] *= 1.0 / u8_reorder_scale;
+      } else {
+        A_mem = rescaled_mem;
+        scales[0] *= 1.0 / u8_reorder_scale;
+      }
+    }
   }
 
   std::vector<dnnl::memory::desc> in_desc;
   in_desc.push_back(A_mem->get_desc());
   in_desc.push_back(B_mem->get_desc());
 
-  auto output_md                   = 
outputs[q_elemwise_add::kOut].GetDNNLData()->get_desc();
-  DNNLStream* stream               = DNNLStream::Get();
-  DNNLQuantizedElemwiseSumFwd& fwd = GetQuantizedElemwiseSumForward(output_md, 
scales, in_desc);
-  auto out_mem                     = 
CreateDNNLMem(outputs[q_elemwise_add::kOut],
-                               fwd.fwd_pd.dst_desc(),
-                               req[q_elemwise_add::kOut],
-                               &inputs[q_elemwise_add::kDataA]);
-  dnnl_args_map_t args({{DNNL_ARG_MULTIPLE_SRC, *A_mem},
-                        {DNNL_ARG_MULTIPLE_SRC + 1, *B_mem},
-                        {DNNL_ARG_DST, *out_mem.second}});
-  stream->RegisterPrimArgs(fwd.GetFwd(), args);
+  dnnl_output_t out_mem;
+  auto output_md     = outputs[q_elemwise_add::kOut].GetDNNLData()->get_desc();
+  DNNLStream* stream = DNNLStream::Get();
+
+  if (sum_kernel) {
+    const auto& fwd = DNNLQuantizedSumFwd::GetCached(output_md, scales, 
in_desc);
+    out_mem         = CreateDNNLMem(outputs[q_elemwise_add::kOut],
+                            fwd.fwd_pd.dst_desc(),
+                            req[q_elemwise_add::kOut],
+                            &inputs[q_elemwise_add::kDataA]);
+    const dnnl_args_map_t args({{DNNL_ARG_MULTIPLE_SRC, *A_mem},
+                                {DNNL_ARG_MULTIPLE_SRC + 1, *B_mem},
+                                {DNNL_ARG_DST, *out_mem.second}});
+    stream->RegisterPrimArgs(fwd.GetFwd(), args);
+  } else {
+    const auto& fwd = DNNLQuantizedBinAddFwd::GetCached(output_md, scales, 
in_desc);
+    const auto potentially_inplace_input =
+        (outputs[q_elemwise_add::kOut].GetDNNLData()->get_data_handle() ==
+         inputs[q_elemwise_add::kDataB].GetDNNLData()->get_data_handle()) ?
+            q_elemwise_add::kDataB :
+            q_elemwise_add::kDataA;
+    out_mem = CreateDNNLMem(outputs[q_elemwise_add::kOut],
+                            fwd.fwd_pd.dst_desc(),
+                            req[q_elemwise_add::kOut],
+                            &inputs[potentially_inplace_input]);
+
+    const dnnl_args_map_t args(
+        {{DNNL_ARG_SRC_0, *A_mem}, {DNNL_ARG_SRC_1, *B_mem}, {DNNL_ARG_DST, 
*out_mem.second}});
+    stream->RegisterPrimArgs(fwd.GetFwd(), args);
+  }
   CommitOutput(outputs[q_elemwise_add::kOut], out_mem);
   stream->Submit();
 
@@ -197,6 +270,13 @@ NNVM_REGISTER_OP(_contrib_quantized_elemwise_add)
     .set_attr<bool>("TIsDNNL", true)
     .set_attr_parser(ParamParser<QuantizeElemwiseAddParam>)
     .add_arguments(QuantizeElemwiseAddParam::__FIELDS__());
+
+NNVM_REGISTER_OP(_contrib_quantized_npi_add)
+    .set_attr<FInferStorageType>("FInferStorageType", ElemwiseAddStorageType)
+    .set_attr<FComputeEx>("FComputeEx<cpu>", DNNLQuantizedElemwiseAddForward)
+    .set_attr<bool>("TIsDNNL", true)
+    .set_attr_parser(ParamParser<QuantizeElemwiseAddParam>)
+    .add_arguments(QuantizeElemwiseAddParam::__FIELDS__());
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/quantization/quantized_elemwise_add-inl.h 
b/src/operator/quantization/quantized_elemwise_add-inl.h
index 7935a1471e..4b103dee8c 100644
--- a/src/operator/quantization/quantized_elemwise_add-inl.h
+++ b/src/operator/quantization/quantized_elemwise_add-inl.h
@@ -26,6 +26,7 @@
 #define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_ELEMWISE_ADD_INL_H_
 
 #include "../tensor/elemwise_unary_op.h"
+#include "../tensor/elemwise_binary_broadcast_op.h"
 
 namespace mxnet {
 namespace op {
@@ -54,6 +55,20 @@ enum QuantizedElemwiseAddOutputs { kOut, kMin, kMax };
 enum QuantizedElemwiseAddInputs { kDataA, kDataB, kAMin, kAMax, kBMin, kBMax };
 }  // namespace q_elemwise_add
 
+inline bool QuantizedBinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
+                                          mxnet::ShapeVector* in_attrs,
+                                          mxnet::ShapeVector* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 6U);
+  CHECK_EQ(out_attrs->size(), 3U);
+  SHAPE_ASSIGN_CHECK(*in_attrs, 2, TShape{1});
+  SHAPE_ASSIGN_CHECK(*in_attrs, 3, TShape{1});
+  SHAPE_ASSIGN_CHECK(*in_attrs, 4, TShape{1});
+  SHAPE_ASSIGN_CHECK(*in_attrs, 5, TShape{1});
+  SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape{1});
+  SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape{1});
+  return BinaryBroadcastShapeCommon(attrs, in_attrs, out_attrs);
+}
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/quantization/quantized_elemwise_add.cc 
b/src/operator/quantization/quantized_elemwise_add.cc
index 262f6e8158..0e10dfd6fd 100644
--- a/src/operator/quantization/quantized_elemwise_add.cc
+++ b/src/operator/quantization/quantized_elemwise_add.cc
@@ -138,5 +138,55 @@ 
NNVM_REGISTER_OP(elemwise_add).set_attr<FQuantizedOp>("FQuantizedOp", [](const N
   return node;
 });
 
+NNVM_REGISTER_OP(_contrib_quantized_npi_add)
+    .add_alias("_npx_quantized_npi_add")
+    .describe(R"code(elemwise_add operator for input dataA and input dataB 
data type of int8,
+and accumulates in type int32 for the output. For each argument, two more 
arguments of type
+float32 must be provided representing the thresholds of quantizing argument 
from data
+type float32 to int8. The final outputs contain result in int32, and min
+and max thresholds representing the threholds for quantizing the float32 
output into int32.
+
+.. Note::
+    This operator only supports forward propogation. DO NOT use it in training.
+
+)code")
+    .set_num_inputs([](const NodeAttrs& attrs) {
+      // A, B, A_min, A_max, B_min, B_max
+      return 6;
+    })
+    // C, C_min, C_max
+    .set_num_outputs(3)
+    .set_attr<nnvm::FListInputNames>(
+        "FListInputNames",
+        [](const NodeAttrs& attrs) {
+          return std::vector<std::string>{"lhs", "rhs", "lhs_min", "lhs_max", 
"rhs_min", "rhs_max"};
+        })
+    .set_attr<nnvm::FListOutputNames>(
+        "FListOutputNames",
+        [](const NodeAttrs& attrs) {
+          return std::vector<std::string>{"output", "min_output", 
"max_output"};
+        })
+    .set_attr<nnvm::FInferType>("FInferType", ElemwiseAddType)
+    .set_attr<mxnet::FInferShape>("FInferShape", QuantizedBinaryBroadcastShape)
+    .set_attr<FCompute>("FCompute<cpu>", QuantizedElemwiseAddForward)
+    .set_attr<FNeedRequantize>("FNeedRequantize", [](const NodeAttrs& attrs) { 
return true; })
+    .add_argument("lhs", "NDArray-or-Symbol", "first input")
+    .add_argument("rhs", "NDArray-or-Symbol", "second input")
+    .add_argument("lhs_min", "NDArray-or-Symbol", "3rd input")
+    .add_argument("lhs_max", "NDArray-or-Symbol", "4th input")
+    .add_argument("rhs_min", "NDArray-or-Symbol", "5th input")
+    .add_argument("rhs_max", "NDArray-or-Symbol", "6th input");
+
+NNVM_REGISTER_OP(_npi_add).set_attr<FQuantizedOp>("FQuantizedOp", [](const 
NodeAttrs& attrs) {
+  nnvm::ObjectPtr node = nnvm::Node::Create();
+  node->attrs.op       = Op::Get("_contrib_quantized_npi_add");
+  node->attrs.name     = "quantized_" + attrs.name;
+  node->attrs.dict     = attrs.dict;
+  if (node->op()->attr_parser != nullptr) {
+    node->op()->attr_parser(&(node->attrs));
+  }
+  return node;
+});
+
 }  // namespace op
 }  // namespace mxnet
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 0a7439ba4b..2c21db8ad6 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -41,7 +41,7 @@ bool SupportsRequantizeFusion(const Op* op) {
   static const std::set<const Op*> support_requantize_fusion_ops = {
       Op::Get("_contrib_quantized_elemwise_add"),
       Op::Get("_contrib_quantized_elemwise_mul"),
-      // Op::Get("_contrib_quantized_npi_add") - to be added later on
+      Op::Get("_contrib_quantized_npi_add"),
       Op::Get("_sg_onednn_conv"),
       Op::Get("_sg_onednn_fully_connected"),
       Op::Get("_sg_onednn_selfatt_qk"),
@@ -92,7 +92,7 @@ class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
     const nnvm::Node* raw_new_node = new_node.node;
 
     static const std::set<const Op*> dequantize_fusion_unsupported_ops = {
-        Op::Get("_contrib_quantized_elemwise_add")};
+        Op::Get("_contrib_quantized_elemwise_add"), 
Op::Get("_contrib_quantized_npi_add")};
 
     if (status == SelectStatus::kFail || status == SelectStatus::kSuccess ||
         raw_new_node->is_variable())
diff --git a/src/operator/tensor/elemwise_binary_broadcast_op.h 
b/src/operator/tensor/elemwise_binary_broadcast_op.h
index 3ea6ff78b5..56b29f1911 100644
--- a/src/operator/tensor/elemwise_binary_broadcast_op.h
+++ b/src/operator/tensor/elemwise_binary_broadcast_op.h
@@ -38,11 +38,9 @@
 
 namespace mxnet {
 namespace op {
-inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
-                                 mxnet::ShapeVector* in_attrs,
-                                 mxnet::ShapeVector* out_attrs) {
-  CHECK_EQ(in_attrs->size(), 2U);
-  CHECK_EQ(out_attrs->size(), 1U);
+static inline bool BinaryBroadcastShapeCommon(const nnvm::NodeAttrs& attrs,
+                                              mxnet::ShapeVector* in_attrs,
+                                              mxnet::ShapeVector* out_attrs) {
   mxnet::TShape& lhs = (*in_attrs)[0];
   mxnet::TShape& rhs = (*in_attrs)[1];
 
@@ -79,6 +77,14 @@ inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& 
attrs,
   return shape_is_known(lhs) && shape_is_known(rhs) && shape_is_known(out);
 }
 
+inline bool BinaryBroadcastShape(const nnvm::NodeAttrs& attrs,
+                                 mxnet::ShapeVector* in_attrs,
+                                 mxnet::ShapeVector* out_attrs) {
+  CHECK_EQ(in_attrs->size(), 2U);
+  CHECK_EQ(out_attrs->size(), 1U);
+  return BinaryBroadcastShapeCommon(attrs, in_attrs, out_attrs);
+}
+
 inline bool BinaryBroadcastMulStorageType(const nnvm::NodeAttrs& attrs,
                                           const int dev_mask,
                                           DispatchMode* dispatch_mode,
diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py 
b/tests/python/dnnl/subgraphs/subgraph_common.py
index e3a102a634..c1224d25e3 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -90,7 +90,7 @@ def check_qsym_calibrated(qsym, out_type, name='conv'):
     if k.find('_quantize') != -1:
       assert v['out_type'] == out_type
     if k.find(quantized_op_name) != -1:
-      if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected") 
+      if (quantized_op_name.startswith("quantized_sg_onednn_fully_connected")
           or quantized_op_name.startswith("quantized_sg_onednn_conv")) and 
'enable_float_output' in v:
         continue
       assert 'min_calib_range' in v
@@ -132,7 +132,7 @@ def check_fusion_parameter(sym, attrs_dict):
 
 def check_quantize(net_original, data_shapes, out_type, name='conv',
                    check_calibration=True, check_scale_align=False, 
quantize_mode='full',
-                   attrs_dict={}):
+                   attrs_dict={}, calib_mode='naive', check_fusion=True):
   quantize_granularity_list = ['tensor-wise']
   if name == 'fc':
     quantize_granularity_list += ['channel-wise']
@@ -140,11 +140,18 @@ def check_quantize(net_original, data_shapes, out_type, 
name='conv',
   if name in config:
     name = config[name][OP_NAME]
 
-  sigma = 0.3 if hasattr(net_original, 'alg') is True and net_original.alg == 
'exp' else 0.5
+  sigma = 0.01 if hasattr(net_original, 'alg') is True and net_original.alg == 
'exp' else 0.5
+  if out_type == 'uint8':
+    # Initialize weights and tensors only with positive values to be sure
+    # that results are always positive
+    init = CustomNormalInit(sigma=sigma, bounded=True)
+    min_value = 0
+  else:
+    init = mx.init.Normal(sigma)
+    min_value = -1
 
-  net_original.initialize(init=mx.init.Normal(sigma), force_reinit=True)
+  net_original.initialize(init=init, force_reinit=True)
 
-  min_value = -1 if out_type != 'uint8' else 0
   one_shape = isinstance(data_shapes, tuple)
   if one_shape:
     # replace one shape with list of shapes with one element inside to follow 
later the same schema
@@ -188,13 +195,14 @@ def check_quantize(net_original, data_shapes, out_type, 
name='conv',
                                      exclude_layers=None,
                                      exclude_operators=None,
                                      quantized_dtype=out_type,
-                                     calib_mode='naive',
+                                     calib_mode=calib_mode,
                                      calib_data=calib_data,
                                      num_calib_batches=1,
                                      quantize_mode=quantize_mode,
                                      quantize_granularity=quantize_granularity)
     qsym, _ = qnet.export(None)
-    check_fusion_parameter(qsym, attrs_dict)
+    if check_fusion:
+      check_fusion_parameter(qsym, attrs_dict)
     if check_calibration:
       check_qsym_calibrated(qsym, out_type, name=name)
     if check_scale_align:
diff --git a/tests/python/dnnl/subgraphs/test_fc_subgraph.py 
b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
index 5028750c74..af5defa8df 100644
--- a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
@@ -416,3 +416,58 @@ def test_neg_fc_add_quantized(data_shape, add_op, 
fc_out_add, scaled_fc_out):
   attrs = []
   excluded_attrs = ['with_sum']
   check_neg_fusion_quantized(net, attrs, excluded_attrs, data_shapes, 
name='fc')
+
+
+def function_add_quantized(data_shape, add_op, quantize_mode, relu, out_type, 
broadcast, calib_mode):
+  class SumExample(nn.HybridBlock):
+    def __init__(self,  add_op, **kwargs):
+      super(SumExample, self).__init__(**kwargs)
+      self.elemwise_add = (add_op == 'ele_add')
+      self.relu = (relu == 'relu')
+
+    def forward(self, data1a, data2):
+      fc_out = data1a
+      if self.relu:
+        fc_out = mx.npx.relu(fc_out)
+      if  self.elemwise_add:
+        sum1 = mx.nd.elemwise_add(data2.as_nd_ndarray(), 
fc_out.as_nd_ndarray()).as_np_ndarray()
+      else:
+        sum1 = data2 + fc_out
+      return sum1
+
+  attrs = {add_op: {}}
+  net = SumExample(add_op)
+  if broadcast:
+    broadcasted_shape = (1,) + data_shape[1:-1] + (1,)
+    data_shapes = [broadcasted_shape, data_shape]
+  else:
+    data_shapes = [data_shape, data_shape]
+
+  # check_calibration could be enabled if check_qsym_calibrated will be 
reimplemented
+  # to find operator names instead of node names
+  check_quantize(net, data_shapes, out_type, name="contrib_quantized_" + 
add_op,
+                 quantize_mode=quantize_mode, attrs_dict=attrs, 
calib_mode=calib_mode,
+                 check_calibration=(calib_mode != 'none') and False, 
check_fusion=False)
+
+
[email protected]_np
[email protected]('out_type', ['int8', 'auto'])
[email protected]('calib_mode', ['naive', 'none'])
[email protected]('quantize_mode', ['full', 'smart'])
[email protected]('relu', ['nore', 'relu'])
[email protected]('broadcast', ['broadcast', 'no_broadc'])
[email protected]('add_op', ['ele_add', 'npi_add'])
+def test_add_quantized(add_op, quantize_mode, out_type, relu, broadcast, 
calib_mode):
+  """
+  The test check results from quantization of simple graph
+  with npi_add or elemwise_add with additional relu which force
+  unsigned representation of one inputs to the add operator.
+  Due to construction of quantization code unsigned int8 is never choosen
+  for scenario without calibration as operators always raports min = -max
+  """
+  broadcastB = (broadcast ==  'broadcast')
+  if broadcastB and add_op == 'ele_add':
+    # elemwise_Add doesn't support broadcasting
+    pytest.skip()
+  data_shape = DATA_SHAPE[0]
+  function_add_quantized(data_shape, add_op, quantize_mode, relu, out_type, 
broadcastB, calib_mode)
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index 48c7e606ff..c62391f3bb 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -432,6 +432,83 @@ def test_quantized_elemwise_add():
     check_quantized_elemwise_add((3, 4, 56, 56), 'int8', 'uint8')
     check_quantized_elemwise_add((32, 56, 64, 11), 'int8', 'int8')
 
+@use_np
+def test_quantized_npi_add():
+    def check_quantized_npi_add(data_shape,  qdtypeA, qdtypeB, broadcast=None):
+        if is_test_for_native_cpu():
+            print('skipped testing quantized_npi_add for native cpu since it 
is not supported yet')
+            return
+        elif (qdtypeA != 'uint8' and qdtypeA != 'int8') or (qdtypeB != 'uint8' 
and qdtypeB != 'int8'):
+            print('skipped testing quantized_npi_add for not supported data 
type')
+            return
+        elif is_test_for_gpu():
+            print('skipped testing quantized_npi_add for gpu since it is not 
supported yet')
+            return
+
+        class ElemwiseSumBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(ElemwiseSumBlock, self).__init__(**kwargs)
+
+            def forward(self, dataA, dataB):
+                return dataA + dataB
+
+        class QuantElemwiseSumBlock(mx.gluon.nn.HybridBlock):
+            def __init__(self, **kwargs):
+                super(QuantElemwiseSumBlock, self).__init__(**kwargs)
+
+            def forward(self, dataA, dataB, dataA_min, dataA_max, dataB_min, 
dataB_max):
+                return npx.quantized_npi_add(dataA, dataB, dataA_min, 
dataA_max, dataB_min, dataB_max)
+
+        elemwise_add_fp32 = ElemwiseSumBlock()
+
+        dataA_low, dataA_high = get_low_high(qdtypeA)
+        dataB_low, dataB_high = get_low_high(qdtypeB)
+
+        data_shapeA = data_shape
+        data_shapeB = data_shape
+
+        if broadcast :
+            if broadcast == 'A':
+                data_shapeA = ()
+                for index in range(len(data_shape)):
+                    data_shapeA += (1,)
+            else:
+                data_shapeB = ()
+                for index in range(len(data_shape)):
+                    data_shapeB += (1,)
+
+        dataA_val = mx.np.random.uniform(low=dataA_low, high=dataA_high, 
size=data_shapeA).astype('int32').astype('float32')
+        dataB_val = mx.np.random.uniform(low=dataB_low, high=dataB_high, 
size=data_shapeB).astype('int32').astype('float32')
+
+        output = elemwise_add_fp32(dataA_val, dataB_val)
+
+        #run quantized
+        quantized_elemwise_add = QuantElemwiseSumBlock()
+        dataA_val_int8 = dataA_val.astype(qdtypeA)
+        dataB_val_int8 = dataB_val.astype(qdtypeB)
+        quantized_range = 127.0
+        min_dataA = mx.np.array([dataA_low])
+        max_dataA = mx.np.array([dataA_high])
+        min_dataB = mx.np.array([dataB_low])
+        max_dataB = mx.np.array([dataB_high])
+        qoutput, min_range, max_range = quantized_elemwise_add(dataA_val_int8, 
dataB_val_int8,
+                                                               min_dataA, 
max_dataA,
+                                                               min_dataB, 
max_dataB)
+        int8_rslt = qoutput.astype(output.dtype) * max_range / 0x7fffffff
+        diff = mx.np.abs(output - int8_rslt)
+        cond = mx.np.less(2, diff).sum().item()
+        assert cond == 0
+
+    check_quantized_npi_add((4, 6), 'uint8', 'int8')
+    check_quantized_npi_add((13, 74, 52), 'uint8', 'uint8')
+    check_quantized_npi_add((3, 4, 56, 56), 'int8', 'uint8')
+    check_quantized_npi_add((32, 56, 64, 11), 'int8', 'int8')
+
+    check_quantized_npi_add((4, 6), 'uint8', 'int8', 'A')
+    check_quantized_npi_add((13, 74, 52), 'uint8', 'uint8', 'B')
+    check_quantized_npi_add((3, 4, 56, 56), 'int8', 'uint8', 'A')
+    check_quantized_npi_add((32, 56, 64, 11), 'int8', 'int8', 'B')
+
 
 @use_np
 def test_quantized_elemwise_mul():

Reply via email to