ZhennanQin commented on a change in pull request #15950: [MKLDNN]Support 
fullyconnected and element-wise ops fusion
URL: https://github.com/apache/incubator-mxnet/pull/15950#discussion_r315916256
 
 

 ##########
 File path: src/operator/subgraph/mkldnn/mkldnn_fc_property.h
 ##########
 @@ -68,44 +73,70 @@ class SgMKLDNNFCSelector : public SubgraphSelector {
   }
 
   bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
-    if (status == kFail || status == kSuccess || new_node.is_variable())
+    if (status_ == kFail || status_ == kSuccess || new_node.is_variable())
       return false;
 
     // If n isn't the last matched node, then we encoutered a internal
     // branch, we should pop out the node behind n and stop fusion.
-    if (matched_list.back() != &n) {
-      if (std::find(matched_list.begin(), matched_list.end(), &n) !=
-        matched_list.end()) {
-        while (matched_list.back() != &n) {
-          matched_list.pop_back();
+    if (matched_list_.back() != &n) {
+      if (std::find(matched_list_.begin(), matched_list_.end(), &n) !=
+        matched_list_.end()) {
+        while (matched_list_.back() != &n) {
+          matched_list_.pop_back();
         }
       }
 
-      status = kSuccess;
+      status_ = kSuccess;
       return false;
     }
 
-    switch (status) {
+    switch (status_) {
       case kStart:
-        if (new_node.op() == Op::Get("Activation") &&
-            new_node.attrs.dict.at("act_type") == "relu") {
-          matched_list.push_back(&new_node);
-          status = kSuccess;
+        // Currently, For INT8 FC fusion, only supports 
relu/bounded_relu(clip)/abs.
+        if (new_node.op() == Op::Get("Activation")) {
+          const ActivationParam &param = 
nnvm::get<ActivationParam>(new_node.attrs.parsed);
+          if ((quantized_ && SupportQuantizedMKLDNNAct(param)) ||
+              (!quantized_ && SupportMKLDNNAct(param))) {
+            matched_list_.push_back(&new_node);
+            status_ = kSuccess;
+            return true;
+          }
+        }
+        if (!quantized_ && (new_node.op() == Op::Get("square") ||
+            new_node.op() == Op::Get("sqrt") ||
+            new_node.op() == Op::Get("exp"))) {
+          matched_list_.push_back(&new_node);
+          status_ = kSuccess;
+          return true;
+        }
+        if (new_node.op() == Op::Get("abs")) {
+          matched_list_.push_back(&new_node);
+          status_ = kSuccess;
           return true;
         }
+        if (new_node.op() == Op::Get("clip")) {
+          const ClipParam &param = nnvm::get<ClipParam>(new_node.attrs.parsed);
+          if (param.a_min == 0.f && param.a_max == 1.0f) {
 
 Review comment:
   why a_max have to be 1.0f? I think it's not necessary.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to