ciyongch commented on a change in pull request #15950: [MKLDNN]Support
fullyconnected and element-wise ops fusion
URL: https://github.com/apache/incubator-mxnet/pull/15950#discussion_r316002034
##########
File path: src/operator/subgraph/mkldnn/mkldnn_fc_property.h
##########
@@ -68,44 +73,70 @@ class SgMKLDNNFCSelector : public SubgraphSelector {
}
bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
- if (status == kFail || status == kSuccess || new_node.is_variable())
+ if (status_ == kFail || status_ == kSuccess || new_node.is_variable())
return false;
// If n isn't the last matched node, then we encoutered a internal
// branch, we should pop out the node behind n and stop fusion.
- if (matched_list.back() != &n) {
- if (std::find(matched_list.begin(), matched_list.end(), &n) !=
- matched_list.end()) {
- while (matched_list.back() != &n) {
- matched_list.pop_back();
+ if (matched_list_.back() != &n) {
+ if (std::find(matched_list_.begin(), matched_list_.end(), &n) !=
+ matched_list_.end()) {
+ while (matched_list_.back() != &n) {
+ matched_list_.pop_back();
}
}
- status = kSuccess;
+ status_ = kSuccess;
return false;
}
- switch (status) {
+ switch (status_) {
case kStart:
- if (new_node.op() == Op::Get("Activation") &&
- new_node.attrs.dict.at("act_type") == "relu") {
- matched_list.push_back(&new_node);
- status = kSuccess;
+ // Currently, For INT8 FC fusion, only supports
relu/bounded_relu(clip)/abs.
+ if (new_node.op() == Op::Get("Activation")) {
+ const ActivationParam ¶m =
nnvm::get<ActivationParam>(new_node.attrs.parsed);
+ if ((quantized_ && SupportQuantizedMKLDNNAct(param)) ||
+ (!quantized_ && SupportMKLDNNAct(param))) {
+ matched_list_.push_back(&new_node);
+ status_ = kSuccess;
+ return true;
+ }
+ }
+ if (!quantized_ && (new_node.op() == Op::Get("square") ||
+ new_node.op() == Op::Get("sqrt") ||
+ new_node.op() == Op::Get("exp"))) {
+ matched_list_.push_back(&new_node);
+ status_ = kSuccess;
+ return true;
+ }
+ if (new_node.op() == Op::Get("abs")) {
+ matched_list_.push_back(&new_node);
+ status_ = kSuccess;
return true;
}
+ if (new_node.op() == Op::Get("clip")) {
+ const ClipParam ¶m = nnvm::get<ClipParam>(new_node.attrs.parsed);
+ if (param.a_min == 0.f && param.a_max == 1.0f) {
Review comment:
Good catch, will remove this check for a_max for `bounded_relu`.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services