This is an automated email from the ASF dual-hosted git repository.
bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e36c9f075a Refactor fc_sum_fuse (#21077)
e36c9f075a is described below
commit e36c9f075aa7ee14e1a8ccf245df5d1e1648515b
Author: bartekkuncer <[email protected]>
AuthorDate: Thu Jul 7 16:35:29 2022 +0200
Refactor fc_sum_fuse (#21077)
* Refactor fc_sum_fuse
* Fix sanity
* Simplify Selector
* Restore ConnectSubgraphOutputs in fc_sum_fuse_property
* Fix node name
---
...l_fc_sum_fuse.h => dnnl_fc_sum_fuse_property.h} | 100 ++++++++-------------
.../subgraph/dnnl/dnnl_subgraph_property.cc | 2 +-
2 files changed, 38 insertions(+), 64 deletions(-)
diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
similarity index 71%
rename from src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
rename to src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
index c65711493c..2c19b7b68d 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
@@ -26,8 +26,8 @@
this output is scaled to the proper range.
*/
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
#if MXNET_USE_ONEDNN == 1
#include <memory>
@@ -55,8 +55,7 @@ inline bool EndsWith(std::string const& value, std::string
const& ending) {
class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
private:
bool quantized_;
- SelectStatus status_ = kFail;
- std::vector<const BiDirectedNode*> matched_list_;
+ bool patternFound = false;
public:
explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {}
@@ -64,18 +63,13 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
bool Select(const BiDirectedNode& seed_node,
const std::shared_ptr<NodeAttr>& node_attr) override {
const auto n = seed_node.node;
- if (n->op() == Op::Get("_sg_onednn_fully_connected")) {
- if (SupportDNNLAttr(node_attr) && (seed_node.outputs.size() == 1)) {
- auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
- if ((!quantized_) || (fc_param.dnnl_param.quantized &&
!fc_param.dnnl_param.with_eltwise)) {
- // Start subgraph when fusing for floats (quantized_ is false for
ONEDNN backend) or
- // when FC is already quantized (second pass for ONEDNN_QUANTIZE)
but not already fuzed
- // with elemwise operator.
- status_ = kStart;
- matched_list_.clear();
- matched_list_.push_back(&seed_node);
- return true;
- }
+ if (n->op() == Op::Get("_sg_onednn_fully_connected") &&
seed_node.outputs.size() == 1) {
+ auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
+ if (!quantized_ || (fc_param.dnnl_param.quantized &&
!fc_param.dnnl_param.with_eltwise)) {
+ // Start subgraph when fusing for floats (quantized_ is false for
ONEDNN backend) or
+ // when FC is already quantized (second pass for ONEDNN_QUANTIZE) but
not already fused
+ // with elemwise operator.
+ return true;
}
}
return false;
@@ -88,46 +82,29 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode&
output_node) override {
const auto cur_n = cur_node.node;
const auto output_n = output_node.node;
- if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
+ if (patternFound || output_n->is_variable()) {
return false;
}
- // If n isn't the last matched node, then we encoutered an internal
- // branch, we should pop out the node behind n and stop fusion.
- if (matched_list_.back() != &cur_node) {
- if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) !=
matched_list_.end()) {
- while (matched_list_.back() != &cur_node) {
- matched_list_.pop_back();
+
+ // Find _contrib_quantized_elemwise_add or elemwise_add
+ if (EndsWith(output_n->op()->name, "elemwise_add")) {
+ if (quantized_) {
+ auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
+ if (!fc_param.dnnl_param.enable_float_output) {
+ // For quantized graph, when FC floating point output is not enabled
elementwise add must
+ // also be quantized (min and max value have to be already stored in
elementwise add).
+ CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
}
}
- status_ = kSuccess;
+ patternFound = true;
+ return true;
+ } else {
return false;
}
-
- switch (status_) {
- case kStart:
- // Find _contrib_quantized_elemwise_add or elemwise_add
- if (EndsWith(output_n->op()->name, "elemwise_add")) {
- if (quantized_) {
- auto const& fc_param =
nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
- if (!fc_param.dnnl_param.enable_float_output) {
- // For quantized graph, when FC floating point output is not
enabled
- // elementwise add must also be quantized (min and max value
have to be already stored
- // in elementwise add).
- CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
- }
- }
- matched_list_.push_back(&output_node);
- status_ = kSuccess;
- return true;
- }
- default:
- status_ = kFail;
- return false;
- }
}
std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>&
candidates) override {
- if (status_ == kSuccess) {
+ if (patternFound) {
return candidates;
} else {
return std::vector<BiDirectedNode*>(0);
@@ -135,10 +112,7 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
}
void Reset() override {
- CHECK_GE(matched_list_.size(), 1);
- auto new_selector = SgDNNLFCSumFuseSelector(quantized_);
- new_selector.Select(*matched_list_[0], nullptr);
- *this = new_selector;
+ patternFound = false;
}
};
@@ -147,11 +121,11 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
SgDNNLFCSumFuseProperty() {}
static SubgraphPropertyPtr Create() {
- static const std::string& name = "DNNL fuse FullyConnected with sum";
+ static const std::string& name = "oneDNN fuse FullyConnected with sum";
auto property =
std::make_shared<SgDNNLFCSumFuseProperty>();
property->SetAttr<std::string>("property_name", name);
property->SetAttr<bool>("inference_only", true);
- if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) {
+ if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_FC_SUM", 0)) {
property->SetAttr<bool>("disable", true);
}
return property;
@@ -207,33 +181,33 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
return selector;
}
- void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+ void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
std::vector<nnvm::NodeEntry*>* output_entries)
const override {
// Connect all extern output entries to output[0]
for (size_t i = 0; i < output_entries->size(); ++i) {
auto entry_ptr = output_entries->at(i);
- *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0};
+ *entry_ptr = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0};
}
}
- void ConnectSubgraphInputs(const nnvm::ObjectPtr n,
+ void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node,
std::vector<nnvm::NodeEntry*>* input_entries,
std::vector<nnvm::NodeEntry>* orig_input_entries)
const override {
- auto sym = n->attrs.subgraphs[0];
- auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
- std::unordered_set<const nnvm::Node*> node_sets;
+ auto sym = subgraph_node->attrs.subgraphs[0];
+ auto const& fc_param =
nnvm::get<DNNLFCFullParam>(subgraph_node->attrs.parsed);
+ std::unordered_set<const nnvm::Node*> node_set;
DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
if (node->is_variable()) {
return;
}
- node_sets.insert(node.get());
+ node_set.insert(node.get());
if (EndsWith(node->op()->name, "elemwise_add")) {
const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
// Make sure fc output is the left operand of the add operator, if not:
// - swap inputs of add operator
// - switch add operands sequence to ensure that
// the tensor (sum_tensor) to which FC output is added is the last
input.
- if (node_sets.count(node->inputs[1].node.get())) {
+ if (node_set.count(node->inputs[1].node.get())) {
// Example of input_entries reordering for channel-wise quantized
graph:
// sum_tensor.data --> fc.data
// fc.data --> fc.weight0
@@ -272,7 +246,7 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
}
}
});
- n->inputs = *orig_input_entries;
+ subgraph_node->inputs = *orig_input_entries;
}
};
@@ -280,4 +254,4 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
} // namespace mxnet
#endif // if MXNET_USE_ONEDNN == 1
-#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#endif // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 69fae1c97d..86e08020ee 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -30,7 +30,7 @@
#include "dnnl_pow_mul_scalar_property.h"
#include "dnnl_transformer_qk_property.h"
#include "dnnl_transformer_valatt_property.h"
-#include "dnnl_fc_sum_fuse.h"
+#include "dnnl_fc_sum_fuse_property.h"
#include "dnnl_remove_casts_property.h"
namespace mxnet {