This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 26f9fa6  Unifying oneDNN post-quantization properties (#20724)
26f9fa6 is described below

commit 26f9fa6cdfc08de600bf0202534e17e64e4b9134
Author: DominikaJedynak <[email protected]>
AuthorDate: Thu Nov 25 12:16:24 2021 +0100

    Unifying oneDNN post-quantization properties (#20724)
    
    * * Unifying post-quantization properties
    
    * Compatibility and review fixes
    
    * Review changes
    
    * Small fix
---
 .../dnnl/dnnl_elemwisemul_post_quantize_property.h | 231 ---------------------
 .../subgraph/dnnl/dnnl_fc_post_quantize_property.h | 230 --------------------
 .../dnnl/dnnl_matmul_post_quantize_property.h      | 202 ------------------
 .../subgraph/dnnl/dnnl_post_quantize_property.h    | 189 +++++++++++------
 .../subgraph/dnnl/dnnl_subgraph_property.cc        |   7 -
 5 files changed, 125 insertions(+), 734 deletions(-)

diff --git 
a/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h
deleted file mode 100644
index 5e015cb..0000000
--- a/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h
+++ /dev/null
@@ -1,231 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_elemwisemul_post_quantize_property.cc
- * \brief Partition gragph property for oneDNN Quantized ElemwiseMul operator
- * \author Xinyu Chen
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
-#if MXNET_USE_ONEDNN == 1
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "../../quantization/requantize-inl.h"
-#include "../../tensor/elemwise_binary_op-inl.h"
-#include "../common.h"
-#include "dnnl_subgraph_base-inl.h"
-
-namespace mxnet {
-namespace op {
-
-#define QUANTIZED_ElemwiseMul_NAME "_contrib_quantized_elemwise_mul"
-
-class ElemwiseMulPostQuantizeSelector : public SubgraphSelectorV2 {
- public:
-  /*! \brief pattern match status */
-  enum SelectStatus {
-    kFail = 0,
-    kStart,
-    kRequantize,
-    kSuccess,
-  };
-
- private:
-  bool disable_all;
-  bool disable_float_output;
-  SelectStatus status;
-  std::vector<const BiDirectedNode*> matched_list;
-
- public:
-  explicit ElemwiseMulPostQuantizeSelector(const bool dis_all, const bool 
dis_float_output)
-      : disable_all(dis_all), disable_float_output(dis_float_output) {}
-
-  bool Select(const BiDirectedNode& n) override {
-    const auto rawnode = n.node;
-    if ((!disable_all) && rawnode->op() == 
Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
-      status = disable_all ? kSuccess : kStart;
-      matched_list.clear();
-      matched_list.push_back(&n);
-      return true;
-    }
-    return false;
-  }
-
-  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node) 
override {
-    return false;
-  }
-
-  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) 
override {
-    const auto raw_node     = n.node;
-    const auto raw_new_node = new_node.node;
-    if (status == kFail || status == kSuccess || raw_new_node->is_variable())
-      return false;
-    // If n isn't the last matched node, then we encoutered a internal
-    // branch, we should pop out the node behind n and stop fusion.
-    if (matched_list.back() != &n) {
-      if (std::find(matched_list.begin(), matched_list.end(), &n) != 
matched_list.end()) {
-        while (matched_list.back() != &n) {
-          matched_list.pop_back();
-        }
-      }
-
-      status = kSuccess;
-      return false;
-    }
-
-    switch (status) {
-      case kStart:
-        if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
-          auto const& param = 
nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
-          if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
-            matched_list.push_back(&new_node);
-            status = kRequantize;
-            return true;
-          }
-        }
-      case kRequantize:
-        if ((!disable_float_output) && (raw_new_node->op() == 
Op::Get("_contrib_dequantize"))) {
-          CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
-          if (n.outputs.size() > 1) {
-            // check if requantize have other outputs than dequantize
-            // if it has we can't fuse dequantize into elemwise_mul
-            for (auto kv : n.outputs) {
-              const auto& node = kv.first;
-              if (node->op() != Op::Get("_contrib_dequantize")) {
-                status = kSuccess;
-                return false;
-              }
-            }
-          }
-
-          matched_list.push_back(&new_node);
-          status = kSuccess;
-          return true;
-        }
-      default:
-        status = kSuccess;
-        return false;
-    }
-  }
-
-  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
-    if ((status != kSuccess) || (matched_list.size() <= 1)) {
-      return std::vector<BiDirectedNode*>(0);
-    } else {
-      std::vector<BiDirectedNode*> ret;
-      for (auto i : matched_list) {
-        auto non_const_i = const_cast<BiDirectedNode*>(i);
-        if (std::find(candidates.begin(), candidates.end(), non_const_i) != 
candidates.end()) {
-          ret.push_back(non_const_i);
-        }
-      }
-      return ret;
-    }
-  }
-
-  void Reset() override {
-    CHECK_GE(matched_list.size(), 1);
-    auto new_selector = ElemwiseMulPostQuantizeSelector(disable_all, 
disable_float_output);
-    new_selector.Select(*matched_list[0]);
-    *this = new_selector;
-  }
-};
-
-class ElemwiseMulPostQuantizeProperty : public SubgraphProperty {
- public:
-  ElemwiseMulPostQuantizeProperty() {
-    disable_fuse_all     = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QEM_FUSE_ALL", 
false);
-    disable_float_output = 
dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QEM_FLOAT_OUTPUT", false);
-  }
-
-  static SubgraphPropertyPtr Create() {
-    static const std::string& name = "oneDNN EltwiseMul post-quantization 
optimization pass";
-    auto property                  = 
std::make_shared<ElemwiseMulPostQuantizeProperty>();
-    property->SetAttr<std::string>("property_name", name);
-    property->SetAttr<bool>("inference_only", true);
-    return property;
-  }
-
-  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
-                                     const int subgraph_id = 0) const override 
{
-    nnvm::ObjectPtr em_node         = nullptr;
-    nnvm::ObjectPtr requantize_node = nullptr;
-    nnvm::ObjectPtr dequantize_node = nullptr;
-
-    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
-      if (node->is_variable())
-        return;
-      if (node->op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
-        em_node = node;
-      } else if (node->op() == Op::Get("_contrib_requantize")) {
-        requantize_node = node;
-      } else if (node->op() == Op::Get("_contrib_dequantize")) {
-        dequantize_node = node;
-      }
-    });
-
-    CHECK_NOTNULL(em_node);
-    CHECK_NOTNULL(requantize_node);
-    auto const& requantize_param = 
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
-    CHECK(requantize_param.min_calib_range.has_value());
-    CHECK(requantize_param.max_calib_range.has_value());
-
-    // When only fused quantized_elemwise_mul and requantize, set 
min/max_cablib_range,
-    // When fused quantized_elemwise_mul + requantize + dequantize, set 
dequantize flag to true.
-    if (dequantize_node != nullptr) {
-      em_node->attrs.dict["enable_float_output"] = "True";
-    } else {
-      em_node->attrs.dict["min_calib_range"] =
-          std::to_string(requantize_param.min_calib_range.value());
-      em_node->attrs.dict["max_calib_range"] =
-          std::to_string(requantize_param.max_calib_range.value());
-    }
-    em_node->op()->attr_parser(&(em_node->attrs));
-    return em_node;
-  }
-
-  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
-    auto selector =
-        std::make_shared<ElemwiseMulPostQuantizeSelector>(disable_fuse_all, 
disable_float_output);
-    return selector;
-  }
-
-  void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
-                              std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
-    for (size_t i = 0; i < output_entries->size(); ++i) {
-      auto entry_ptr = output_entries->at(i);
-      *entry_ptr     = nnvm::NodeEntry{n, entry_ptr->index, 0};
-    }
-  }
-
- private:
-  bool disable_fuse_all;
-  bool disable_float_output;
-};
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // if MXNET_USE_ONEDNN == 1
-#endif  // 
MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h
deleted file mode 100644
index b1ae537..0000000
--- a/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h
+++ /dev/null
@@ -1,230 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_fc_post_quantize_property.cc
- * \brief Partition gragph property for oneDNN Quantized FullyConnected 
operator
- * \author Ciyong Chen
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_
-#if MXNET_USE_ONEDNN == 1
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "../../nn/fully_connected-inl.h"
-#include "../../quantization/requantize-inl.h"
-#include "../common.h"
-#include "dnnl_subgraph_base-inl.h"
-
-namespace mxnet {
-namespace op {
-
-#define QUANTIZED_FC_NAME "_sg_onednn_fully_connected"
-
-class SgDNNLFCPostQuantizeSelector : public SubgraphSelectorV2 {
- public:
-  /*! \brief pattern match status */
-  enum SelectStatus {
-    kFail = 0,
-    kStart,
-    kRequantize,
-    kSuccess,
-  };
-
- private:
-  bool disable_all;
-  bool disable_float_output;
-  SelectStatus status;
-  std::vector<const BiDirectedNode*> matched_list;
-
- public:
-  explicit SgDNNLFCPostQuantizeSelector(const bool dis_all, const bool 
dis_float_output)
-      : disable_all(dis_all), disable_float_output(dis_float_output) {}
-
-  bool Select(const BiDirectedNode& n) override {
-    const auto rawnode = n.node;
-    if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_FC_NAME)) {
-      status = disable_all ? kSuccess : kStart;
-      matched_list.clear();
-      matched_list.push_back(&n);
-      return true;
-    }
-    return false;
-  }
-
-  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node) 
override {
-    return false;
-  }
-
-  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) 
override {
-    const auto raw_node     = n.node;
-    const auto raw_new_node = new_node.node;
-    if (status == kFail || status == kSuccess || raw_new_node->is_variable())
-      return false;
-    // If n isn't the last matched node, then we encoutered a internal
-    // branch, we should pop out the node behind n and stop fusion.
-    if (matched_list.back() != &n) {
-      if (std::find(matched_list.begin(), matched_list.end(), &n) != 
matched_list.end()) {
-        while (matched_list.back() != &n) {
-          matched_list.pop_back();
-        }
-      }
-
-      status = kSuccess;
-      return false;
-    }
-
-    switch (status) {
-      case kStart:
-        if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
-          auto const& param = 
nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
-          if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
-            matched_list.push_back(&new_node);
-            status = kRequantize;
-            return true;
-          }
-        }
-      case kRequantize:
-        if ((!disable_float_output) && (raw_new_node->op() == 
Op::Get("_contrib_dequantize"))) {
-          CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
-          if (n.outputs.size() > 1) {
-            // check if requantize have other outputs than dequantize
-            // if it has we can't fuse dequantize into FC
-            for (auto kv : n.outputs) {
-              const auto& node = kv.first;
-              if (node->op() != Op::Get("_contrib_dequantize")) {
-                status = kSuccess;
-                return false;
-              }
-            }
-          }
-          matched_list.push_back(&new_node);
-          status = kSuccess;
-          return true;
-        }
-      default:
-        status = kSuccess;
-        return false;
-    }
-  }
-
-  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
-    if ((status != kSuccess) || (matched_list.size() <= 1)) {
-      return std::vector<BiDirectedNode*>(0);
-    } else {
-      std::vector<BiDirectedNode*> ret;
-      for (auto i : matched_list) {
-        auto non_const_i = const_cast<BiDirectedNode*>(i);
-        if (std::find(candidates.begin(), candidates.end(), non_const_i) != 
candidates.end()) {
-          ret.push_back(non_const_i);
-        }
-      }
-      return ret;
-    }
-  }
-
-  void Reset() override {
-    CHECK_GE(matched_list.size(), 1);
-    auto new_selector = SgDNNLFCPostQuantizeSelector(disable_all, 
disable_float_output);
-    new_selector.Select(*matched_list[0]);
-    *this = new_selector;
-  }
-};
-
-class SgDNNLFCPostQuantizeProperty : public SubgraphProperty {
- public:
-  SgDNNLFCPostQuantizeProperty() {
-    disable_fuse_all     = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL", 
false);
-    disable_float_output = 
dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT", false);
-  }
-
-  static SubgraphPropertyPtr Create() {
-    static const std::string& name = "oneDNN FullyConected post-quantization 
optimization pass";
-    auto property                  = 
std::make_shared<SgDNNLFCPostQuantizeProperty>();
-    property->SetAttr<std::string>("property_name", name);
-    property->SetAttr<bool>("inference_only", true);
-    return property;
-  }
-
-  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
-                                     const int subgraph_id = 0) const override 
{
-    nnvm::ObjectPtr fc_node         = nullptr;
-    nnvm::ObjectPtr requantize_node = nullptr;
-    nnvm::ObjectPtr dequantize_node = nullptr;
-
-    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
-      if (node->is_variable())
-        return;
-      if (node->op() == Op::Get(QUANTIZED_FC_NAME)) {
-        fc_node = node;
-      } else if (node->op() == Op::Get("_contrib_requantize")) {
-        requantize_node = node;
-      } else if (node->op() == Op::Get("_contrib_dequantize")) {
-        dequantize_node = node;
-      }
-    });
-
-    CHECK_NOTNULL(fc_node);
-    CHECK_NOTNULL(requantize_node);
-    auto const& requantize_param = 
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
-    CHECK(requantize_param.min_calib_range.has_value());
-    CHECK(requantize_param.max_calib_range.has_value());
-
-    // When only fused quantized_fullyconnected and requantize, set 
min/max_cablib_range,
-    // When fused quantized_fullyconnected + requantize + dequantize, set 
dequantize flag to true.
-    if (dequantize_node != nullptr) {
-      fc_node->attrs.dict["enable_float_output"] = "True";
-    } else {
-      fc_node->attrs.dict["min_calib_range"] =
-          std::to_string(requantize_param.min_calib_range.value());
-      fc_node->attrs.dict["max_calib_range"] =
-          std::to_string(requantize_param.max_calib_range.value());
-    }
-    fc_node->op()->attr_parser(&(fc_node->attrs));
-    return fc_node;
-  }
-
-  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
-    auto selector =
-        std::make_shared<SgDNNLFCPostQuantizeSelector>(disable_fuse_all, 
disable_float_output);
-    return selector;
-  }
-
-  void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
-                              std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
-    for (size_t i = 0; i < output_entries->size(); ++i) {
-      auto entry_ptr = output_entries->at(i);
-      *entry_ptr     = nnvm::NodeEntry{n, entry_ptr->index, 0};
-    }
-  }
-
- private:
-  bool disable_fuse_all;
-  bool disable_float_output;
-};
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // if MXNET_USE_ONEDNN == 1
-#endif  // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h
deleted file mode 100644
index 6c384a1..0000000
--- a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h
+++ /dev/null
@@ -1,202 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_
-#if MXNET_USE_ONEDNN == 1
-
-#include <string>
-#include <vector>
-
-#include "../../quantization/requantize-inl.h"
-#include "../common.h"
-#include "dnnl_subgraph_base-inl.h"
-
-namespace mxnet {
-namespace op {
-
-class SgDNNLMatmulPostQuantizeSelector : public SubgraphSelector {
- public:
-  /*! \brief pattern match status */
-  enum SelectStatus {
-    kFail = 0,
-    kStart,
-    kRequantize,
-    kSuccess,
-  };
-
- private:
-  bool disable_all;
-  bool disable_float_output;
-  SelectStatus status;
-  std::vector<const nnvm::Node*> matched_list;
-
- public:
-  explicit SgDNNLMatmulPostQuantizeSelector(const bool dis_all, const bool 
dis_float_output)
-      : disable_all(dis_all), disable_float_output(dis_float_output) {}
-
-  bool Select(const nnvm::Node& n) override {
-    if ((!disable_all) && (n.op() == Op::Get("_sg_onednn_selfatt_qk") ||
-                           n.op() == Op::Get("_sg_onednn_selfatt_valatt") ||
-                           n.op() == Op::Get("_sg_onednn_batch_dot"))) {
-      status = disable_all ? kSuccess : kStart;
-      matched_list.clear();
-      matched_list.push_back(&n);
-      return true;
-    }
-    return false;
-  }
-
-  bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
-    return false;
-  }
-
-  bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
-    if (status == kFail || status == kSuccess || new_node.is_variable())
-      return false;
-    // If n isn't the last matched node, then we encoutered a internal
-    // branch, we should pop out the node behind n and stop fusion.
-    if (matched_list.back() != &n) {
-      if (std::find(matched_list.begin(), matched_list.end(), &n) != 
matched_list.end()) {
-        while (matched_list.back() != &n) {
-          matched_list.pop_back();
-        }
-      }
-
-      status = kSuccess;
-      return false;
-    }
-
-    switch (status) {
-      case kStart:
-        if (new_node.op() == Op::Get("_contrib_requantize")) {
-          auto const& param = 
nnvm::get<RequantizeParam>(new_node.attrs.parsed);
-          if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
-            matched_list.push_back(&new_node);
-            status = kRequantize;
-            return true;
-          }
-        }
-      case kRequantize:
-        if ((!disable_float_output) && (new_node.op() == 
Op::Get("_contrib_dequantize"))) {
-          matched_list.push_back(&new_node);
-          status = kSuccess;
-          return true;
-        }
-      default:
-        status = kSuccess;
-        return false;
-    }
-  }
-
-  std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) 
override {
-    if ((status != kSuccess) || (matched_list.size() <= 1)) {
-      return std::vector<nnvm::Node*>(0);
-    } else {
-      std::vector<nnvm::Node*> ret;
-      for (auto i : matched_list) {
-        auto non_const_i = const_cast<nnvm::Node*>(i);
-        if (std::find(candidates.begin(), candidates.end(), non_const_i) != 
candidates.end()) {
-          ret.push_back(non_const_i);
-        }
-      }
-      return ret;
-    }
-  }
-
-  void Reset() override {
-    CHECK_GE(matched_list.size(), 1);
-    auto new_selector = SgDNNLMatmulPostQuantizeSelector(disable_all, 
disable_float_output);
-    new_selector.Select(*matched_list[0]);
-    *this = new_selector;
-  }
-};
-
-class SgDNNLMatmulPostQuantizeProperty : public SubgraphProperty {
- public:
-  SgDNNLMatmulPostQuantizeProperty() {
-    disable_fuse_all     = dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FUSE_ALL", 
false);
-    disable_float_output = 
dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FLOAT_OUTPUT", false);
-  }
-
-  static SubgraphPropertyPtr Create() {
-    static const std::string& name = "oneDNN Matmul post-quantization 
optimization pass";
-    auto property                  = 
std::make_shared<SgDNNLMatmulPostQuantizeProperty>();
-    property->SetAttr<std::string>("property_name", name);
-    property->SetAttr<bool>("inference_only", true);
-    return property;
-  }
-
-  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
-                                     const int subgraph_id = 0) const override 
{
-    nnvm::ObjectPtr interleaved_node = nullptr;
-    nnvm::ObjectPtr requantize_node  = nullptr;
-    nnvm::ObjectPtr dequantize_node  = nullptr;
-
-    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
-      if (node->is_variable())
-        return;
-      if (node->op() == Op::Get("_sg_onednn_selfatt_qk") ||
-          node->op() == Op::Get("_sg_onednn_selfatt_valatt") ||
-          node->op() == Op::Get("_sg_onednn_batch_dot")) {
-        interleaved_node = node;
-      } else if (node->op() == Op::Get("_contrib_requantize")) {
-        requantize_node = node;
-      } else if (node->op() == Op::Get("_contrib_dequantize")) {
-        dequantize_node = node;
-      }
-    });
-
-    CHECK_NOTNULL(interleaved_node);
-    CHECK_NOTNULL(requantize_node);
-    auto const& requantize_param = 
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
-    CHECK(requantize_param.min_calib_range.has_value());
-    CHECK(requantize_param.max_calib_range.has_value());
-
-    // When only fusing quantized_interleaved_matmul and requantize, set 
min/max_cablib_range,
-    // When fusing quantized_interleaved_matmul + requantize + dequantize,
-    // set dequantize flag to true.
-    if (dequantize_node != nullptr) {
-      interleaved_node->attrs.dict["enable_float_output"] = "True";
-    } else {
-      interleaved_node->attrs.dict["min_calib_range"] =
-          std::to_string(requantize_param.min_calib_range.value());
-      interleaved_node->attrs.dict["max_calib_range"] =
-          std::to_string(requantize_param.max_calib_range.value());
-    }
-    interleaved_node->op()->attr_parser(&(interleaved_node->attrs));
-    return interleaved_node;
-  }
-
-  SubgraphSelectorPtr CreateSubgraphSelector() const override {
-    auto selector =
-        std::make_shared<SgDNNLMatmulPostQuantizeSelector>(disable_fuse_all, 
disable_float_output);
-    return selector;
-  }
-
- private:
-  bool disable_fuse_all;
-  bool disable_float_output;
-};
-
-}  // namespace op
-}  // namespace mxnet
-
-#endif  // if MXNET_USE_ONEDNN == 1
-#endif  // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h 
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 662b792..cddf4b4 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -20,110 +20,161 @@
 #define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_QUANTIZE_PROPERTY_H_
 #if MXNET_USE_ONEDNN == 1
 
