This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e36c9f075a Refactor fc_sum_fuse (#21077)
e36c9f075a is described below

commit e36c9f075aa7ee14e1a8ccf245df5d1e1648515b
Author: bartekkuncer <[email protected]>
AuthorDate: Thu Jul 7 16:35:29 2022 +0200

    Refactor fc_sum_fuse (#21077)
    
    * Refactor fc_sum_fuse
    
    * Fix sanity
    
    * Simplify Selector
    
    * Restore ConnectSubgraphOutputs in fc_sum_fuse_property
    
    * Fix node name
---
 ...l_fc_sum_fuse.h => dnnl_fc_sum_fuse_property.h} | 100 ++++++++-------------
 .../subgraph/dnnl/dnnl_subgraph_property.cc        |   2 +-
 2 files changed, 38 insertions(+), 64 deletions(-)

diff --git a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h 
b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
similarity index 71%
rename from src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
rename to src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
index c65711493c..2c19b7b68d 100644
--- a/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse.h
+++ b/src/operator/subgraph/dnnl/dnnl_fc_sum_fuse_property.h
@@ -26,8 +26,8 @@
   this output is scaled to the proper range.
 */
 
-#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
-#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
 #if MXNET_USE_ONEDNN == 1
 
 #include <memory>
@@ -55,8 +55,7 @@ inline bool EndsWith(std::string const& value, std::string 
const& ending) {
 class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
  private:
   bool quantized_;
-  SelectStatus status_ = kFail;
-  std::vector<const BiDirectedNode*> matched_list_;
+  bool patternFound = false;
 
  public:
   explicit SgDNNLFCSumFuseSelector(bool quantized) : quantized_(quantized) {}
@@ -64,18 +63,13 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
   bool Select(const BiDirectedNode& seed_node,
               const std::shared_ptr<NodeAttr>& node_attr) override {
     const auto n = seed_node.node;
-    if (n->op() == Op::Get("_sg_onednn_fully_connected")) {
-      if (SupportDNNLAttr(node_attr) && (seed_node.outputs.size() == 1)) {
-        auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
-        if ((!quantized_) || (fc_param.dnnl_param.quantized && 
!fc_param.dnnl_param.with_eltwise)) {
-          // Start subgraph when fusing for floats (quantized_ is false for 
ONEDNN backend) or
-          // when FC is already quantized (second pass for ONEDNN_QUANTIZE) 
but not already fuzed
-          // with elemwise operator.
-          status_ = kStart;
-          matched_list_.clear();
-          matched_list_.push_back(&seed_node);
-          return true;
-        }
+    if (n->op() == Op::Get("_sg_onednn_fully_connected") && 
seed_node.outputs.size() == 1) {
+      auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
+      if (!quantized_ || (fc_param.dnnl_param.quantized && 
!fc_param.dnnl_param.with_eltwise)) {
+        // Start subgraph when fusing for floats (quantized_ is false for 
ONEDNN backend) or
+        // when FC is already quantized (second pass for ONEDNN_QUANTIZE) but 
not already fused
+        // with elemwise operator.
+        return true;
       }
     }
     return false;
@@ -88,46 +82,29 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
   bool SelectOutput(const BiDirectedNode& cur_node, const BiDirectedNode& 
output_node) override {
     const auto cur_n    = cur_node.node;
     const auto output_n = output_node.node;
-    if (status_ == kFail || status_ == kSuccess || output_n->is_variable()) {
+    if (patternFound || output_n->is_variable()) {
       return false;
     }
-    // If n isn't the last matched node, then we encoutered an internal
-    // branch, we should pop out the node behind n and stop fusion.
-    if (matched_list_.back() != &cur_node) {
-      if (std::find(matched_list_.begin(), matched_list_.end(), &cur_node) != 
matched_list_.end()) {
-        while (matched_list_.back() != &cur_node) {
-          matched_list_.pop_back();
+
+    // Find _contrib_quantized_elemwise_add or elemwise_add
+    if (EndsWith(output_n->op()->name, "elemwise_add")) {
+      if (quantized_) {
+        auto const& fc_param = nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
+        if (!fc_param.dnnl_param.enable_float_output) {
+          // For quantized graph, when FC floating point output is not enabled 
elementwise add must
+          // also be quantized (min and max value have to be already stored in 
elementwise add).
+          CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
         }
       }
-      status_ = kSuccess;
+      patternFound = true;
+      return true;
+    } else {
       return false;
     }
-
-    switch (status_) {
-      case kStart:
-        // Find _contrib_quantized_elemwise_add or elemwise_add
-        if (EndsWith(output_n->op()->name, "elemwise_add")) {
-          if (quantized_) {
-            auto const& fc_param = 
nnvm::get<DNNLFCFullParam>(cur_n->attrs.parsed);
-            if (!fc_param.dnnl_param.enable_float_output) {
-              // For quantized graph, when FC floating point output is not 
enabled
-              // elementwise add must also be quantized (min and max value 
have to be already stored
-              // in elementwise add).
-              CHECK_EQ(output_n->attrs.dict.count("min_calib_range"), 1);
-            }
-          }
-          matched_list_.push_back(&output_node);
-          status_ = kSuccess;
-          return true;
-        }
-      default:
-        status_ = kFail;
-        return false;
-    }
   }
 
   std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
-    if (status_ == kSuccess) {
+    if (patternFound) {
       return candidates;
     } else {
       return std::vector<BiDirectedNode*>(0);
@@ -135,10 +112,7 @@ class SgDNNLFCSumFuseSelector : public SubgraphSelectorV2 {
   }
 
   void Reset() override {
-    CHECK_GE(matched_list_.size(), 1);
-    auto new_selector = SgDNNLFCSumFuseSelector(quantized_);
-    new_selector.Select(*matched_list_[0], nullptr);
-    *this = new_selector;
+    patternFound = false;
   }
 };
 
@@ -147,11 +121,11 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
   SgDNNLFCSumFuseProperty() {}
 
   static SubgraphPropertyPtr Create() {
-    static const std::string& name = "DNNL fuse FullyConnected with sum";
+    static const std::string& name = "oneDNN fuse FullyConnected with sum";
     auto property                  = 
std::make_shared<SgDNNLFCSumFuseProperty>();
     property->SetAttr<std::string>("property_name", name);
     property->SetAttr<bool>("inference_only", true);
-    if (dmlc::GetEnv("MXNET_DISABLE_DNNL_FC_SUM", 0)) {
+    if (dmlc::GetEnv("MXNET_DISABLE_ONEDNN_FC_SUM", 0)) {
       property->SetAttr<bool>("disable", true);
     }
     return property;
@@ -207,33 +181,33 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
     return selector;
   }
 
-  void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+  void ConnectSubgraphOutputs(const nnvm::ObjectPtr subgraph_node,
                               std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
     // Connect all extern output entries to output[0]
     for (size_t i = 0; i < output_entries->size(); ++i) {
       auto entry_ptr = output_entries->at(i);
-      *entry_ptr     = nnvm::NodeEntry{n, entry_ptr->index, 0};
+      *entry_ptr     = nnvm::NodeEntry{subgraph_node, entry_ptr->index, 0};
     }
   }
 
-  void ConnectSubgraphInputs(const nnvm::ObjectPtr n,
+  void ConnectSubgraphInputs(const nnvm::ObjectPtr subgraph_node,
                              std::vector<nnvm::NodeEntry*>* input_entries,
                              std::vector<nnvm::NodeEntry>* orig_input_entries) 
const override {
-    auto sym             = n->attrs.subgraphs[0];
-    auto const& fc_param = nnvm::get<DNNLFCFullParam>(n->attrs.parsed);
-    std::unordered_set<const nnvm::Node*> node_sets;
+    auto sym             = subgraph_node->attrs.subgraphs[0];
+    auto const& fc_param = 
nnvm::get<DNNLFCFullParam>(subgraph_node->attrs.parsed);
+    std::unordered_set<const nnvm::Node*> node_set;
     DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr& node) {
       if (node->is_variable()) {
         return;
       }
-      node_sets.insert(node.get());
+      node_set.insert(node.get());
       if (EndsWith(node->op()->name, "elemwise_add")) {
         const size_t base_inputs = fc_param.default_param.no_bias ? 3 : 4;
         // Make sure fc output is the left operand of the add operator, if not:
         // - swap inputs of add operator
         // - switch add operands sequence to ensure that
         // the tensor (sum_tensor) to which FC output is added is the last 
input.
-        if (node_sets.count(node->inputs[1].node.get())) {
+        if (node_set.count(node->inputs[1].node.get())) {
           // Example of input_entries reordering for channel-wise quantized 
graph:
           // sum_tensor.data    -->   fc.data
           // fc.data            -->   fc.weight0
@@ -272,7 +246,7 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
         }
       }
     });
-    n->inputs = *orig_input_entries;
+    subgraph_node->inputs = *orig_input_entries;
   }
 };
 
@@ -280,4 +254,4 @@ class SgDNNLFCSumFuseProperty : public SubgraphProperty {
 }  // namespace mxnet
 
 #endif  // if MXNET_USE_ONEDNN == 1
-#endif  // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_H_
+#endif  // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_FC_SUM_FUSE_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc 
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 69fae1c97d..86e08020ee 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -30,7 +30,7 @@
 #include "dnnl_pow_mul_scalar_property.h"
 #include "dnnl_transformer_qk_property.h"
 #include "dnnl_transformer_valatt_property.h"
-#include "dnnl_fc_sum_fuse.h"
+#include "dnnl_fc_sum_fuse_property.h"
 #include "dnnl_remove_casts_property.h"
 
 namespace mxnet {

Reply via email to