ZhennanQin commented on a change in pull request #15950: [MKLDNN]Support fullyconnected and element-wise ops fusion URL: https://github.com/apache/incubator-mxnet/pull/15950#discussion_r315916256
########## File path: src/operator/subgraph/mkldnn/mkldnn_fc_property.h ########## @@ -68,44 +73,70 @@ class SgMKLDNNFCSelector : public SubgraphSelector { } bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { - if (status == kFail || status == kSuccess || new_node.is_variable()) + if (status_ == kFail || status_ == kSuccess || new_node.is_variable()) return false; // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. - if (matched_list.back() != &n) { - if (std::find(matched_list.begin(), matched_list.end(), &n) != - matched_list.end()) { - while (matched_list.back() != &n) { - matched_list.pop_back(); + if (matched_list_.back() != &n) { + if (std::find(matched_list_.begin(), matched_list_.end(), &n) != + matched_list_.end()) { + while (matched_list_.back() != &n) { + matched_list_.pop_back(); } } - status = kSuccess; + status_ = kSuccess; return false; } - switch (status) { + switch (status_) { case kStart: - if (new_node.op() == Op::Get("Activation") && - new_node.attrs.dict.at("act_type") == "relu") { - matched_list.push_back(&new_node); - status = kSuccess; + // Currently, For INT8 FC fusion, only supports relu/bounded_relu(clip)/abs. + if (new_node.op() == Op::Get("Activation")) { + const ActivationParam ¶m = nnvm::get<ActivationParam>(new_node.attrs.parsed); + if ((quantized_ && SupportQuantizedMKLDNNAct(param)) || + (!quantized_ && SupportMKLDNNAct(param))) { + matched_list_.push_back(&new_node); + status_ = kSuccess; + return true; + } + } + if (!quantized_ && (new_node.op() == Op::Get("square") || + new_node.op() == Op::Get("sqrt") || + new_node.op() == Op::Get("exp"))) { + matched_list_.push_back(&new_node); + status_ = kSuccess; + return true; + } + if (new_node.op() == Op::Get("abs")) { + matched_list_.push_back(&new_node); + status_ = kSuccess; return true; } + if (new_node.op() == Op::Get("clip")) { + const ClipParam ¶m = nnvm::get<ClipParam>(new_node.attrs.parsed); + if (param.a_min == 0.f && param.a_max == 1.0f) { Review comment: why a_max have to be 1.0f? I think it's not necessary. ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services