+#include <memory>
 #include <set>
 #include <string>
 #include <vector>
 
-#include "../../nn/dnnl/dnnl_convolution-inl.h"
-#include "../../quantization/requantize-inl.h"
-#include "../common.h"
+#include "operator/nn/dnnl/dnnl_convolution-inl.h"
+#include "operator/nn/fully_connected-inl.h"
+#include "operator/quantization/requantize-inl.h"
+#include "operator/tensor/elemwise_binary_op-inl.h"
+#include "operator/subgraph/common.h"
 #include "dnnl_conv-inl.h"
 #include "dnnl_subgraph_base-inl.h"
 
 namespace mxnet {
 namespace op {
-
-class SgDNNLPostQuantizeSelector : public SubgraphSelector {
- public:
+namespace {
+const std::set<std::string> support_req_fusion_op = 
{"_contrib_quantized_elemwise_add",
+                                                     
"_contrib_quantized_elemwise_mul",
+                                                     
"_contrib_quantized_npi_add",
+                                                     "_sg_onednn_conv",
+                                                     
"_sg_onednn_fully_connected",
+                                                     "_sg_onednn_selfatt_qk",
+                                                     
"_sg_onednn_selfatt_valatt",
+                                                     "_sg_onednn_batch_dot"};
+}  // namespace
+
+class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
+ private:
   /*! \brief pattern match status */
-  enum SelectStatus {
+  enum class SelectStatus {
     kFail = 0,
     kStart,
+    kRequantize,
     kSuccess,
   };
 
- private:
+  bool fuse_all;
+  bool float_output;
   SelectStatus status;
-  std::vector<const nnvm::Node*> matched_list;
+  std::vector<const BiDirectedNode*> matched_list;
   std::set<std::string> support_requantize_fusion_op_name;
 
  public:
-  SgDNNLPostQuantizeSelector() {
-    support_requantize_fusion_op_name.insert("_sg_onednn_conv");
-    
support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add");
-    support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add");
+  explicit SgDNNLPostQuantizeSelector(const bool fuse_all, const bool 
float_output)
+      : fuse_all(fuse_all), float_output(float_output) {
+    support_requantize_fusion_op_name = support_req_fusion_op;
   }
 
-  bool Select(const nnvm::Node& n) override {
-    if (n.op() && support_requantize_fusion_op_name.count(n.op()->name)) {
-      if (n.op() == Op::Get("_sg_onednn_conv")) {
-        auto const& param = nnvm::get<DNNLConvFusionParam>(n.attrs.parsed);
-        if (param.full_conv_param.dnnl_param.quantized) {
-          status = kStart;
-          matched_list.clear();
-          matched_list.push_back(&n);
-          return true;
-        }
-      } else if (n.op()->name == "_contrib_quantized_elemwise_add" ||
-                 n.op()->name == "_contrib_quantized_npi_add") {
-        status = kStart;
-        matched_list.clear();
-        matched_list.push_back(&n);
-        return true;
-      }
+  bool Select(const BiDirectedNode& n) override {
+    const nnvm::Node* raw_node = n.node;
+    if (fuse_all && raw_node->op() &&
+        support_requantize_fusion_op_name.count(raw_node->op()->name)) {
+      status = SelectStatus::kStart;
+      matched_list.clear();
+      matched_list.emplace_back(&n);
+      return true;
     }
     return false;
   }
 
-  bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node) 
override {
     return false;
   }
 
-  bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
-    if (status == kFail || status == kSuccess || new_node.is_variable())
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node) 
override {
+    const nnvm::Node* raw_node     = n.node;
+    const nnvm::Node* raw_new_node = new_node.node;
+    if (status == SelectStatus::kFail || status == SelectStatus::kSuccess ||
+        raw_new_node->is_variable())
       return false;
     // If n isn't the last matched node, then we encoutered a internal
     // branch, we should pop out the node behind n and stop fusion.
     if (matched_list.back() != &n) {
-      status = kFail;
+      if (std::find(matched_list.begin(), matched_list.end(), &n) != 
matched_list.end()) {
+        while (matched_list.back() != &n) {
+          matched_list.pop_back();
+        }
+      }
+      status = SelectStatus::kSuccess;
       return false;
     }
-    if (new_node.op()->name == "_contrib_requantize") {
-      auto const& param = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
-      if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
-        matched_list.push_back(&new_node);
-        status = kSuccess;
-        return true;
-      } else {
-        status = kFail;
-      }
+
+    switch (status) {
+      case SelectStatus::kStart:
+        if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
+          auto const& param = 
nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
+          if (param.min_calib_range.has_value() && 
param.max_calib_range.has_value()) {
+            matched_list.emplace_back(&new_node);
+            status = SelectStatus::kRequantize;
+            if (raw_node->op() == Op::Get("_sg_onednn_conv")) {
+              status = SelectStatus::kSuccess;
+            }
+            return true;
+          }
+        }
+      case SelectStatus::kRequantize:
+        if (float_output && raw_new_node->op() == 
Op::Get("_contrib_dequantize")) {
+          CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
+          if (n.outputs.size() > 1) {
+            // check if requantize have other outputs than dequantize
+            // if it has we can't fuse dequantize
+            for (const auto& kv : n.outputs) {
+              const auto& node = kv.first;
+              if (node->op() != Op::Get("_contrib_dequantize")) {
+                status = SelectStatus::kSuccess;
+                return false;
+              }
+            }
+          }
+          matched_list.emplace_back(&new_node);
+          status = SelectStatus::kSuccess;
+          return true;
+        }
+      default:
+        status = SelectStatus::kSuccess;
+        return false;
     }
-    return false;
   }
 
-  std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates) 
override {
-    if (status != kSuccess) {
-      return std::vector<nnvm::Node*>(0);
+  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
+    if (status != SelectStatus::kSuccess || (matched_list.size() <= 1)) {
+      return std::vector<BiDirectedNode*>(0);
     } else {
-      return candidates;
+      std::vector<BiDirectedNode*> ret;
+      for (auto i : matched_list) {
+        auto non_const_i = const_cast<BiDirectedNode*>(i);
+        if (std::find(candidates.begin(), candidates.end(), non_const_i) != 
candidates.end()) {
+          ret.push_back(non_const_i);
+        }
+      }
+      return ret;
     }
   }
 
   void Reset() override {
     CHECK_GE(matched_list.size(), 1);
-    auto new_selector = SgDNNLPostQuantizeSelector();
+    auto new_selector = SgDNNLPostQuantizeSelector(fuse_all, float_output);
     new_selector.Select(*matched_list[0]);
     *this = new_selector;
   }
 };
 
 class SgDNNLPostQuantizeProperty : public SubgraphProperty {
+ private:
+  bool fuse_all;
+  bool float_output;
+  std::set<std::string> support_requantize_fusion_op_name;
+
  public:
   SgDNNLPostQuantizeProperty() {
-    support_requantize_fusion_op_name.insert("_sg_onednn_conv");
-    
support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add");
-    support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add");
+    fuse_all                          = 
dmlc::GetEnv("MXNET_ONEDNN_FUSE_REQUANTIZE", true);
+    float_output                      = 
dmlc::GetEnv("MXNET_ONEDNN_FUSE_DEQUANTIZE", true);
+    support_requantize_fusion_op_name = support_req_fusion_op;
   }
+
   static SubgraphPropertyPtr Create() {
     static const std::string& name = "oneDNN post-quantization optimization 
pass";
     auto property                  = 
std::make_shared<SgDNNLPostQuantizeProperty>();
@@ -131,35 +182,47 @@ class SgDNNLPostQuantizeProperty : public 
SubgraphProperty {
     property->SetAttr<bool>("inference_only", true);
     return property;
   }
+
   nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
                                      const int subgraph_id = 0) const override 
{
     nnvm::ObjectPtr fuse_node       = nullptr;
     nnvm::ObjectPtr requantize_node = nullptr;
+    nnvm::ObjectPtr dequantize_node = nullptr;
+
     DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
       if (node->is_variable())
         return;
-      auto& op_name = node->op()->name;
-      if (support_requantize_fusion_op_name.count(op_name)) {
+      if (node->op() && 
support_requantize_fusion_op_name.count(node->op()->name)) {
         fuse_node = node;
-      } else if (op_name == "_contrib_requantize") {
+      } else if (node->op() == Op::Get("_contrib_requantize")) {
         requantize_node = node;
+      } else if (node->op() == Op::Get("_contrib_dequantize")) {
+        dequantize_node = node;
       }
     });
+
     CHECK_NOTNULL(fuse_node);
     CHECK_NOTNULL(requantize_node);
     auto const& requantize_param = 
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
     CHECK(requantize_param.min_calib_range.has_value());
     CHECK(requantize_param.max_calib_range.has_value());
-    fuse_node->attrs.dict["min_calib_range"] =
-        std::to_string(requantize_param.min_calib_range.value());
-    fuse_node->attrs.dict["max_calib_range"] =
-        std::to_string(requantize_param.max_calib_range.value());
+
+    // When only fused quantized operator and requantize, set 
min/max_cablib_range,
+    // When fused quantized operator + requantize + dequantize, set dequantize 
flag to true.
+    if (dequantize_node != nullptr) {
+      fuse_node->attrs.dict["enable_float_output"] = "True";
+    } else {
+      fuse_node->attrs.dict["min_calib_range"] =
+          std::to_string(requantize_param.min_calib_range.value());
+      fuse_node->attrs.dict["max_calib_range"] =
+          std::to_string(requantize_param.max_calib_range.value());
+    }
     fuse_node->op()->attr_parser(&(fuse_node->attrs));
     return fuse_node;
   }
 
-  SubgraphSelectorPtr CreateSubgraphSelector() const override {
-    auto selector = std::make_shared<SgDNNLPostQuantizeSelector>();
+  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+    auto selector = std::make_shared<SgDNNLPostQuantizeSelector>(fuse_all, 
float_output);
     return selector;
   }
 
@@ -170,10 +233,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty 
{
       *entry_ptr     = nnvm::NodeEntry{n, entry_ptr->index, 0};
     }
   }
-
- private:
-  std::set<std::string> support_requantize_fusion_op_name;
 };
+
 }  // namespace op
 }  // namespace mxnet
 
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc 
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 4a5f6a6..9727187 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -22,10 +22,7 @@
 #include "dnnl_batch_dot_property.h"
 #include "dnnl_bn_relu_property.h"
 #include "dnnl_conv_property.h"
-#include "dnnl_elemwisemul_post_quantize_property.h"
-#include "dnnl_fc_post_quantize_property.h"
 #include "dnnl_fc_property.h"
-#include "dnnl_matmul_post_quantize_property.h"
 #include "dnnl_post_quantize_align_scale_property.h"
 #include "dnnl_post_quantize_property.h"
 #include "dnnl_transformer_qk_property.h"
@@ -54,11 +51,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLTransformerValAttPropert
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty)
     .set_attr("quantize", true);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty);
-MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLFCPostQuantizeProperty);
-MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
ElemwiseMulPostQuantizeProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLPostQuantizeAlignScaleProperty);
-MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, 
SgDNNLMatmulPostQuantizeProperty)
-    .set_attr("quantize", true);
 
 }  // namespace op
 }  // namespace mxnet

Reply via email to