This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 26f9fa6 Unifying oneDNN post-quantization properties (#20724)
26f9fa6 is described below
commit 26f9fa6cdfc08de600bf0202534e17e64e4b9134
Author: DominikaJedynak <[email protected]>
AuthorDate: Thu Nov 25 12:16:24 2021 +0100
Unifying oneDNN post-quantization properties (#20724)
* * Unifying post-quantization properties
* Compatibility and review fixes
* Review changes
* Small fix
---
.../dnnl/dnnl_elemwisemul_post_quantize_property.h | 231 ---------------------
.../subgraph/dnnl/dnnl_fc_post_quantize_property.h | 230 --------------------
.../dnnl/dnnl_matmul_post_quantize_property.h | 202 ------------------
.../subgraph/dnnl/dnnl_post_quantize_property.h | 189 +++++++++++------
.../subgraph/dnnl/dnnl_subgraph_property.cc | 7 -
5 files changed, 125 insertions(+), 734 deletions(-)
diff --git
a/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h
deleted file mode 100644
index 5e015cb..0000000
--- a/src/operator/subgraph/dnnl/dnnl_elemwisemul_post_quantize_property.h
+++ /dev/null
@@ -1,231 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_elemwisemul_post_quantize_property.cc
- * \brief Partition gragph property for oneDNN Quantized ElemwiseMul operator
- * \author Xinyu Chen
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
-#if MXNET_USE_ONEDNN == 1
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "../../quantization/requantize-inl.h"
-#include "../../tensor/elemwise_binary_op-inl.h"
-#include "../common.h"
-#include "dnnl_subgraph_base-inl.h"
-
-namespace mxnet {
-namespace op {
-
-#define QUANTIZED_ElemwiseMul_NAME "_contrib_quantized_elemwise_mul"
-
-class ElemwiseMulPostQuantizeSelector : public SubgraphSelectorV2 {
- public:
- /*! \brief pattern match status */
- enum SelectStatus {
- kFail = 0,
- kStart,
- kRequantize,
- kSuccess,
- };
-
- private:
- bool disable_all;
- bool disable_float_output;
- SelectStatus status;
- std::vector<const BiDirectedNode*> matched_list;
-
- public:
- explicit ElemwiseMulPostQuantizeSelector(const bool dis_all, const bool
dis_float_output)
- : disable_all(dis_all), disable_float_output(dis_float_output) {}
-
- bool Select(const BiDirectedNode& n) override {
- const auto rawnode = n.node;
- if ((!disable_all) && rawnode->op() ==
Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
- status = disable_all ? kSuccess : kStart;
- matched_list.clear();
- matched_list.push_back(&n);
- return true;
- }
- return false;
- }
-
- bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node)
override {
- return false;
- }
-
- bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node)
override {
- const auto raw_node = n.node;
- const auto raw_new_node = new_node.node;
- if (status == kFail || status == kSuccess || raw_new_node->is_variable())
- return false;
- // If n isn't the last matched node, then we encoutered a internal
- // branch, we should pop out the node behind n and stop fusion.
- if (matched_list.back() != &n) {
- if (std::find(matched_list.begin(), matched_list.end(), &n) !=
matched_list.end()) {
- while (matched_list.back() != &n) {
- matched_list.pop_back();
- }
- }
-
- status = kSuccess;
- return false;
- }
-
- switch (status) {
- case kStart:
- if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
- auto const& param =
nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
- if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
- matched_list.push_back(&new_node);
- status = kRequantize;
- return true;
- }
- }
- case kRequantize:
- if ((!disable_float_output) && (raw_new_node->op() ==
Op::Get("_contrib_dequantize"))) {
- CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
- if (n.outputs.size() > 1) {
- // check if requantize have other outputs than dequantize
- // if it has we can't fuse dequantize into elemwise_mul
- for (auto kv : n.outputs) {
- const auto& node = kv.first;
- if (node->op() != Op::Get("_contrib_dequantize")) {
- status = kSuccess;
- return false;
- }
- }
- }
-
- matched_list.push_back(&new_node);
- status = kSuccess;
- return true;
- }
- default:
- status = kSuccess;
- return false;
- }
- }
-
- std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
- if ((status != kSuccess) || (matched_list.size() <= 1)) {
- return std::vector<BiDirectedNode*>(0);
- } else {
- std::vector<BiDirectedNode*> ret;
- for (auto i : matched_list) {
- auto non_const_i = const_cast<BiDirectedNode*>(i);
- if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
- ret.push_back(non_const_i);
- }
- }
- return ret;
- }
- }
-
- void Reset() override {
- CHECK_GE(matched_list.size(), 1);
- auto new_selector = ElemwiseMulPostQuantizeSelector(disable_all,
disable_float_output);
- new_selector.Select(*matched_list[0]);
- *this = new_selector;
- }
-};
-
-class ElemwiseMulPostQuantizeProperty : public SubgraphProperty {
- public:
- ElemwiseMulPostQuantizeProperty() {
- disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QEM_FUSE_ALL",
false);
- disable_float_output =
dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QEM_FLOAT_OUTPUT", false);
- }
-
- static SubgraphPropertyPtr Create() {
- static const std::string& name = "oneDNN EltwiseMul post-quantization
optimization pass";
- auto property =
std::make_shared<ElemwiseMulPostQuantizeProperty>();
- property->SetAttr<std::string>("property_name", name);
- property->SetAttr<bool>("inference_only", true);
- return property;
- }
-
- nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
- const int subgraph_id = 0) const override
{
- nnvm::ObjectPtr em_node = nullptr;
- nnvm::ObjectPtr requantize_node = nullptr;
- nnvm::ObjectPtr dequantize_node = nullptr;
-
- DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
- if (node->is_variable())
- return;
- if (node->op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
- em_node = node;
- } else if (node->op() == Op::Get("_contrib_requantize")) {
- requantize_node = node;
- } else if (node->op() == Op::Get("_contrib_dequantize")) {
- dequantize_node = node;
- }
- });
-
- CHECK_NOTNULL(em_node);
- CHECK_NOTNULL(requantize_node);
- auto const& requantize_param =
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
- CHECK(requantize_param.min_calib_range.has_value());
- CHECK(requantize_param.max_calib_range.has_value());
-
- // When only fused quantized_elemwise_mul and requantize, set
min/max_cablib_range,
- // When fused quantized_elemwise_mul + requantize + dequantize, set
dequantize flag to true.
- if (dequantize_node != nullptr) {
- em_node->attrs.dict["enable_float_output"] = "True";
- } else {
- em_node->attrs.dict["min_calib_range"] =
- std::to_string(requantize_param.min_calib_range.value());
- em_node->attrs.dict["max_calib_range"] =
- std::to_string(requantize_param.max_calib_range.value());
- }
- em_node->op()->attr_parser(&(em_node->attrs));
- return em_node;
- }
-
- SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
- auto selector =
- std::make_shared<ElemwiseMulPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
- return selector;
- }
-
- void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
- std::vector<nnvm::NodeEntry*>* output_entries)
const override {
- for (size_t i = 0; i < output_entries->size(); ++i) {
- auto entry_ptr = output_entries->at(i);
- *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
- }
- }
-
- private:
- bool disable_fuse_all;
- bool disable_float_output;
-};
-
-} // namespace op
-} // namespace mxnet
-
-#endif // if MXNET_USE_ONEDNN == 1
-#endif //
MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h
deleted file mode 100644
index b1ae537..0000000
--- a/src/operator/subgraph/dnnl/dnnl_fc_post_quantize_property.h
+++ /dev/null
@@ -1,230 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file dnnl_fc_post_quantize_property.cc
- * \brief Partition gragph property for oneDNN Quantized FullyConnected
operator
- * \author Ciyong Chen
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_
-#if MXNET_USE_ONEDNN == 1
-
-#include <memory>
-#include <string>
-#include <vector>
-
-#include "../../nn/fully_connected-inl.h"
-#include "../../quantization/requantize-inl.h"
-#include "../common.h"
-#include "dnnl_subgraph_base-inl.h"
-
-namespace mxnet {
-namespace op {
-
-#define QUANTIZED_FC_NAME "_sg_onednn_fully_connected"
-
-class SgDNNLFCPostQuantizeSelector : public SubgraphSelectorV2 {
- public:
- /*! \brief pattern match status */
- enum SelectStatus {
- kFail = 0,
- kStart,
- kRequantize,
- kSuccess,
- };
-
- private:
- bool disable_all;
- bool disable_float_output;
- SelectStatus status;
- std::vector<const BiDirectedNode*> matched_list;
-
- public:
- explicit SgDNNLFCPostQuantizeSelector(const bool dis_all, const bool
dis_float_output)
- : disable_all(dis_all), disable_float_output(dis_float_output) {}
-
- bool Select(const BiDirectedNode& n) override {
- const auto rawnode = n.node;
- if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_FC_NAME)) {
- status = disable_all ? kSuccess : kStart;
- matched_list.clear();
- matched_list.push_back(&n);
- return true;
- }
- return false;
- }
-
- bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node)
override {
- return false;
- }
-
- bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node)
override {
- const auto raw_node = n.node;
- const auto raw_new_node = new_node.node;
- if (status == kFail || status == kSuccess || raw_new_node->is_variable())
- return false;
- // If n isn't the last matched node, then we encoutered a internal
- // branch, we should pop out the node behind n and stop fusion.
- if (matched_list.back() != &n) {
- if (std::find(matched_list.begin(), matched_list.end(), &n) !=
matched_list.end()) {
- while (matched_list.back() != &n) {
- matched_list.pop_back();
- }
- }
-
- status = kSuccess;
- return false;
- }
-
- switch (status) {
- case kStart:
- if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
- auto const& param =
nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
- if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
- matched_list.push_back(&new_node);
- status = kRequantize;
- return true;
- }
- }
- case kRequantize:
- if ((!disable_float_output) && (raw_new_node->op() ==
Op::Get("_contrib_dequantize"))) {
- CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
- if (n.outputs.size() > 1) {
- // check if requantize have other outputs than dequantize
- // if it has we can't fuse dequantize into FC
- for (auto kv : n.outputs) {
- const auto& node = kv.first;
- if (node->op() != Op::Get("_contrib_dequantize")) {
- status = kSuccess;
- return false;
- }
- }
- }
- matched_list.push_back(&new_node);
- status = kSuccess;
- return true;
- }
- default:
- status = kSuccess;
- return false;
- }
- }
-
- std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
- if ((status != kSuccess) || (matched_list.size() <= 1)) {
- return std::vector<BiDirectedNode*>(0);
- } else {
- std::vector<BiDirectedNode*> ret;
- for (auto i : matched_list) {
- auto non_const_i = const_cast<BiDirectedNode*>(i);
- if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
- ret.push_back(non_const_i);
- }
- }
- return ret;
- }
- }
-
- void Reset() override {
- CHECK_GE(matched_list.size(), 1);
- auto new_selector = SgDNNLFCPostQuantizeSelector(disable_all,
disable_float_output);
- new_selector.Select(*matched_list[0]);
- *this = new_selector;
- }
-};
-
-class SgDNNLFCPostQuantizeProperty : public SubgraphProperty {
- public:
- SgDNNLFCPostQuantizeProperty() {
- disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QFC_FUSE_ALL",
false);
- disable_float_output =
dmlc::GetEnv("MXNET_DISABLE_ONEDNN_QFC_FLOAT_OUTPUT", false);
- }
-
- static SubgraphPropertyPtr Create() {
- static const std::string& name = "oneDNN FullyConected post-quantization
optimization pass";
- auto property =
std::make_shared<SgDNNLFCPostQuantizeProperty>();
- property->SetAttr<std::string>("property_name", name);
- property->SetAttr<bool>("inference_only", true);
- return property;
- }
-
- nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
- const int subgraph_id = 0) const override
{
- nnvm::ObjectPtr fc_node = nullptr;
- nnvm::ObjectPtr requantize_node = nullptr;
- nnvm::ObjectPtr dequantize_node = nullptr;
-
- DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
- if (node->is_variable())
- return;
- if (node->op() == Op::Get(QUANTIZED_FC_NAME)) {
- fc_node = node;
- } else if (node->op() == Op::Get("_contrib_requantize")) {
- requantize_node = node;
- } else if (node->op() == Op::Get("_contrib_dequantize")) {
- dequantize_node = node;
- }
- });
-
- CHECK_NOTNULL(fc_node);
- CHECK_NOTNULL(requantize_node);
- auto const& requantize_param =
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
- CHECK(requantize_param.min_calib_range.has_value());
- CHECK(requantize_param.max_calib_range.has_value());
-
- // When only fused quantized_fullyconnected and requantize, set
min/max_cablib_range,
- // When fused quantized_fullyconnected + requantize + dequantize, set
dequantize flag to true.
- if (dequantize_node != nullptr) {
- fc_node->attrs.dict["enable_float_output"] = "True";
- } else {
- fc_node->attrs.dict["min_calib_range"] =
- std::to_string(requantize_param.min_calib_range.value());
- fc_node->attrs.dict["max_calib_range"] =
- std::to_string(requantize_param.max_calib_range.value());
- }
- fc_node->op()->attr_parser(&(fc_node->attrs));
- return fc_node;
- }
-
- SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
- auto selector =
- std::make_shared<SgDNNLFCPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
- return selector;
- }
-
- void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
- std::vector<nnvm::NodeEntry*>* output_entries)
const override {
- for (size_t i = 0; i < output_entries->size(); ++i) {
- auto entry_ptr = output_entries->at(i);
- *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
- }
- }
-
- private:
- bool disable_fuse_all;
- bool disable_float_output;
-};
-
-} // namespace op
-} // namespace mxnet
-
-#endif // if MXNET_USE_ONEDNN == 1
-#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h
deleted file mode 100644
index 6c384a1..0000000
--- a/src/operator/subgraph/dnnl/dnnl_matmul_post_quantize_property.h
+++ /dev/null
@@ -1,202 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_
-#if MXNET_USE_ONEDNN == 1
-
-#include <string>
-#include <vector>
-
-#include "../../quantization/requantize-inl.h"
-#include "../common.h"
-#include "dnnl_subgraph_base-inl.h"
-
-namespace mxnet {
-namespace op {
-
-class SgDNNLMatmulPostQuantizeSelector : public SubgraphSelector {
- public:
- /*! \brief pattern match status */
- enum SelectStatus {
- kFail = 0,
- kStart,
- kRequantize,
- kSuccess,
- };
-
- private:
- bool disable_all;
- bool disable_float_output;
- SelectStatus status;
- std::vector<const nnvm::Node*> matched_list;
-
- public:
- explicit SgDNNLMatmulPostQuantizeSelector(const bool dis_all, const bool
dis_float_output)
- : disable_all(dis_all), disable_float_output(dis_float_output) {}
-
- bool Select(const nnvm::Node& n) override {
- if ((!disable_all) && (n.op() == Op::Get("_sg_onednn_selfatt_qk") ||
- n.op() == Op::Get("_sg_onednn_selfatt_valatt") ||
- n.op() == Op::Get("_sg_onednn_batch_dot"))) {
- status = disable_all ? kSuccess : kStart;
- matched_list.clear();
- matched_list.push_back(&n);
- return true;
- }
- return false;
- }
-
- bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
- return false;
- }
-
- bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
- if (status == kFail || status == kSuccess || new_node.is_variable())
- return false;
- // If n isn't the last matched node, then we encoutered a internal
- // branch, we should pop out the node behind n and stop fusion.
- if (matched_list.back() != &n) {
- if (std::find(matched_list.begin(), matched_list.end(), &n) !=
matched_list.end()) {
- while (matched_list.back() != &n) {
- matched_list.pop_back();
- }
- }
-
- status = kSuccess;
- return false;
- }
-
- switch (status) {
- case kStart:
- if (new_node.op() == Op::Get("_contrib_requantize")) {
- auto const& param =
nnvm::get<RequantizeParam>(new_node.attrs.parsed);
- if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
- matched_list.push_back(&new_node);
- status = kRequantize;
- return true;
- }
- }
- case kRequantize:
- if ((!disable_float_output) && (new_node.op() ==
Op::Get("_contrib_dequantize"))) {
- matched_list.push_back(&new_node);
- status = kSuccess;
- return true;
- }
- default:
- status = kSuccess;
- return false;
- }
- }
-
- std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates)
override {
- if ((status != kSuccess) || (matched_list.size() <= 1)) {
- return std::vector<nnvm::Node*>(0);
- } else {
- std::vector<nnvm::Node*> ret;
- for (auto i : matched_list) {
- auto non_const_i = const_cast<nnvm::Node*>(i);
- if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
- ret.push_back(non_const_i);
- }
- }
- return ret;
- }
- }
-
- void Reset() override {
- CHECK_GE(matched_list.size(), 1);
- auto new_selector = SgDNNLMatmulPostQuantizeSelector(disable_all,
disable_float_output);
- new_selector.Select(*matched_list[0]);
- *this = new_selector;
- }
-};
-
-class SgDNNLMatmulPostQuantizeProperty : public SubgraphProperty {
- public:
- SgDNNLMatmulPostQuantizeProperty() {
- disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FUSE_ALL",
false);
- disable_float_output =
dmlc::GetEnv("MXNET_DISABLE_DNNL_QMATMUL_FLOAT_OUTPUT", false);
- }
-
- static SubgraphPropertyPtr Create() {
- static const std::string& name = "oneDNN Matmul post-quantization
optimization pass";
- auto property =
std::make_shared<SgDNNLMatmulPostQuantizeProperty>();
- property->SetAttr<std::string>("property_name", name);
- property->SetAttr<bool>("inference_only", true);
- return property;
- }
-
- nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
- const int subgraph_id = 0) const override
{
- nnvm::ObjectPtr interleaved_node = nullptr;
- nnvm::ObjectPtr requantize_node = nullptr;
- nnvm::ObjectPtr dequantize_node = nullptr;
-
- DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
- if (node->is_variable())
- return;
- if (node->op() == Op::Get("_sg_onednn_selfatt_qk") ||
- node->op() == Op::Get("_sg_onednn_selfatt_valatt") ||
- node->op() == Op::Get("_sg_onednn_batch_dot")) {
- interleaved_node = node;
- } else if (node->op() == Op::Get("_contrib_requantize")) {
- requantize_node = node;
- } else if (node->op() == Op::Get("_contrib_dequantize")) {
- dequantize_node = node;
- }
- });
-
- CHECK_NOTNULL(interleaved_node);
- CHECK_NOTNULL(requantize_node);
- auto const& requantize_param =
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
- CHECK(requantize_param.min_calib_range.has_value());
- CHECK(requantize_param.max_calib_range.has_value());
-
- // When only fusing quantized_interleaved_matmul and requantize, set
min/max_cablib_range,
- // When fusing quantized_interleaved_matmul + requantize + dequantize,
- // set dequantize flag to true.
- if (dequantize_node != nullptr) {
- interleaved_node->attrs.dict["enable_float_output"] = "True";
- } else {
- interleaved_node->attrs.dict["min_calib_range"] =
- std::to_string(requantize_param.min_calib_range.value());
- interleaved_node->attrs.dict["max_calib_range"] =
- std::to_string(requantize_param.max_calib_range.value());
- }
- interleaved_node->op()->attr_parser(&(interleaved_node->attrs));
- return interleaved_node;
- }
-
- SubgraphSelectorPtr CreateSubgraphSelector() const override {
- auto selector =
- std::make_shared<SgDNNLMatmulPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
- return selector;
- }
-
- private:
- bool disable_fuse_all;
- bool disable_float_output;
-};
-
-} // namespace op
-} // namespace mxnet
-
-#endif // if MXNET_USE_ONEDNN == 1
-#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_MATMUL_POST_QUANTIZE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
index 662b792..cddf4b4 100644
--- a/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
+++ b/src/operator/subgraph/dnnl/dnnl_post_quantize_property.h
@@ -20,110 +20,161 @@
#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_POST_QUANTIZE_PROPERTY_H_
#if MXNET_USE_ONEDNN == 1
+#include <memory>
#include <set>
#include <string>
#include <vector>
-#include "../../nn/dnnl/dnnl_convolution-inl.h"
-#include "../../quantization/requantize-inl.h"
-#include "../common.h"
+#include "operator/nn/dnnl/dnnl_convolution-inl.h"
+#include "operator/nn/fully_connected-inl.h"
+#include "operator/quantization/requantize-inl.h"
+#include "operator/tensor/elemwise_binary_op-inl.h"
+#include "operator/subgraph/common.h"
#include "dnnl_conv-inl.h"
#include "dnnl_subgraph_base-inl.h"
namespace mxnet {
namespace op {
-
-class SgDNNLPostQuantizeSelector : public SubgraphSelector {
- public:
+namespace {
+const std::set<std::string> support_req_fusion_op =
{"_contrib_quantized_elemwise_add",
+
"_contrib_quantized_elemwise_mul",
+
"_contrib_quantized_npi_add",
+ "_sg_onednn_conv",
+
"_sg_onednn_fully_connected",
+ "_sg_onednn_selfatt_qk",
+
"_sg_onednn_selfatt_valatt",
+ "_sg_onednn_batch_dot"};
+} // namespace
+
+class SgDNNLPostQuantizeSelector : public SubgraphSelectorV2 {
+ private:
/*! \brief pattern match status */
- enum SelectStatus {
+ enum class SelectStatus {
kFail = 0,
kStart,
+ kRequantize,
kSuccess,
};
- private:
+ bool fuse_all;
+ bool float_output;
SelectStatus status;
- std::vector<const nnvm::Node*> matched_list;
+ std::vector<const BiDirectedNode*> matched_list;
std::set<std::string> support_requantize_fusion_op_name;
public:
- SgDNNLPostQuantizeSelector() {
- support_requantize_fusion_op_name.insert("_sg_onednn_conv");
-
support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add");
- support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add");
+ explicit SgDNNLPostQuantizeSelector(const bool fuse_all, const bool
float_output)
+ : fuse_all(fuse_all), float_output(float_output) {
+ support_requantize_fusion_op_name = support_req_fusion_op;
}
- bool Select(const nnvm::Node& n) override {
- if (n.op() && support_requantize_fusion_op_name.count(n.op()->name)) {
- if (n.op() == Op::Get("_sg_onednn_conv")) {
- auto const& param = nnvm::get<DNNLConvFusionParam>(n.attrs.parsed);
- if (param.full_conv_param.dnnl_param.quantized) {
- status = kStart;
- matched_list.clear();
- matched_list.push_back(&n);
- return true;
- }
- } else if (n.op()->name == "_contrib_quantized_elemwise_add" ||
- n.op()->name == "_contrib_quantized_npi_add") {
- status = kStart;
- matched_list.clear();
- matched_list.push_back(&n);
- return true;
- }
+ bool Select(const BiDirectedNode& n) override {
+ const nnvm::Node* raw_node = n.node;
+ if (fuse_all && raw_node->op() &&
+ support_requantize_fusion_op_name.count(raw_node->op()->name)) {
+ status = SelectStatus::kStart;
+ matched_list.clear();
+ matched_list.emplace_back(&n);
+ return true;
}
return false;
}
- bool SelectInput(const nnvm::Node& n, const nnvm::Node& new_node) override {
+ bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& new_node)
override {
return false;
}
- bool SelectOutput(const nnvm::Node& n, const nnvm::Node& new_node) override {
- if (status == kFail || status == kSuccess || new_node.is_variable())
+ bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& new_node)
override {
+ const nnvm::Node* raw_node = n.node;
+ const nnvm::Node* raw_new_node = new_node.node;
+ if (status == SelectStatus::kFail || status == SelectStatus::kSuccess ||
+ raw_new_node->is_variable())
return false;
// If n isn't the last matched node, then we encoutered a internal
// branch, we should pop out the node behind n and stop fusion.
if (matched_list.back() != &n) {
- status = kFail;
+ if (std::find(matched_list.begin(), matched_list.end(), &n) !=
matched_list.end()) {
+ while (matched_list.back() != &n) {
+ matched_list.pop_back();
+ }
+ }
+ status = SelectStatus::kSuccess;
return false;
}
- if (new_node.op()->name == "_contrib_requantize") {
- auto const& param = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
- if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
- matched_list.push_back(&new_node);
- status = kSuccess;
- return true;
- } else {
- status = kFail;
- }
+
+ switch (status) {
+ case SelectStatus::kStart:
+ if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
+ auto const& param =
nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
+ if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
+ matched_list.emplace_back(&new_node);
+ status = SelectStatus::kRequantize;
+ if (raw_node->op() == Op::Get("_sg_onednn_conv")) {
+ status = SelectStatus::kSuccess;
+ }
+ return true;
+ }
+ }
+ case SelectStatus::kRequantize:
+ if (float_output && raw_new_node->op() ==
Op::Get("_contrib_dequantize")) {
+ CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
+ if (n.outputs.size() > 1) {
+ // check if requantize have other outputs than dequantize
+ // if it has we can't fuse dequantize
+ for (const auto& kv : n.outputs) {
+ const auto& node = kv.first;
+ if (node->op() != Op::Get("_contrib_dequantize")) {
+ status = SelectStatus::kSuccess;
+ return false;
+ }
+ }
+ }
+ matched_list.emplace_back(&new_node);
+ status = SelectStatus::kSuccess;
+ return true;
+ }
+ default:
+ status = SelectStatus::kSuccess;
+ return false;
}
- return false;
}
- std::vector<nnvm::Node*> Filter(const std::vector<nnvm::Node*>& candidates)
override {
- if (status != kSuccess) {
- return std::vector<nnvm::Node*>(0);
+ std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
+ if (status != SelectStatus::kSuccess || (matched_list.size() <= 1)) {
+ return std::vector<BiDirectedNode*>(0);
} else {
- return candidates;
+ std::vector<BiDirectedNode*> ret;
+ for (auto i : matched_list) {
+ auto non_const_i = const_cast<BiDirectedNode*>(i);
+ if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
+ ret.push_back(non_const_i);
+ }
+ }
+ return ret;
}
}
void Reset() override {
CHECK_GE(matched_list.size(), 1);
- auto new_selector = SgDNNLPostQuantizeSelector();
+ auto new_selector = SgDNNLPostQuantizeSelector(fuse_all, float_output);
new_selector.Select(*matched_list[0]);
*this = new_selector;
}
};
class SgDNNLPostQuantizeProperty : public SubgraphProperty {
+ private:
+ bool fuse_all;
+ bool float_output;
+ std::set<std::string> support_requantize_fusion_op_name;
+
public:
SgDNNLPostQuantizeProperty() {
- support_requantize_fusion_op_name.insert("_sg_onednn_conv");
-
support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add");
- support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add");
+ fuse_all =
dmlc::GetEnv("MXNET_ONEDNN_FUSE_REQUANTIZE", true);
+ float_output =
dmlc::GetEnv("MXNET_ONEDNN_FUSE_DEQUANTIZE", true);
+ support_requantize_fusion_op_name = support_req_fusion_op;
}
+
static SubgraphPropertyPtr Create() {
static const std::string& name = "oneDNN post-quantization optimization
pass";
auto property =
std::make_shared<SgDNNLPostQuantizeProperty>();
@@ -131,35 +182,47 @@ class SgDNNLPostQuantizeProperty : public
SubgraphProperty {
property->SetAttr<bool>("inference_only", true);
return property;
}
+
nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
const int subgraph_id = 0) const override
{
nnvm::ObjectPtr fuse_node = nullptr;
nnvm::ObjectPtr requantize_node = nullptr;
+ nnvm::ObjectPtr dequantize_node = nullptr;
+
DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
if (node->is_variable())
return;
- auto& op_name = node->op()->name;
- if (support_requantize_fusion_op_name.count(op_name)) {
+ if (node->op() &&
support_requantize_fusion_op_name.count(node->op()->name)) {
fuse_node = node;
- } else if (op_name == "_contrib_requantize") {
+ } else if (node->op() == Op::Get("_contrib_requantize")) {
requantize_node = node;
+ } else if (node->op() == Op::Get("_contrib_dequantize")) {
+ dequantize_node = node;
}
});
+
CHECK_NOTNULL(fuse_node);
CHECK_NOTNULL(requantize_node);
auto const& requantize_param =
nnvm::get<RequantizeParam>(requantize_node->attrs.parsed);
CHECK(requantize_param.min_calib_range.has_value());
CHECK(requantize_param.max_calib_range.has_value());
- fuse_node->attrs.dict["min_calib_range"] =
- std::to_string(requantize_param.min_calib_range.value());
- fuse_node->attrs.dict["max_calib_range"] =
- std::to_string(requantize_param.max_calib_range.value());
+
+ // When only fused quantized operator and requantize, set
min/max_cablib_range,
+ // When fused quantized operator + requantize + dequantize, set dequantize
flag to true.
+ if (dequantize_node != nullptr) {
+ fuse_node->attrs.dict["enable_float_output"] = "True";
+ } else {
+ fuse_node->attrs.dict["min_calib_range"] =
+ std::to_string(requantize_param.min_calib_range.value());
+ fuse_node->attrs.dict["max_calib_range"] =
+ std::to_string(requantize_param.max_calib_range.value());
+ }
fuse_node->op()->attr_parser(&(fuse_node->attrs));
return fuse_node;
}
- SubgraphSelectorPtr CreateSubgraphSelector() const override {
- auto selector = std::make_shared<SgDNNLPostQuantizeSelector>();
+ SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+ auto selector = std::make_shared<SgDNNLPostQuantizeSelector>(fuse_all,
float_output);
return selector;
}
@@ -170,10 +233,8 @@ class SgDNNLPostQuantizeProperty : public SubgraphProperty
{
*entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
}
}
-
- private:
- std::set<std::string> support_requantize_fusion_op_name;
};
+
} // namespace op
} // namespace mxnet
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 4a5f6a6..9727187 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -22,10 +22,7 @@
#include "dnnl_batch_dot_property.h"
#include "dnnl_bn_relu_property.h"
#include "dnnl_conv_property.h"
-#include "dnnl_elemwisemul_post_quantize_property.h"
-#include "dnnl_fc_post_quantize_property.h"
#include "dnnl_fc_property.h"
-#include "dnnl_matmul_post_quantize_property.h"
#include "dnnl_post_quantize_align_scale_property.h"
#include "dnnl_post_quantize_property.h"
#include "dnnl_transformer_qk_property.h"
@@ -54,11 +51,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLTransformerValAttPropert
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLBatchDotProperty)
.set_attr("quantize", true);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE, SgDNNLPostQuantizeProperty);
-MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLFCPostQuantizeProperty);
-MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
ElemwiseMulPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLPostQuantizeAlignScaleProperty);
-MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN_QUANTIZE,
SgDNNLMatmulPostQuantizeProperty)
- .set_attr("quantize", true);
} // namespace op
} // namespace mxnet