This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e9840b8  [FEAUTURE] Fuses FC + elemwise_add operators  for oneDNN 
(#20821)
e9840b8 is described below

commit e9840b89f3305c446e4bab0402232abc68bde6dc
Author: Andrzej KotÅ‚owski <[email protected]>
AuthorDate: Fri Jan 28 14:57:28 2022 +0100

    [FEAUTURE] Fuses FC + elemwise_add operators  for oneDNN (#20821)
    
    * Fix elemwise_add post quantization pass
    
    * Align naming convention with convolution operator
    
    Convolution uses convention data_name_[min|max] which is object
    oriented and more readable.
    
    * Fuse FC with elemwise_add
    
    * [TEST] Add functional tests for FC + add operators fusion
    
    * [TESTS] Disable quantization check for not supported cases
    
    * Take into account already fused elemwise operation
    
    Fix for fusing already fused FC + relu/activation for floating point is
    added.
    Fusing elemwise_add with FC with already fused relu/activation is
    blocked due to accuracy issues.
---
 src/operator/nn/dnnl/dnnl_fully_connected-inl.h    |  85 +++++
 src/operator/nn/dnnl/dnnl_fully_connected.cc       |   3 +
 src/operator/subgraph/build_subgraph.cc            |  11 +
 src/operator/subgraph/dnnl/dnnl_fc.cc              | 424 +++++++++++++--------
 src/operator/subgraph/dnnl/dnnl_fc_property.h      |   3 +
 src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h      | 291 ++++++++++++++
 .../subgraph/dnnl/dnnl_post_quantize_property.h    |   8 +-
 .../subgraph/dnnl/dnnl_subgraph_property.cc        |   4 +
 tests/python/dnnl/subgraphs/subgraph_common.py     | 154 ++++++--
 tests/python/dnnl/subgraphs/test_conv_subgraph.py  |  40 +-
 tests/python/dnnl/subgraphs/test_fc_subgraph.py    | 144 ++++++-
 11 files changed, 936 insertions(+), 231 deletions(-)

diff --git a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h 
b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
index c30ad4b..4196b15 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
+++ b/src/operator/nn/dnnl/dnnl_fully_connected-inl.h
@@ -28,6 +28,8 @@
 
 #if MXNET_USE_ONEDNN == 1
 
+#include <memory>
+#include <unordered_map>
 #include <string>
 #include <vector>
 
@@ -41,6 +43,8 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
   bool quantized;
   bool enable_float_output;
   bool with_eltwise;
+  bool with_sum;
+  bool first_quantization_pass;  // True for operator created during first 
quantization pass
   dmlc::optional<float> min_calib_range;  // min float value calculated from 
calibration dataset
   dmlc::optional<float> max_calib_range;  // max float value calculated from 
calibration dataset
   dmlc::optional<bool> channel_wise_quantize;
@@ -54,6 +58,10 @@ struct DNNLFCParam : public dmlc::Parameter<DNNLFCParam> {
     DMLC_DECLARE_FIELD(with_eltwise)
         .set_default(false)
         .describe("Whether there's a post with_eltwise after FullyConnected 
operator");
+    DMLC_DECLARE_FIELD(with_sum).set_default(false).describe("Add post sum");
+    DMLC_DECLARE_FIELD(first_quantization_pass)
+        .set_default(false)
+        .describe("True for first quantization pass");
     DMLC_DECLARE_FIELD(min_calib_range)
         .set_default(dmlc::optional<float>())
         .describe(
@@ -76,9 +84,86 @@ struct DNNLFCFullParam {
   FullyConnectedParam default_param;
   DNNLFCParam dnnl_param;
   DNNLPostEltwiseParam eltwise_param;
+  float sum_scale                  = {1.0f};
   std::vector<float> output_scales = {0.0f};
 };
 
+static inline size_t GetInSumIndex(const DNNLFCFullParam& param) {
+  assert(param.dnnl_param.with_sum);
+  return fullc::kWeight + 1 + (param.default_param.no_bias ? 0 : 1);
+}
+
+class FCInputIndex {
+ public:
+  explicit FCInputIndex(const DNNLFCFullParam full_param) {
+    auto& dnnl_param     = full_param.dnnl_param;
+    const bool has_bias  = !full_param.default_param.no_bias;
+    const bool quantized = dnnl_param.quantized;
+    const bool sum_input_quantized =
+        quantized && dnnl_param.with_sum && !dnnl_param.enable_float_output;
+    const bool channel_wise = quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
+                              dnnl_param.channel_wise_quantize.value();
+
+    // Calculate position of particular input in the input vector:
+    int index = 0;
+    data      = index++;
+    weight    = index++;
+    bias      = has_bias ? index++ : 0;
+    sum       = dnnl_param.with_sum ? index++ : 0;
+    num_base  = index;  // note number of base inputs
+
+    data_min   = quantized ? index++ : 0;
+    data_max   = quantized ? index++ : 0;
+    weight_min = (quantized && !channel_wise) ? index++ : 0;
+    weight_max = (quantized && !channel_wise) ? index++ : 0;
+    bias_min   = (quantized && !channel_wise && has_bias) ? index++ : 0;
+    bias_max   = (quantized && !channel_wise && has_bias) ? index++ : 0;
+    sum_min    = sum_input_quantized ? index++ : 0;
+    sum_max    = sum_input_quantized ? index++ : 0;
+    num_total  = index;  // note number of total inputs
+  }
+
+  // Returns true if sum input exists
+  bool IsSumExist() const {
+    return sum;
+  }
+
+  // Returns true if bias input exists
+  bool IsBiasExist() const {
+    return bias;
+  }
+
+  // Returns true if sum input exists and it is float number
+  bool IsSumInputFloat() const {
+    return (sum && !sum_min);
+  }
+  int GetTotal() const {
+    return num_total;
+  }
+  int GetBase() const {
+    return num_base;
+  }
+
+  // Represent index of particular input in the input vector:
+  int data;
+  int weight;
+  int bias;
+  int sum;
+  int data_min;
+  int data_max;
+  int weight_min;
+  int weight_max;
+  int bias_min;
+  int bias_max;
+  int sum_min;
+  int sum_max;
+
+ private:
+  int num_base;   // Number of standard inputs
+  int num_total;  // Number of total inputs: standard + additional needed for
+                  // quantization
+};
+
 dnnl::inner_product_forward::primitive_desc GetFCFwdImpl(const 
DNNLFCFullParam& full_param,
                                                          const bool is_train,
                                                          const NDArray& data,
diff --git a/src/operator/nn/dnnl/dnnl_fully_connected.cc 
b/src/operator/nn/dnnl/dnnl_fully_connected.cc
index eca90b7..6f04b19 100644
--- a/src/operator/nn/dnnl/dnnl_fully_connected.cc
+++ b/src/operator/nn/dnnl/dnnl_fully_connected.cc
@@ -53,6 +53,9 @@ dnnl::inner_product_forward::primitive_desc 
GetFCFwdImpl(const DNNLFCFullParam&
                        full_param.eltwise_param.alpha,
                        full_param.eltwise_param.beta);
   }
+  if (full_param.dnnl_param.with_sum) {
+    ops.append_sum(full_param.sum_scale);
+  }
   attr.set_post_ops(ops);
 
   if (full_param.dnnl_param.quantized && full_param.output_scales.size()) {
diff --git a/src/operator/subgraph/build_subgraph.cc 
b/src/operator/subgraph/build_subgraph.cc
index ef1218b..4acaa22 100644
--- a/src/operator/subgraph/build_subgraph.cc
+++ b/src/operator/subgraph/build_subgraph.cc
@@ -749,6 +749,17 @@ void CreateSubgraphNode(nnvm::Graph* g,
         for (BiDirectedNode* dest_node : subgraph_nodes) {
           sn->outputs.erase(dest_node->node);
         }
+      }
+    }
+
+    // Set outputs according to current inputs
+    for (size_t i = 0; i < n->inputs.size(); ++i) {
+      auto& e = n->inputs[i];
+      // update input entries' source simple nodes' outputs map
+      nnvm::Node* node = e.node.get();
+      if (indexed_graph.exist(node)) {
+        const auto nid     = indexed_graph.node_id(node);
+        BiDirectedNode* sn = simple_nodes[nid].get();
         sn->outputs[n.get()].push_back(i);
       }
     }
diff --git a/src/operator/subgraph/dnnl/dnnl_fc.cc 
b/src/operator/subgraph/dnnl/dnnl_fc.cc
index 49887fe..8ead3e7 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc.cc
+++ b/src/operator/subgraph/dnnl/dnnl_fc.cc
@@ -25,7 +25,10 @@
 
 #if MXNET_USE_ONEDNN == 1
 
+#include <memory>
 #include <string>
+#include <unordered_map>
+#include <unordered_set>
 #include <utility>
 #include <vector>
 
@@ -62,8 +65,8 @@ class SgDNNLFCOp {
 
  private:
   bool initialized_{false};
-  bool channel_wise_runtime_{false};
   bool reorder_data_{false};
+  bool inplace_{false};
   nnvm::Symbol subgraph_sym_;
   DNNLFCFullParam full_param_;
   dnnl_args_map_t args_;
@@ -72,83 +75,123 @@ class SgDNNLFCOp {
   std::shared_ptr<dnnl::memory> cached_out_mem_;
   NDArray cached_weight_;
   NDArray cached_bias_;
-  float cached_min_data_;
-  float cached_max_data_;
-  float cached_min_weight_;
-  float cached_max_weight_;
-  float cached_min_bias_;
-  float cached_max_bias_;
+  float cached_data_min_;
+  float cached_data_max_;
+  float cached_weight_min_;
+  float cached_weight_max_;
+  float cached_sum_min_;
+  float cached_sum_max_;
+  float cached_bias_min_;
+  float cached_bias_max_;
   size_t weight_ver_;
   size_t bias_ver_;
-  float cached_min_output_;
-  float cached_max_output_;
+  float cached_output_min_;
+  float cached_output_max_;
   float data_scale_{0.0f};
   std::vector<float> weight_scales_;
-  size_t total_num_inputs_;
-  size_t total_num_outputs_;
 };
 
 void SgDNNLFCOp::Forward(const OpContext& ctx,
                          const std::vector<NDArray>& in_data,
                          const std::vector<OpReqType>& req,
                          const std::vector<NDArray>& out_data) {
-  auto& dnnl_param        = full_param_.dnnl_param;
-  auto& default_param     = full_param_.default_param;
-  bool has_bias           = !default_param.no_bias;
-  size_t base_num_inputs  = has_bias ? 3 : 2;
-  size_t base_num_outputs = 1;
-
-  float min_data   = 0.0f;
-  float max_data   = 0.0f;
-  float min_weight = 0.0f;
-  float max_weight = 0.0f;
-  float min_bias   = 0.0f;
-  float max_bias   = 0.0f;
-
-  if (!initialized_) {
-    if (dnnl_param.channel_wise_quantize.has_value() && 
dnnl_param.channel_wise_quantize) {
-      channel_wise_runtime_ = true;
+  auto& dnnl_param         = full_param_.dnnl_param;
+  auto& default_param      = full_param_.default_param;
+  const bool has_bias      = !default_param.no_bias;
+  const bool quantized     = dnnl_param.quantized;
+  const bool out_quantized = dnnl_param.quantized && 
!dnnl_param.enable_float_output;
+  const bool channel_wise  = quantized && 
dnnl_param.channel_wise_quantize.has_value() &&
+                            dnnl_param.channel_wise_quantize.value();
+
+  const FCInputIndex idx(full_param_);
+
+  CHECK_EQ(in_data.size(), idx.GetTotal());
+
+  int index               = 0;
+  const int out_index     = index++;
+  const int out_min_index = out_quantized ? index++ : 0;
+  const int out_max_index = out_quantized ? index++ : 0;
+  CHECK_EQ(out_data.size(), index);  // index is equal to total number of 
outputs
+
+  float data_min   = 0.0f;
+  float data_max   = 0.0f;
+  float weight_min = 0.0f;
+  float weight_max = 0.0f;
+  float bias_min   = 0.0f;
+  float bias_max   = 0.0f;
+
+  const float sum_min   = idx.sum_min ? 
in_data[idx.sum_min].data().dptr<float>()[0] : 0.0;
+  const float sum_max   = idx.sum_max ? 
in_data[idx.sum_max].data().dptr<float>()[0] : 0.0;
+  NDArray data          = in_data[idx.data];
+  const NDArray& weight = in_data[idx.weight];
+  NDArray output;
+
+  if (dnnl_param.with_sum) {
+    if (!initialized_) {
+      // TODO(zhennan): Currently, dnnl fallback mechanism will break inplace 
option,
+      // which make check (req[out_index] == kWriteInplace) useless.
+      auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
+      auto out_dnnl_mem = static_cast<const 
dnnl::memory*>(out_data[out_index].GetDNNLData());
+      if (in_dnnl_mem->get_data_handle() == out_dnnl_mem->get_data_handle()) {
+        inplace_ = true;
+      }
     }
-
-    total_num_inputs_  = base_num_inputs;
-    total_num_outputs_ = base_num_outputs;
-    if (dnnl_param.quantized) {
-      total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : 
(base_num_inputs * 3);
-      total_num_outputs_ =
-          dnnl_param.enable_float_output ? base_num_outputs : 
(base_num_outputs * 3);
+    if (inplace_) {
+      output = in_data[idx.sum];
+    } else {
+      // Not in place: copy in_data[idx.sum] into outputs[out_index].
+      auto in_dnnl_mem  = static_cast<const 
dnnl::memory*>(in_data[idx.sum].GetDNNLData());
+      auto out_dnnl_mem = static_cast<const 
dnnl::memory*>(out_data[out_index].GetDNNLData());
+      if (out_data[out_index].dtype() == mshadow::kInt32) {
+        auto mem_desc           = in_dnnl_mem->get_desc();
+        auto this_dtype         = get_dnnl_type(mshadow::kInt32);
+        mem_desc.data.data_type = static_cast<dnnl_data_type_t>(this_dtype);
+        dnnl_mem_ptr tmp_mem(new dnnl::memory(
+            mem_desc, CpuEngine::Get()->get_engine(), 
out_dnnl_mem->get_data_handle()));
+        DNNLStream::Get()->RegisterMem(tmp_mem);
+        DNNLStream::Get()->RegisterPrimArgs(
+            dnnl::reorder(*in_dnnl_mem, *tmp_mem),
+            {{DNNL_ARG_FROM, *in_dnnl_mem}, {DNNL_ARG_TO, *tmp_mem}});
+        output = NDArray(tmp_mem);
+      } else {
+        dnnl_mem_ptr tmp_mem(new dnnl::memory(in_dnnl_mem->get_desc(),
+                                              CpuEngine::Get()->get_engine(),
+                                              
out_dnnl_mem->get_data_handle()));
+        DNNLStream::Get()->RegisterMem(tmp_mem);
+        DNNLMemoryCopy(*in_dnnl_mem, tmp_mem.get());
+        output = NDArray(tmp_mem);
+      }
     }
+  } else {
+    output = out_data[out_index];
   }
-  CHECK_EQ(in_data.size(), total_num_inputs_);
-  CHECK_EQ(out_data.size(), total_num_outputs_);
-
-  NDArray data          = in_data[fullc::kData];
-  const NDArray& weight = in_data[fullc::kWeight];
-  const NDArray& output = out_data[fullc::kOut];
 
   if (dnnl_param.quantized) {
-    if (!channel_wise_runtime_) {
-      min_weight = in_data[base_num_inputs + 
quantized_fullc::kWeightMin].data().dptr<float>()[0];
-      max_weight = in_data[base_num_inputs + 
quantized_fullc::kWeightMax].data().dptr<float>()[0];
+    if (!channel_wise) {
+      weight_min = in_data[idx.weight_min].data().dptr<float>()[0];
+      weight_max = in_data[idx.weight_max].data().dptr<float>()[0];
       if (has_bias) {
-        min_bias = in_data[base_num_inputs + 
quantized_fullc::kBiasMin].data().dptr<float>()[0];
-        max_bias = in_data[base_num_inputs + 
quantized_fullc::kBiasMax].data().dptr<float>()[0];
+        bias_min = in_data[idx.bias_min].data().dptr<float>()[0];
+        bias_max = in_data[idx.bias_max].data().dptr<float>()[0];
       }
     }
-    min_data = in_data[base_num_inputs + 
quantized_fullc::kDataMin].data().dptr<float>()[0];
-    max_data = in_data[base_num_inputs + 
quantized_fullc::kDataMax].data().dptr<float>()[0];
+    data_min = in_data[idx.data_min].data().dptr<float>()[0];
+    data_max = in_data[idx.data_max].data().dptr<float>()[0];
   }
 
   if (initialized_ && dnnl_param.quantized && 
dmlc::GetEnv("MXNET_ONEDNN_QFC_DYNAMIC_PARAMS", 0)) {
-    if (channel_wise_runtime_) {
-      if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
+    if (channel_wise) {
+      if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
+          cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
           weight_ver_ != weight.version() ||
-          (has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
+          (has_bias && (bias_ver_ != in_data[idx.bias].version()))) {
         initialized_ = false;
       }
     } else {
-      if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
-          cached_min_weight_ != min_weight || cached_max_weight_ != max_weight 
||
-          (has_bias && (cached_min_bias_ != min_bias || cached_max_bias_ != 
max_bias))) {
+      if (cached_data_min_ != data_min || cached_data_max_ != data_max ||
+          cached_sum_min_ != sum_min || cached_sum_max_ != sum_max ||
+          cached_weight_min_ != weight_min || cached_weight_max_ != weight_max 
||
+          (has_bias && (cached_bias_min_ != bias_min || cached_bias_max_ != 
bias_max))) {
         initialized_ = false;
       }
     }
@@ -157,17 +200,19 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
   if (!initialized_) {
     const auto nthreads = 
engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
     const auto engine   = CpuEngine::Get()->get_engine();
-    cached_min_data_    = min_data;
-    cached_max_data_    = max_data;
-    cached_min_weight_  = min_weight;
-    cached_max_weight_  = max_weight;
+    cached_data_min_    = data_min;
+    cached_data_max_    = data_max;
+    cached_weight_min_  = weight_min;
+    cached_weight_max_  = weight_max;
     weight_ver_         = weight.version();
     cached_weight_      = weight;
+    cached_sum_min_     = sum_min;
+    cached_sum_max_     = sum_max;
     if (has_bias) {
-      cached_min_bias_ = min_bias;
-      cached_max_bias_ = max_bias;
-      bias_ver_        = in_data[fullc::kBias].version();
-      cached_bias_     = in_data[fullc::kBias];
+      cached_bias_min_ = bias_min;
+      cached_bias_max_ = bias_max;
+      bias_ver_        = in_data[idx.bias].version();
+      cached_bias_     = in_data[idx.bias];
     } else {
       cached_bias_ = NDArray();
     }
@@ -210,13 +255,13 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
     bool support_channelwise_scale = false;
     if (dnnl_param.quantized) {
       CHECK(data.dtype() == mshadow::kInt8 || data.dtype() == mshadow::kUint8);
-      data_scale_ = GetQuantizeScale(data.dtype(), cached_min_data_, 
cached_max_data_);
+      data_scale_ = GetQuantizeScale(data.dtype(), cached_data_min_, 
cached_data_max_);
 
       bool fuse_requantize = false;
       // Channelwise scaling is only supported when fusion is enabled 
(requantize or dequantize).
       if (dnnl_param.min_calib_range.has_value() && 
dnnl_param.max_calib_range.has_value()) {
-        cached_min_output_        = dnnl_param.min_calib_range.value();
-        cached_max_output_        = dnnl_param.max_calib_range.value();
+        cached_output_min_        = dnnl_param.min_calib_range.value();
+        cached_output_max_        = dnnl_param.max_calib_range.value();
         support_channelwise_scale = true;
         fuse_requantize           = true;
       }
@@ -227,7 +272,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
       // True          True                       True
       // True          False                      Error
       // False         True/False                 False
-      if (channel_wise_runtime_ && !support_channelwise_scale) {
+      if (channel_wise && !support_channelwise_scale) {
         LOG(FATAL)
             << "Currently, channel-wise quantization requires fuse requantize 
or dequantize."
             << " Please make sure the `min_calib_range` and `max_calib_range` 
are set when only"
@@ -236,7 +281,7 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
             << " or the env var of `MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT` and 
"
             << " `MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL` are not set to true 
(default is false)";
       }
-      support_channelwise_scale = support_channelwise_scale && 
channel_wise_runtime_;
+      support_channelwise_scale = support_channelwise_scale && channel_wise;
 
       if (support_channelwise_scale) {
         MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
@@ -248,51 +293,56 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
       } else {
         weight_scales_.resize(1);
         weight_scales_[0] =
-            GetQuantizeScale(cached_weight_.dtype(), cached_min_weight_, 
cached_max_weight_);
+            GetQuantizeScale(cached_weight_.dtype(), cached_weight_min_, 
cached_weight_max_);
         if (has_bias) {
-          float bias_scale = GetQuantizeScale(mshadow::kInt8, 
cached_min_bias_, cached_max_bias_);
-          float bias_int32_rescale = data_scale_ * weight_scales_[0] / 
bias_scale;
-          // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set the 
maximum value
-          // of bias to INT_MAX / 2.
-          float bias_max_rescale =
-              MaxValue<int32_t>() / 2 / MaxAbs(cached_min_bias_, 
cached_max_bias_) / bias_scale;
-          if (bias_int32_rescale > bias_max_rescale) {
-            // avoid overflow on bias
-            bias_int32_rescale = bias_max_rescale;
-            float weight_rescale =
-                bias_int32_rescale * bias_scale / data_scale_ / 
weight_scales_[0];
-            int8_t* weight_ptr = weight.data().dptr<int8_t>();
-            size_t weight_size = weight.shape().Size();
+          if (cached_bias_.dtype() == mshadow::kInt8) {
+            float bias_scale = GetQuantizeScale(mshadow::kInt8, 
cached_bias_min_, cached_bias_max_);
+
+            float bias_int32_rescale = data_scale_ * weight_scales_[0] / 
bias_scale;
+            // TODO(zhennan): dnnl has bug to handle INT_MAX in bias, so set
+            // the maximum value of bias to INT_MAX / 2.
+            float bias_max_rescale =
+                MaxValue<int32_t>() / 2 / MaxAbs(cached_bias_min_, 
cached_bias_max_) / bias_scale;
+            if (bias_int32_rescale > bias_max_rescale) {
+              // avoid overflow on bias
+              bias_int32_rescale = bias_max_rescale;
+              float weight_rescale =
+                  bias_int32_rescale * bias_scale / data_scale_ / 
weight_scales_[0];
+              int8_t* weight_ptr = weight.data().dptr<int8_t>();
+              size_t weight_size = weight.shape().Size();
 #pragma omp parallel for num_threads(nthreads)
-            for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
-              weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+              for (index_t i = 0; i < static_cast<index_t>(weight_size); ++i) {
+                weight_ptr[i] = std::round(weight_ptr[i] * weight_rescale);
+              }
+              weight_scales_[0] *= weight_rescale;
             }
-            weight_scales_[0] *= weight_rescale;
-          }
-          NDArray bias = in_data[fullc::kBias];
-          cached_bias_ =
-              NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, 
mshadow::kInt32);
-          int8_t* bias_ptr            = bias.data().dptr<int8_t>();
-          int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
-          size_t bias_size            = bias.shape().Size();
+            NDArray bias = in_data[fullc::kBias];
+            cached_bias_ =
+                NDArray(bias.storage_type(), bias.shape(), bias.ctx(), true, 
mshadow::kInt32);
+            int8_t* bias_ptr            = bias.data().dptr<int8_t>();
+            int32_t* quantized_bias_ptr = cached_bias_.data().dptr<int32_t>();
+            size_t bias_size            = bias.shape().Size();
+
 #pragma omp parallel for num_threads(nthreads)
-          for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
-            quantized_bias_ptr[i] = std::round(bias_ptr[i] * 
bias_int32_rescale);
+            for (index_t i = 0; i < static_cast<index_t>(bias_size); ++i) {
+              quantized_bias_ptr[i] = std::round(bias_ptr[i] * 
bias_int32_rescale);
+            }
           }
         }
       }
 
       size_t num_channel = cached_weight_.shape()[0];
+      float out_scale    = 1.0f;
       if (fuse_requantize || dnnl_param.enable_float_output) {
         float tmp_scale_ = 1.0f;
         if (fuse_requantize) {
           if (dnnl_param.with_eltwise) {
             tmp_scale_ = 1.0 / data_scale_;
             full_param_.eltwise_param.scale =
-                GetQuantizeScale(output.dtype(), cached_min_output_, 
cached_max_output_);
+                GetQuantizeScale(output.dtype(), cached_output_min_, 
cached_output_max_);
           } else {
-            tmp_scale_ = GetQuantizeScale(output.dtype(), cached_min_output_, 
cached_max_output_) /
-                         data_scale_;
+            out_scale  = GetQuantizeScale(output.dtype(), cached_output_min_, 
cached_output_max_);
+            tmp_scale_ = out_scale / data_scale_;
           }
         } else {
           tmp_scale_ = 1.0 / data_scale_;
@@ -314,26 +364,33 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
           mxnet_op::Kernel<QuantizationRangeForS8S8MultiplicationStruct, 
cpu>::Launch(
               s,
               1,
-              &cached_min_output_,
-              &cached_max_output_,
-              &min_data,
-              &max_data,
-              &min_weight,
-              &max_weight);
+              &cached_output_min_,
+              &cached_output_max_,
+              &data_min,
+              &data_max,
+              &weight_min,
+              &weight_max);
         } else {
           mxnet_op::Kernel<QuantizationRangeForS8U8MultiplicationStruct, 
cpu>::Launch(
               s,
               1,
-              &cached_min_output_,
-              &cached_max_output_,
-              &min_data,
-              &max_data,
-              &min_weight,
-              &max_weight);
+              &cached_output_min_,
+              &cached_output_max_,
+              &data_min,
+              &data_max,
+              &weight_min,
+              &weight_max);
         }
         full_param_.output_scales.resize(0);
+        out_scale = data_scale_ * weight_scales_[0];
       }
-    }
+
+      if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+        float sum_in_scale =
+            GetQuantizeScale(in_data[idx.sum].dtype(), cached_sum_min_, 
cached_sum_max_);
+        full_param_.sum_scale = out_scale / sum_in_scale;
+      }
+    }  // if (dnnl_param.quantized)
 
     fwd_.reset(new DNNLFullyConnectedForward(full_param_,
                                              ctx.is_train,
@@ -357,10 +414,11 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
                              weight_scales_,
                              false);
     } else {
-      const auto def_weight_mem = weight.GetDNNLData();
+      const auto def_weight_mem = static_cast<const 
dnnl::memory*>(weight.GetDNNLData());
       if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
-        cached_weight_         = NDArray(fwd_->fwd_pd.weights_desc());
-        auto cached_weight_mem = cached_weight_.GetDNNLData();
+        auto weight_desc       = fwd_->fwd_pd.weights_desc();
+        cached_weight_         = NDArray(weight_desc);
+        auto cached_weight_mem = static_cast<const 
dnnl::memory*>(cached_weight_.GetDNNLData());
         std::unordered_map<int, dnnl::memory> args(
             {{DNNL_ARG_FROM, *def_weight_mem}, {DNNL_ARG_TO, 
*cached_weight_mem}});
         DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(*def_weight_mem, 
*cached_weight_mem),
@@ -368,17 +426,32 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
       }
     }
 
-    const auto data_mem = data.GetDNNLData();
+    const auto data_mem = static_cast<const dnnl::memory*>(data.GetDNNLData());
     cached_data_mem_    = std::make_shared<dnnl::memory>(data_mem->get_desc(), 
engine);
 
     args_[DNNL_ARG_SRC]     = *cached_data_mem_;
-    args_[DNNL_ARG_WEIGHTS] = *cached_weight_.GetDNNLData();
+    args_[DNNL_ARG_WEIGHTS] = *static_cast<const 
dnnl::memory*>(cached_weight_.GetDNNLData());
     if (has_bias)
-      args_[DNNL_ARG_BIAS] = *cached_bias_.GetDNNLData();
+      args_[DNNL_ARG_BIAS] = *static_cast<const 
dnnl::memory*>(cached_bias_.GetDNNLData());
     args_[DNNL_ARG_DST] = *cached_out_mem_;
     initialized_        = true;
   }
 
+  if (dnnl_param.with_sum) {
+    const auto& output_mem   = output.GetDNNLData();
+    const auto& out_mem_desc = output_mem->get_desc();
+    auto dst_mem_desc        = fwd_->fwd_pd.dst_desc();
+    if (out_mem_desc != dst_mem_desc) {
+      auto tmp_out_mem            = output.GetDNNLDataReorder(dst_mem_desc);
+      dst_mem_desc.data.data_type = out_mem_desc.data.data_type;
+      dnnl_mem_ptr new_out_mem(new dnnl::memory(
+          dst_mem_desc, CpuEngine::Get()->get_engine(), 
output_mem->get_data_handle()));
+      DNNLStream::Get()->RegisterMem(new_out_mem);
+      DNNLMemoryCopy(*tmp_out_mem, new_out_mem.get());
+      output = NDArray(new_out_mem);
+    }
+  }
+
   if (reorder_data_) {
     data = data.Reorder2Default();
   }
@@ -392,10 +465,11 @@ void SgDNNLFCOp::Forward(const OpContext& ctx,
   DNNLStream::Get()->Submit();
 
   if (dnnl_param.quantized && !dnnl_param.enable_float_output) {
-    float* min_output_ptr = 
out_data[quantized_fullc::kOutMin].data().dptr<float>();
-    float* max_output_ptr = 
out_data[quantized_fullc::kOutMax].data().dptr<float>();
-    *min_output_ptr       = cached_min_output_;
-    *max_output_ptr       = cached_max_output_;
+    float* output_min_ptr = out_data[out_min_index].data().dptr<float>();
+    float* output_max_ptr = out_data[out_max_index].data().dptr<float>();
+
+    *output_min_ptr = cached_output_min_;
+    *output_max_ptr = cached_output_max_;
   }
 }
 
@@ -450,23 +524,25 @@ static void SgDNNLFCParamParser(nnvm::NodeAttrs* attrs) {
 
 static std::vector<std::string> SgDNNLFCListInputNames(const NodeAttrs& attrs) 
{
   auto const& full_param               = 
nnvm::get<DNNLFCFullParam>(attrs.parsed);
+  auto const& dnnl_param               = full_param.dnnl_param;
   std::vector<std::string> input_names = DefaultSubgraphOpListInputs(attrs);
-  if (full_param.dnnl_param.quantized) {
-    bool channel_wise = false;
-    if (full_param.dnnl_param.channel_wise_quantize.has_value() &&
-        full_param.dnnl_param.channel_wise_quantize) {
-      channel_wise = true;
-    }
-    input_names.emplace_back("min_data");
-    input_names.emplace_back("max_data");
+  if (dnnl_param.quantized) {
+    const bool channel_wise =
+        dnnl_param.channel_wise_quantize.has_value() && 
dnnl_param.channel_wise_quantize;
+    input_names.emplace_back("data_min");
+    input_names.emplace_back("data_max");
     if (!channel_wise) {
-      input_names.emplace_back("min_weight");
-      input_names.emplace_back("max_weight");
+      input_names.emplace_back("weight_min");
+      input_names.emplace_back("weight_max");
       if (!full_param.default_param.no_bias) {
-        input_names.emplace_back("min_bias");
-        input_names.emplace_back("max_bias");
+        input_names.emplace_back("bias_min");
+        input_names.emplace_back("bias_max");
       }
     }
+    if (dnnl_param.with_sum && !dnnl_param.enable_float_output) {
+      input_names.emplace_back("sum_min");
+      input_names.emplace_back("sum_max");
+    }
   }
   return input_names;
 }
@@ -477,19 +553,19 @@ static std::vector<std::string> 
SgDNNLFCListOutputNames(const NodeAttrs& attrs)
     if (full_param.dnnl_param.enable_float_output)
       return std::vector<std::string>{"output"};
     else
-      return std::vector<std::string>{"output", "min_output", "max_output"};
+      return std::vector<std::string>{"output", "output_min", "output_max"};
   } else {
     return std::vector<std::string>{"output"};
   }
 }
 
 template <typename T>
-static inline void FillBaseInputOutputInfo(const FullyConnectedParam& param,
+static inline void FillBaseInputOutputInfo(const DNNLFCFullParam& param,
                                            std::vector<T>* base_in_attrs,
                                            std::vector<T>* base_out_attrs,
                                            std::vector<T>* in_attrs,
                                            std::vector<T>* out_attrs) {
-  auto base_num_inputs = param.no_bias ? 2 : 3;
+  auto base_num_inputs = FCInputIndex(param).GetBase();
 
   base_out_attrs->push_back(out_attrs->at(0));
   for (int i = 0; i < base_num_inputs; ++i) {
@@ -504,8 +580,7 @@ static bool SgDNNLFCInferShape(const nnvm::NodeAttrs& attrs,
   if (full_param.dnnl_param.quantized) {
     mxnet::ShapeVector base_in_shapes;
     mxnet::ShapeVector base_out_shapes;
-    FillBaseInputOutputInfo(
-        full_param.default_param, &base_in_shapes, &base_out_shapes, 
in_shapes, out_shapes);
+    FillBaseInputOutputInfo(full_param, &base_in_shapes, &base_out_shapes, 
in_shapes, out_shapes);
     bool ret = DefaultSubgraphOpShape(attrs, &base_in_shapes, 
&base_out_shapes);
 
     for (size_t i = 0; i < in_shapes->size(); ++i) {
@@ -531,26 +606,43 @@ static bool SgDNNLFCInferType(const nnvm::NodeAttrs& 
attrs,
                               std::vector<int>* out_types) {
   auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
   if (full_param.dnnl_param.quantized) {
-    bool channel_wise = false;
-    if (full_param.dnnl_param.channel_wise_quantize.has_value() &&
-        full_param.dnnl_param.channel_wise_quantize) {
-      channel_wise = true;
-    }
-    size_t base_num_inputs = full_param.default_param.no_bias ? 2 : 3;
-    CHECK(in_types->at(0) == mshadow::kInt8 || in_types->at(0) == 
mshadow::kUint8)
-        << "QuantizedFullyConnected only supports int8/uint8 input, while " << 
in_types->at(0)
-        << " is given.";
-    for (size_t i = 1; i < in_types->size(); ++i) {
-      if (channel_wise) {
-        TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
-      } else {
-        if (i < base_num_inputs) {
-          TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kInt8);
+    const bool channel_wise = 
full_param.dnnl_param.channel_wise_quantize.has_value() &&
+                              full_param.dnnl_param.channel_wise_quantize;
+    const FCInputIndex idx(full_param);
+
+    CHECK(in_types->at(idx.data) == mshadow::kInt8 || in_types->at(idx.data) 
== mshadow::kUint8)
+        << "QuantizedFullyConnected  data input only supports int8/uint8, 
while "
+        << in_types->at(idx.data) << " is given.";
+    if (channel_wise) {
+      TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kFloat32);
+      if (idx.IsBiasExist()) {
+        TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kFloat32);
+      }
+    } else {
+      TYPE_ASSIGN_CHECK(*in_types, idx.weight, mshadow::kInt8);
+      if (idx.IsBiasExist()) {
+        if (in_types->at(idx.bias) == -1) {
+          TYPE_ASSIGN_CHECK(*in_types, idx.bias, mshadow::kInt32);
         } else {
-          TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+          CHECK(in_types->at(idx.bias) == mshadow::kInt8 ||
+                in_types->at(idx.bias) == mshadow::kInt32)
+              << "QuantizedFullyConnected bias input only supports int8/int32, 
while "
+              << in_types->at(idx.bias) << " is given.";
         }
       }
     }
+    if (idx.IsSumExist()) {
+      if (full_param.dnnl_param.enable_float_output) {
+        TYPE_ASSIGN_CHECK(*in_types, idx.sum, mshadow::kFloat32);
+      } else {
+        CHECK(in_types->at(idx.sum) == mshadow::kInt8 || in_types->at(idx.sum) 
== mshadow::kUint8)
+            << "QuantizedFullyConnected sum input only supports int8/uint8, 
while "
+            << in_types->at(idx.sum) << " is given.";
+      }
+    }
+    for (size_t i = idx.data_min; i < in_types->size(); ++i) {
+      TYPE_ASSIGN_CHECK(*in_types, i, mshadow::kFloat32);
+    }
 
     if (full_param.dnnl_param.enable_float_output) {
       TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kFloat32);
@@ -583,8 +675,7 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs& 
attrs,
   if (full_param.dnnl_param.quantized) {
     std::vector<int> base_in_attrs;
     std::vector<int> base_out_attrs;
-    FillBaseInputOutputInfo(
-        full_param.default_param, &base_in_attrs, &base_out_attrs, in_attrs, 
out_attrs);
+    FillBaseInputOutputInfo(full_param, &base_in_attrs, &base_out_attrs, 
in_attrs, out_attrs);
     bool ret = DefaultSubgraphOpStorageType(
         attrs, dev_mask, dispatch_mode, &base_in_attrs, &base_out_attrs);
 
@@ -606,6 +697,15 @@ static bool SgDNNLFCStorageType(const nnvm::NodeAttrs& 
attrs,
   }
 }
 
+std::vector<std::pair<int, int>> SgDNNLFCInplaceOption(const NodeAttrs& attrs) 
{
+  auto const& param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
+  if (param.dnnl_param.with_sum) {
+    return std::vector<std::pair<int, int>>{{FCInputIndex(param).sum, 0}};
+  } else {
+    return std::vector<std::pair<int, int>>();
+  }
+}
+
 static OpStatePtr CreateSgDNNLFCState(const nnvm::NodeAttrs& attrs,
                                       Context ctx,
                                       const mxnet::ShapeVector& in_shapes,
@@ -641,13 +741,16 @@ static bool SgDNNLAvoidFCQuantizeInput(const NodeAttrs& 
attrs,
                                        const std::string quantize_granularity) 
{
   auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
   std::unordered_set<size_t> avoid_indexes;
+  FCInputIndex idx(full_param);
   if (quantize_granularity == "channel-wise") {
     avoid_indexes.insert(fullc::kWeight);  // weight
     if (!full_param.default_param.no_bias) {
       avoid_indexes.insert(fullc::kBias);  // bias
     }
   }
-
+  if (idx.IsSumInputFloat()) {
+    avoid_indexes.insert(idx.sum);
+  }
   return avoid_indexes.count(index_to_check);
 }
 
@@ -656,17 +759,7 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected)
     .describe(R"code(_sg_onednn_fully_connected)code" ADD_FILELINE)
     .set_num_inputs([](const NodeAttrs& attrs) {
       auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
-      auto num_inputs        = full_param.default_param.no_bias ? 2 : 3;
-      if (full_param.dnnl_param.quantized) {
-        if (full_param.dnnl_param.channel_wise_quantize.has_value() &&
-            full_param.dnnl_param.channel_wise_quantize) {
-          return num_inputs + 2;  // min_data, max_data
-        } else {
-          return num_inputs * 3;
-        }
-      } else {
-        return num_inputs;
-      }
+      return FCInputIndex(full_param).GetTotal();
     })
     .set_num_outputs([](const NodeAttrs& attrs) {
       auto const& full_param = nnvm::get<DNNLFCFullParam>(attrs.parsed);
@@ -691,6 +784,7 @@ NNVM_REGISTER_OP(_sg_onednn_fully_connected)
                                 })
     .set_attr<nnvm::FMutateInputs>("FMutateInputs", 
DefaultSubgraphOpMutableInputs)
     .set_attr<std::string>("key_var_num_args", "num_args")
+    .set_attr<nnvm::FInplaceOption>("FInplaceOption", SgDNNLFCInplaceOption)
     .set_attr<FQuantizable>("FQuantizable",
                             [](const NodeAttrs& attrs) { return 
QuantizeType::kMust; })
     .set_attr<FQuantizedOp>("FQuantizedOp", SgDNNLFCQuantizedOp)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_property.h 
b/src/operator/subgraph/dnnl/dnnl_fc_property.h
index 64fd507..b22c5ef 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_property.h
@@ -194,6 +194,9 @@ class SgDNNLFCProperty : public SubgraphProperty {
       auto& sub_name = node->op()->name;
       if (sub_name == "FullyConnected") {
         node_name << "fully_connected_";
+        if (HasAttr("quantize") && GetAttr<bool>("quantize")) {
+          n->attrs.dict["first_quantization_pass"] = "True";
+        }
       } else if (SupportDNNLFCEltwiseFusion(sub_name)) {
         node_name << "eltwise_";
         n->attrs.dict["with_eltwise"] = "True";
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h 
b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
new file mode 100644
index 0000000..4af89c9
--- /dev/null
+++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
@@ -0,0 +1,291 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*
+  \file
+  \brief For fusing FullyConnected operator with element-wise add.
+
+  Element-wise add operator is replaced by DNNL FC "sum" post operator.
+  It adds FC results to existing values in output. For quantized integer 
version
+  this output is scaled to the proper range.
+*/
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#if MXNET_USE_ONEDNN == 1
+
+#include <memory>
+#include <string>
+#include <unordered_set>
+#include <utility>
+#include <vector>
+
+#include "../../tensor/matrix_op-inl.h"
+#include "../common.h"
+#include "dnnl_fc-inl.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+inline bool EndsWith(std::string const& value, std::string const& ending) {
+  if (ending.size() > value.size()) {
+    return false;
+  } else {
+    return std::equal(ending.rbegin(), ending.rend(), value.rbegin());
+  }
+}
+
+class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
+ private:
+  /*! \brief pattern match status */
+  enum SelectStatus {
+    kFail = 0,
+    kStart,
+    kSuccess,
+  };
+
+  bool quantized_;
+  SelectStatus status_ = kFail;
+  std::vector<const BiDirectedNode*> matched_list_;
+
+ public:
+  explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {}
+
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    const auto n = seed_node.node;
+    if (n->op() == Op::Get("_sg_onednn_fully_connected") && 
SupportDNNLAttr(node_attr) &&
+        (seed_node.outputs.size() == 1)) {
+      auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
+      if ((!quantized_ && !fc_param.dnnl_param.first_quantization_pass) ||
+          (fc_param.dnnl_param.quantized && 
!fc_param.dnnl_param.with_eltwise)) {
+        // Start subgraph when fusing for floats (quantized_ is false for DNNL 
backend) or
+        // when FC is already quantized (second pass for DNNL_QUANTIZE) but 
not already fuzed
+        // with elemwise operator.
+        status_ = kStart;
+        matched_list_.clear();
+        matched_list_.push_back(&seed_node);
+        return true;
+      }
+    }
+    return false;
+  }
+
+  bool SelectInput(const BiDirectedNode& cur_node, const BiDirectedNode& 
input_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& 
output_node) override {
+    const auto cur_n    = cur_node.node;
+    const auto output_n = output_node.node;
+    if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
+      return false;
+    }
+    // If n isn't the last matched node, then we encoutered an internal
+    // branch, we should pop out the node behind n and stop fusion.
+    if (matched_list_.back() != &cur_node) {
+      if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != 
matched_list_.end()) {
+        while (matched_list_.back() != &cur_node) {
+          matched_list_.pop_back();
+        }
+      }
+      status_ = kSuccess;
+      return false;
+    }
+
+    switch (status_) {
+      case kStart:
+        // Find _contrib_quantized_elemwise_add or elemwise_add
+        if (EndsWith(output_n->op()->name, "elemwise_add")) {
+          if (quantized_) {
+            auto const& fc_param = 
nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
+            if (!fc_param.dnnl_param.enable_float_output) {
+              // For quantized graph, when FC floating point output is not 
enabled
+              // elementwise add must also be quantized (min and max value 
have to be already stored
+              // in elementwise add).
+              CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
+            }
+          }
+          matched_list_.push_back(&output_node);
+          status_ = kSuccess;
+          return true;
+        }
+      default:
+        status_ = kFail;
+        return false;
+    }
+  }
+
+  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
+    if (status_ == kSuccess) {
+      return candidates;
+    } else {
+      return std::vector<BiDirectedNode*>(0);
+    }
+  }
+
+  void Reset() override {
+    CHECK_GE(matched_list_.size(), 1);
+    auto new_selector = SgDNNLFCSumFuseSelector(quantized_);
+    new_selector.Select(*matched_list_[0], nullptr);
+    *this = new_selector;
+  }
+};
+
+class SgDNNLFCSumFuseProperty : public SubgraphProperty {
+ public:
+  SgDNNLFCSumFuseProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string& name = "DNNL fuse FullyConnected with sum";
+    auto property                  = 
std::make_shared<SgDNNLFCSumFuseProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+                                     const int subgraph_id = 0) const override 
{
+    nnvm::ObjectPtr fc_node     = nullptr;
+    nnvm::ObjectPtr ew_add_node = nullptr;
+
+    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
+      if (node->is_variable()) {
+        return;
+      }
+      auto& sub_name = node->op()->name;
+      if (sub_name == "_sg_onednn_fully_connected") {
+        fc_node = node;
+      } else if (EndsWith(sub_name, "elemwise_add")) {
+        ew_add_node = node;
+      }
+    });
+
+    CHECK_NOTNULL(fc_node);
+    if (ew_add_node != nullptr) {
+      CHECK_NOTNULL(fc_node->attrs.subgraphs[0]);
+      auto subgraph_output_node = fc_node->attrs.subgraphs[0]->outputs[0].node;
+      nnvm::Symbol new_sym;
+      // Create a new elemwise_add node to not alter the original one.
+      // It is needed in subgraph to properly calculate InferShape.
+      nnvm::ObjectPtr n = nnvm::Node::Create();
+      n->attrs.op       = Op::Get("elemwise_add");
+      n->attrs.name     = ew_add_node->attrs.name;
+
+      if (ew_add_node->inputs[0].node == fc_node) {
+        n->inputs.emplace_back(subgraph_output_node);
+        n->inputs.emplace_back(ew_add_node->inputs[1]);
+      } else {
+        n->inputs.emplace_back(ew_add_node->inputs[0]);
+        n->inputs.emplace_back(subgraph_output_node);
+      }
+      new_sym.outputs.emplace_back(n);
+      fc_node->attrs.subgraphs.clear();
+      
fc_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(new_sym));
+      fc_node->attrs.dict["with_sum"] = "True";
+      fc_node->attrs.dict.erase("first_quantization_pass");  // Removed as not 
needed any longer
+      fc_node->op()->attr_parser(&(fc_node->attrs));
+    }
+    return fc_node;
+  }
+
+  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+    bool quantized = HasAttr("quantize") ? GetAttr<bool>("quantize") : false;
+    auto selector  = std::make_shared<SgDNNLFCSumFuseSelector>(quantized);
+    return selector;
+  }
+
+  void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+                              std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
+    // Connect all extern output entries to output[0]
+    for (size_t i = 0; i < output_entries->size(); ++i) {
+      auto entry_ptr = output_entries->at(i);
+      *entry_ptr     = nnvm::NodeEntry{n, entry_ptr->index, 0};
+    }
+  }
+
+  void ConnectSubgraphInputs(const nnvm::ObjectPtr n,
+                             std::vector<nnvm::NodeEntry*>* input_entries,
+                             std::vector<nnvm::NodeEntry>* orig_input_entries) 
const override {
+    auto sym             = n->attrs.subgraphs[0];
+    auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
+    std::unordered_set<const nnvm::Node*> node_sets;
+    DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
+      if (node->is_variable()) {
+        return;
+      }
+      node_sets.insert(node.get());
+      if (EndsWith(node->op()->name, "elemwise_add")) {
+        const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
+        // Make sure fc output is the left operand of the add operator, if not:
+        // - swap inputs of add operator
+        // - switch add operands sequence to ensure that
+        // the tensor (sum_tensor) to which FC output is added is the last 
input.
+        if (node_sets.count(node->inputs[1].node.get())) {
+          // Example of input_entries reordering for channel-wise quantized 
graph:
+          // sum_tensor.data    -->   fc.data
+          // fc.data            -->   fc.weight0
+          // fc.weight0         -->   fc.bias0
+          // fc.bias0           -->   sum_tensor.data
+          // fc_out.min         -->   fc_out.min
+          // fc_out.max         -->   fc_out.max
+          // sum_tensor.min     -->   sum_tensor.min
+          // sum_tensor.max     -->   sum_tensor.max
+          std::swap(node->inputs[0], node->inputs[1]);
+          std::rotate(input_entries->begin(),
+                      input_entries->begin() + 1,
+                      input_entries->begin() + base_inputs);
+          std::rotate(orig_input_entries->begin(),
+                      orig_input_entries->begin() + 1,
+                      orig_input_entries->begin() + base_inputs);
+        } else {
+          // Example of input_entries reordering for channel-wise quantized 
graph:
+          // fc.data            -->   fc.data
+          // fc.weight0         -->   fc.weight0
+          // fc.bias0           -->   fc.bias0
+          // fc_out.min         -->   sum_tensor.data
+          // fc_out.max         -->   fc_out.min
+          // sum_tensor.data    -->   fc_out.max
+          // sum_tensor.min     -->   sum_tensor.min
+          // sum_tensor.max     -->   sum_tensor.max
+          const int not_rotated_end =
+              (fc_param.dnnl_param.quantized && 
!fc_param.dnnl_param.enable_float_output) ? 2 : 0;
+
+          std::rotate(input_entries->begin() + base_inputs - 1,
+                      input_entries->end() - 1 - not_rotated_end,
+                      input_entries->end() - not_rotated_end);
+          std::rotate(orig_input_entries->begin() + base_inputs - 1,
+                      orig_input_entries->end() - 1 - not_rotated_end,
+                      orig_input_entries->end() - not_rotated_end);
+        }
+      }
+    });
+    n->inputs = *orig_input_entries;
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index d9cb6c0..1aa52ca 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -45,6 +45,9 @@ const std::set<std::string> support_req_fusion_op = 
{"_contrib_quantized_elemwis
                                                      "_sg_onednn_selfatt_qk",
                                                      
"_sg_onednn_selfatt_valatt",
                                                      "_sg_onednn_batch_dot"};
+
+const std::set<const Op*> no_enable_float_output = 
{Op::Get("_contrib_quantized_elemwise_add"),
+                                                    
Op::Get("_sg_onednn_conv")};
 }  // namespace
 
 class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
@@ -110,7 +113,8 @@ class SgDNNLPostQuantizeSelector : public 
SubgraphSelectorV2 {
           if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
             matched_list.emplace_back(&new_node);
             status = SelectStatus::kRequantize;
-            if (raw_node->op() == Op::Get("_sg_onednn_conv")) {
+            if ((raw_node->op() == Op::Get("_sg_onednn_conv")) ||
+                (raw_node->op() == 
Op::Get("_contrib_quantized_elemwise_add"))) {
               status = SelectStatus::kSuccess;
             }
             return true;
@@ -210,7 +214,7 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty {
 
     // When only fused quantized operator and requantize, set 
min/max_cablib_range,
     // When fused quantized operator + requantize + dequantize, set dequantize 
flag to true.
-    if (dequantize_node != nullptr) {
+    if ((dequantize_node != nullptr) && 
(no_enable_float_output.count(fuse_node->op()) == 0)) {
       fuse_node->attrs.dict["enable_float_output"] = "True";
     } else {
       fuse_node->attrs.dict["min_calib_range"] =
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc 
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 8f8fc44..b3b23e1 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -28,6 +28,7 @@
 #include "dnnl_post_quantize_property.h"
 #include "dnnl_transformer_qk_property.h"
 #include "dnnl_transformer_valatt_property.h"
+#include "dnnl_fc_sum_fuse.h"
 
 namespace mxnet {
 namespace op {
@@ -43,6 +44,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, 
SgDNNLBNReLUProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCSumFuseProperty);
 
 MXNET_REGISTER_SUBGRAPH_BACKEND(ONEDNN_QUANTIZE).set_attr("context", 
Context::CPU());
 
@@ -55,6 +57,8 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLBatchDotProperty)
     .set_attr("quantize", true);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLPostQuantizeAlignScaleProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLFCSumFuseProperty)
+    .set_attr("quantize", true);
 
 }  // namespace op
 }  // namespace mxnet
diff --git a/tests/python/dnnl/subgraphs/subgraph_common.py 
b/tests/python/dnnl/subgraphs/subgraph_common.py
index 37b14c8..b3bf5b0 100644
--- a/tests/python/dnnl/subgraphs/subgraph_common.py
+++ b/tests/python/dnnl/subgraphs/subgraph_common.py
@@ -114,8 +114,27 @@ def check_qsym_scale_align(qsym):
         assert max_calib_range == v['max_calib_range']
 
 
-def check_quantize(net_original, data_shape, out_type, name='conv',
-                   check_calibration=True, check_scale_align=False):
+def check_fusion_parameter(sym, attrs_dict):
+  for name, attrs in attrs_dict.items():
+    if name in config:
+      op_name = config[name][OP_NAME]
+    else:
+      op_name = name
+    assert ''.join(sym.get_internals().list_outputs()).find(op_name) != -1
+    if len(attrs):
+      found = False
+      for k, v in sym.attr_dict().items():
+        if k.find('_quantize') != -1:
+          continue
+        if k.find(op_name) != -1:
+          found = True
+          for attr_name, attr_value in attrs.items():
+            assert v[attr_name].lower() == attr_value.lower()
+      assert found
+
+def check_quantize(net_original, data_shapes, out_type, name='conv',
+                   check_calibration=True, check_scale_align=False, 
quantize_mode='full',
+                   attrs_dict={}):
   quantize_granularity_list = ['tensor-wise']
   if name == 'fc':
     quantize_granularity_list += ['channel-wise']
@@ -125,92 +144,108 @@ def check_quantize(net_original, data_shape, out_type, 
name='conv',
 
   net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True)
   min_value = -1 if out_type != 'uint8' else 0
-  data = mx.np.random.uniform(min_value, 1.0, size=data_shape, 
dtype='float32', ctx=mx.current_device())
-
-  outputs = net_original(data)
+  one_shape = isinstance(data_shapes, tuple)
+  if one_shape:
+    # replace one shape with list of shapes with one element inside to follow 
later the same schema
+    data_shapes=[data_shapes]
+  data = []
+  for shape in data_shapes:
+    data.append(mx.np.random.uniform(min_value, 1.0, size=shape, 
dtype='float32', device=mx.cpu()))
+
+  outputs = net_original(*data)
   for output in outputs:
       output.wait_to_read()
   ref_out = outputs
 
-  calib_data = mx.gluon.data.DataLoader(data, batch_size=1)
+  dataArray= mx.gluon.data.ArrayDataset(*data)
+
+  calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1)
   for quantize_granularity in quantize_granularity_list:
     qnet = quantization.quantize_net(net_original,
-                                     ctx=mx.current_device(),
+                                     device=mx.cpu(),
                                      exclude_layers=None,
                                      exclude_operators=None,
                                      quantized_dtype=out_type,
                                      calib_mode='naive',
                                      calib_data=calib_data,
                                      num_calib_batches=1,
-                                     quantize_mode='full',
+                                     quantize_mode=quantize_mode,
                                      quantize_granularity=quantize_granularity)
     qsym, _ = qnet.export(None)
+    check_fusion_parameter(qsym, attrs_dict)
     if check_calibration:
       check_qsym_calibrated(qsym, out_type, name=name)
     if check_scale_align:
       check_qsym_scale_align(qsym)
 
-    quantized_out = qnet(data)
+    quantized_out = qnet(*data)
     for i in range(len(ref_out)):
       min_range = mx.np.min(ref_out[i]).item()
       max_range = mx.np.max(ref_out[i]).item()
       atol = 0.1 * max(abs(min_range), abs(max_range))
-      assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(), 
rtol=0.1, atol=atol, etol=0.2)
+      assert_almost_equal_with_err(quantized_out.asnumpy(), ref_out.asnumpy(),
+                                   rtol=0.1, atol=atol, etol=0.2)
 
 
-def check_fusion(net_original, data_shape, attrs_dict, check_fp32_fusion=True, 
check_quantization=True,
-                 out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True):
+def check_fusion(net_original, data_shapes, attrs_dict, check_fp32_fusion=True,
+                 check_quantization=True, out_types=['uint8', 'int8', 'auto'], 
dedup_subgraph=True,
+                 quantize_mode='full'):
   net_original.initialize()
   net_original.hybridize(static_alloc=False, static_shape=False)
-  data = mx.np.random.uniform(size=data_shape, dtype='float32', 
ctx=mx.current_device())
-  net_original(data)
+  one_shape = isinstance(data_shapes, tuple)
+  data_min = -1.0
+  data_max = 1.0
+
+  if one_shape:
+    # replace one shape with list of shapes with one element to follow later 
the same schema
+    data_shapes=[data_shapes]
+  data = []
+  for shape in data_shapes:
+    data.append(mx.np.random.uniform(size=shape, dtype='float32', 
device=mx.cpu(),
+                low=data_min, high=data_max))
+  net_original(*data)
   net_fusion = copy.copy(net_original)
   sym, params = net_original.export(None)
 
   if check_fp32_fusion:
-    data_min = -1.0
-    data_max = 1.0
     if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1:
       check_quantization = False
       data_min = 0
 
     sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=dedup_subgraph, 
skip_infer=True)
-    for name, attrs in attrs_dict.items():
-      if name in config:
-        op_name = config[name][OP_NAME]
-      else:
-        op_name = name
-      assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1
-      if len(attrs):
-          found = False
-          for k, v in sym_sg.attr_dict().items():
-            if k.find(op_name) != -1:
-              found = True
-              for attr_name, attr_value in attrs.items():
-                assert v[attr_name].lower() == attr_value.lower()
-          assert found
-
-    data = mx.np.random.uniform(size=data_shape, low=data_min, high=data_max)
-    out_unfused = net_original(data)
-
-    net_fusion.optimize_for(data, backend=SG_PASS_NAME)
-    out_fused = net_fusion(data)
+    check_fusion_parameter(sym_sg, attrs_dict)
+    if data_min == 0 and mx.npx.is_np_default_dtype():
+      # regenerate inputs if they have different range or data type
+      data = []
+      for shape in data_shapes:
+        data.append(mx.np.random.uniform(size=shape, device=mx.cpu(), 
low=data_min, high=data_max))
+    out_unfused = net_original(*data)
+
+    net_fusion.optimize_for(*data, backend=SG_PASS_NAME)
+    out_fused = net_fusion(*data)
 
     assert_almost_equal(out_unfused.asnumpy(), out_fused.asnumpy(), rtol=1e-3, 
atol=1e-1)
 
   if check_quantization:
     # fp32 to int8
     for out_type in out_types:
-      check_quantize(net_original, data_shape, out_type, name=name)
+      check_quantize(net_original, data_shapes, out_type, 
name=list(attrs_dict.keys())[0],
+                     quantize_mode=quantize_mode, attrs_dict=attrs_dict)
 
 def check_neg_fusion(net_original, attrs_name=None, excluded_attrs=None,
-                     data_shapes=(4,4,10,10), name='conv'):
+                     data_shapes=[(4,4,10,10)], name='conv'):
   op_name = config[name][OP_NAME]
+  one_shape = isinstance(data_shapes, tuple)
+  if one_shape:
+    # replace one shape with list of shapes with one element to follow later 
the same schema
+    data_shapes = [data_shapes]
+  data = []
+  for shape in data_shapes:
+    data.append(mx.np.random.uniform(size=shape))
 
-  data_nd = mx.np.random.uniform(size=data_shapes)
   net_original.initialize()
   net_original.hybridize()
-  net_original(data_nd)
+  net_original(*data)
 
   sym, _ = net_original.export(None)
   sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=True, skip_infer=True)
@@ -221,4 +256,41 @@ def check_neg_fusion(net_original, attrs_name=None, 
excluded_attrs=None,
       for attr in attrs_name:
         assert v[attr] == 'true'
       for exc_attr in excluded_attrs:
-        assert exc_attr not in v.keys()
+        assert exc_attr not in v.keys(), exc_attr + " atribute shouldn't exist"
+
+
+
+def check_neg_fusion_quantized(net_original, attrs_name=None, 
excluded_attrs=None,
+                     data_shapes=[(4,4,10,10)], name='conv'):
+  op_name = config[name][OP_NAME]
+  net_original.initialize(init=mx.init.Normal(0.5), force_reinit=True)
+  one_shape = isinstance(data_shapes, tuple)
+  if one_shape:
+    # replace one shape with list of shapes with one element inside to follow 
later the same schema
+    data_shapes=[data_shapes]
+  data = []
+  for shape in data_shapes:
+    data.append(mx.np.random.uniform(size=shape, dtype='float32', 
device=mx.cpu()))
+
+  dataArray= mx.gluon.data.ArrayDataset(*data)
+  calib_data = mx.gluon.data.DataLoader(dataArray, batch_size=1)
+
+  qnet = quantization.quantize_net(net_original,
+                                    device=mx.cpu(),
+                                    exclude_layers=None,
+                                    exclude_operators=None,
+                                    quantized_dtype='int8',
+                                    calib_mode='naive',
+                                    calib_data=calib_data,
+                                    num_calib_batches=1,
+                                    quantize_mode='full',
+                                    quantize_granularity='tensor-wise')
+  qsym, _ = qnet.export(None)
+  attrs_dict = qsym.attr_dict()
+  for k, v in attrs_dict.items():
+    if k.find(op_name) != -1:
+      for attr in attrs_name:
+        assert v[attr] == 'true'
+      for exc_attr in excluded_attrs:
+        assert exc_attr not in v.keys(), exc_attr + " atribute shouldn't exist"
+
diff --git a/tests/python/dnnl/subgraphs/test_conv_subgraph.py 
b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
index 6b6169b..e7dac8f 100644
--- a/tests/python/dnnl/subgraphs/test_conv_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_conv_subgraph.py
@@ -91,7 +91,7 @@ def test_pos_conv_add(use_bias, data_shape):
     
   attr = {'conv': {'with_sum': 'true'}}
   net = ConvAdd(use_bias=use_bias)
-  check_fusion(net, data_shape, attr)
+  check_fusion(net, data_shape, attr, check_quantization=False)
 
 
 @mx.util.use_np
@@ -112,14 +112,14 @@ def test_pos_conv_add2(no_bias, data_shape):
 
   attr = {'conv': {'with_sum': 'true'}}
   net = ConvAdd(use_bias=True)
-  check_fusion(net, data_shape, attr)
+  check_fusion(net, data_shape, attr, check_quantization=False)
 
 
 @mx.util.use_np
 @pytest.mark.parametrize('data_shape', DATA_SHAPE)
 @pytest.mark.parametrize('alg,quantize', [
     ("relu", False), #TODO(bgawrych): investigate
-    ("sigmoid", True),
+    ("sigmoid", False),
     ("log_sigmoid", False),
     ("mish", False),
     ("tanh", False), #TODO(bgawrych): investigate
@@ -162,11 +162,11 @@ def test_pos_conv_act_add(data_shape, alg, quantize, 
use_bias):
 @pytest.mark.parametrize('data_shape', DATA_SHAPE)
 @pytest.mark.parametrize('alg,quantize', [
     ("relu", True),
-    ("sigmoid", True),
-    ("log_sigmoid", True),
-    ("mish", True),
-    ("tanh", True),
-    ("softrelu", True),
+    ("sigmoid", False),
+    ("log_sigmoid", False),
+    ("mish", False),
+    ("tanh", False),
+    ("softrelu", False),
     ("relu6", True),
     ("leakyrelu", True),
     ("gelu", True)
@@ -200,14 +200,14 @@ def test_pos_conv_bn_act(use_bias, data_shape, alg, 
quantize):
 @mx.util.use_np
 @pytest.mark.parametrize('data_shape', DATA_SHAPE)
 @pytest.mark.parametrize('alg,quantize', [
-    ("relu", True),
-    ("sigmoid", True),
-    ("log_sigmoid", True),
-    ("mish", True),
-    ("tanh", True),
+    ("relu", False),
+    ("sigmoid", False),
+    ("log_sigmoid", False),
+    ("mish", False),
+    ("tanh", False),
     #("softrelu", True), #TODO(bgawrych): failing fusion check - difference in 
random single element
-    ("relu6", True),
-    ("leakyrelu", True),
+    ("relu6", False),
+    ("leakyrelu", False),
     ("gelu", False) #TODO: for True we get assert instead of not fusing pattern
 ])
 @pytest.mark.parametrize('use_bias', [True, False])
@@ -321,11 +321,11 @@ def test_pos_concat_scale_align(data_shape, out_type):
 @pytest.mark.parametrize('data_shape', DATA_SHAPE)
 @pytest.mark.parametrize('alg,quantize', [
     ("relu", True),
-    ("sigmoid", True),
-    ("log_sigmoid", True),
-    ("mish", True),
-    ("tanh", True),
-    ("softrelu", True),
+    ("sigmoid", False),
+    ("log_sigmoid", False),
+    ("mish", False),
+    ("tanh", False),
+    ("softrelu", False),
     ("relu6", True),
     ("leakyrelu", True),
     ("gelu", True)
diff --git a/tests/python/dnnl/subgraphs/test_fc_subgraph.py 
b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
index 223a55d..c63bb9a 100644
--- a/tests/python/dnnl/subgraphs/test_fc_subgraph.py
+++ b/tests/python/dnnl/subgraphs/test_fc_subgraph.py
@@ -17,7 +17,7 @@
 
 import mxnet as mx
 import pytest
-from subgraph_common import check_fusion, check_neg_fusion
+from subgraph_common import check_fusion, check_neg_fusion, 
check_neg_fusion_quantized
 from subgraph_common import CustomNormalInit, DATA_SHAPE, TailNegBlock
 from mxnet.contrib import quantization
 from mxnet.gluon import nn
@@ -89,9 +89,11 @@ def test_fc_eltwise(data_shape, use_bias, flatten, alg):
         out = mx.np.clip(fc_out, 0, 1.0)
       return out
 
+  not_quant_fuze = ['sigmoid', 'log_sigmoid', 'softrelu', 'tanh', 'mish', 
'square', 'square_root',
+                    'exp']
   attrs = {'fc': {'with_eltwise': 'true'}}
   net = FCEltwise(use_bias, flatten, alg)
-  check_fusion(net, data_shape, attrs, check_quantization=flatten)
+  check_fusion(net, data_shape, attrs, check_quantization=flatten and not alg 
in not_quant_fuze)
 
 
 @mx.util.use_np
@@ -148,7 +150,7 @@ def test_quantized_fc_bias_overflow(data_min, data_max, 
weight_min, weight_max):
         conv1 = mx.npx.fully_connected(x, num_hidden=64, 
weight=self.weight.data(x.device),
                                        no_bias=False, 
bias=self.bias.data(x.device))
         return conv1
-    
+
     def infer_shape(self, x, *args):
         self.weight.shape = (64, x.shape[x.ndim-1])
         self.bias.shape = (64,)
@@ -232,3 +234,139 @@ def test_fc_identity_eltwise(identity_node):
            'sg_onednn_fully_connected_eltwise_1' : {'with_eltwise': 'true'}}
   net = FCIdentityEltwise(identity_node)
   check_fusion(net, data_shape, attrs, check_quantization=False)
+
+
+def function_fc_add(data_shape, add_op, quantize_mode, fc_out_add, flatten, 
relu, out_type):
+  class FCWithSumExample(nn.HybridBlock):
+    def __init__(self,  num_hidden, add_op, fc_out_add, **kwargs):
+      super(FCWithSumExample, self).__init__(**kwargs)
+      self.fca = nn.Dense(units=num_hidden, flatten=flatten)
+      self.elemwise_add = (add_op == 'elemwise_add')
+      self.fc_out_as_rhs = (fc_out_add == 'rhs')
+      self.relu = (relu == 'leaky_relu')
+
+    def forward(self, data1a, data2):
+      fc_out = self.fca(data1a)
+      if self.relu:
+        fc_out = mx.npx.leaky_relu(fc_out, act_type='gelu')
+      if self.fc_out_as_rhs:
+        if  self.elemwise_add:
+          sum1 = mx.nd.elemwise_add(data2.as_nd_ndarray(), 
fc_out.as_nd_ndarray()).as_np_ndarray()
+        else:
+          sum1 = data2 + fc_out
+      else:
+        if  self.elemwise_add:
+          sum1 = mx.nd.elemwise_add(fc_out.as_nd_ndarray(), 
data2.as_nd_ndarray()).as_np_ndarray()
+        else:
+          sum1 = fc_out + data2
+      return sum1
+
+  attrs = {'fc': {'with_sum': 'true'}}
+  if quantize_mode is not None:
+    attrs['fc']['quantized'] = 'true'
+    if quantize_mode == 'smart':
+      attrs['fc']['enable_float_output'] = 'true'
+  num_hidden=10
+  net = FCWithSumExample(num_hidden, add_op, fc_out_add)
+  if flatten:
+    data_shapes = [data_shape, (data_shape[0], num_hidden)]
+  else:
+    data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)]
+  check_fusion(net, data_shapes, attrs,
+               out_types=[out_type],
+               check_fp32_fusion=(quantize_mode is None),
+               check_quantization=(quantize_mode is not None) and flatten,
+               quantize_mode=quantize_mode)
+
[email protected]_np
[email protected]('data_shape', DATA_SHAPE)
[email protected]('relu', ['noleaky_re', 'leaky_relu'])
[email protected]('flatten', ['flat', 'nofl'])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('add_op', ['elemwise_add'])
+def test_fc_add(data_shape, add_op, fc_out_add, flatten, relu):
+  function_fc_add(data_shape, add_op, None, fc_out_add, flatten=='flat', relu, 
None)
+
[email protected]_np
[email protected](1234) # Seed set because the test is not robust enough to 
operate on random data
[email protected]('data_shape', DATA_SHAPE)
[email protected]('quantize_mode', ['full', 'smart'])
[email protected]('out_type', ['int8', 'auto'])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('add_op', ['elemwise_add'])
+def test_fc_add_quantized(data_shape, add_op, quantize_mode, fc_out_add, 
out_type):
+  function_fc_add(data_shape, add_op, quantize_mode, fc_out_add, True, 
'noleaky_re', out_type)
+
+
+class NegFCAdd(nn.HybridBlock):
+  #
+  #  data  --------------------------> 'add_op'  ------------>
+  #                                   /                        \
+  #  sg_oned_dnn_fully_connected ---->                           npi_add -->
+  #                                   \                        /
+  #                                    npi_multiply_scalar -->
+  def __init__(self, num_hidden, add_op, fc_out_add, scaled_fc_out, flatten, 
**kwargs):
+    super(NegFCAdd, self).__init__(**kwargs)
+    self.fca = nn.Dense(units=num_hidden, flatten=flatten)
+    self.elemwise_add = (add_op == 'elemwise_add')
+    self.fc_out_as_rhs = (fc_out_add == 'rhs')
+    self.scaled_fc_out_as_rhs = (scaled_fc_out == 's_rhs')
+
+  def forward(self, data1a, data2):
+    fc_out = self.fca(data1a)
+    scaled_fc_out = fc_out * 200.0
+    if self.fc_out_as_rhs:
+      if  self.elemwise_add:
+        sum1 = mx.nd.elemwise_add(data2.as_nd_ndarray(), 
fc_out.as_nd_ndarray()).as_np_ndarray()
+      else:
+        sum1 = data2 + fc_out
+    else:
+      if  self.elemwise_add:
+        sum1 = mx.nd.elemwise_add(fc_out.as_nd_ndarray(), 
data2.as_nd_ndarray()).as_np_ndarray()
+      else:
+        sum1 = fc_out + data2
+    if self.scaled_fc_out_as_rhs:
+      sum2 = sum1 + scaled_fc_out
+    else:
+      sum2 = scaled_fc_out + sum1
+    return sum2
+
[email protected]_np
[email protected]('add_op', ['elemwise_add'])
[email protected]('data_shape', [DATA_SHAPE[0]])
[email protected]('flatten', ['flat', 'nofl'])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('scaled_fc_out', ['s_lhs', 's_rhs'])
+def test_neg_fc_add(data_shape, add_op, flatten, fc_out_add, scaled_fc_out):
+  '''
+  Test if FullyConnected operator which output is not used for only one 
'add_op' input is not fused.
+  See NegFCAdd for used graph example
+  '''
+  flatten = (flatten == 'flat')
+  num_hidden = 10
+  net = NegFCAdd(num_hidden, add_op, fc_out_add, scaled_fc_out, flatten)
+  if flatten:
+    data_shapes = [data_shape, (data_shape[0], num_hidden)]
+  else:
+    data_shapes = [data_shape, (*data_shape[0:-1], num_hidden)]
+  attrs = []
+  excluded_attrs = ['with_sum']
+  check_neg_fusion(net, attrs, excluded_attrs, data_shapes, name='fc')
+
[email protected]_np
[email protected]('add_op', ['elemwise_add'])
[email protected]('data_shape', [DATA_SHAPE[1]])
[email protected]('fc_out_add', ['lhs', 'rhs'])
[email protected]('scaled_fc_out', ['s_lhs', 's_rhs'])
+def test_neg_fc_add_quantized(data_shape, add_op, fc_out_add, scaled_fc_out):
+  '''
+  Test if FullyConnected operator which output is not used for only one 
'add_op' input
+  is not fused for quantized model.
+  See NegFCAdd for used graph example.
+  '''
+  num_hidden = 10
+  net = NegFCAdd(num_hidden, add_op, fc_out_add, scaled_fc_out, True)
+  data_shapes = [data_shape, (data_shape[0], num_hidden)]
+  attrs = []
+  excluded_attrs = ['with_sum']
+  check_neg_fusion_quantized(net, attrs, excluded_attrs, data_shapes, 
name='fc')

Reply via email to