This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 748e100  [FEATURE] Add BN+ReLU -> BatchNormWithRelu fuse to MKLDNN 
backend (#19578)
748e100 is described below

commit 748e100929bd2e862c4a808531478e4ccf35d8ce
Author: Adam <[email protected]>
AuthorDate: Sun Nov 29 02:19:57 2020 +0100

    [FEATURE] Add BN+ReLU -> BatchNormWithRelu fuse to MKLDNN backend (#19578)
    
    * Add BN+ReLU fuse to MKLDNN backend
    
    * Review and bug fixes
    
    * Remove commented out code
    
    * Review fixes v2
    
    * Add tests
    
    * Sanity check fixes
---
 src/operator/contrib/batch_norm_relu.cc            |  11 +-
 .../subgraph/mkldnn/mkldnn_bn_relu_property.h      | 142 +++++++++++++++++++++
 .../subgraph/mkldnn/mkldnn_subgraph_property.cc    |  21 +--
 tests/python/mkl/test_subgraph.py                  |  18 +++
 4 files changed, 173 insertions(+), 19 deletions(-)

diff --git a/src/operator/contrib/batch_norm_relu.cc 
b/src/operator/contrib/batch_norm_relu.cc
index 51aa4c5..890239d 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -127,9 +127,14 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs& 
attrs,
 
 #if MXNET_USE_MKLDNN == 1
 static inline bool SupportMKLDNNBNReLU(const NDArray &input, const 
BatchNormParam &param) {
-  mxnet::TShape shape = input.shape();
-  return SupportMKLDNN(input) && shape.ndim() == 4
-      && param.axis == mxnet::op::batchnormrelu::DEFAULT_AXIS;
+  if (mxnet::op::batchnorm::disable_mkl) return false;
+  const mxnet::TShape shape = input.shape();
+  const int ndim = shape.ndim();
+  if (ndim == 0 || shape.Size() == 0) return false;
+  const int dtype = input.dtype();
+  return (dtype == mshadow::kFloat32 ||
+          dtype == mshadow::kBfloat16) &&
+          SupportStorageMKLDNN(input.storage_type());
 }
 
 void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs,
diff --git a/src/operator/subgraph/mkldnn/mkldnn_bn_relu_property.h 
b/src/operator/subgraph/mkldnn/mkldnn_bn_relu_property.h
new file mode 100644
index 0000000..84eb18b
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_bn_relu_property.h
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_BN_RELU_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_BN_RELU_PROPERTY_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <string>
+#include <vector>
+#include "../common.h"
+#include "mkldnn_subgraph_base-inl.h"
+#include "../../nn/mkldnn/mkldnn_batch_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNBNReLUSelector : public SubgraphSelector {
+ public:
+  enum SelectStatus {
+    kStart,
+    kSuccess,
+    kFail
+  };
+
+  explicit SgMKLDNNBNReLUSelector(const bool disable_bn_relu) :
+      disable_bn_relu_(disable_bn_relu), status_(kStart) {}
+
+  bool Select(const nnvm::Node &n) override {
+    return n.op() && n.op()->name == "BatchNorm";
+  }
+
+  bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    return false;
+  }
+
+  bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    if (n.op() && n.op()->name == "BatchNorm") {
+      if (new_node.op() && status_ == kStart &&
+          (new_node.op()->name == "relu" || (new_node.op()->name == 
"Activation" &&
+           nnvm::get<ActivationParam>(new_node.attrs.parsed).act_type == 
activation::kReLU))) {
+        status_ = kSuccess;
+        return true;
+      } else {
+        // Do not fuse if BatchNorm is connected to other nodes
+        // e.g: ->- BN --- ReLU --- elementwise_add ->-
+        //           \                   /
+        //            \-------->--------/
+        status_ = kFail;
+        return false;
+      }
+    }
+    return false;
+  }
+
+  std::vector<nnvm::Node *> Filter(
+      const std::vector<nnvm::Node *> &candidates) override {
+    if (!disable_bn_relu_ && status_ == kSuccess)
+      return candidates;
+    else
+      return std::vector<nnvm::Node *>();
+  }
+
+ private:
+  bool disable_bn_relu_;
+  SelectStatus status_;
+};
+
+class SgMKLDNNBNReLUProperty : public SubgraphProperty {
+ public:
+  SgMKLDNNBNReLUProperty() {
+    disable_bn_relu_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_BN_RELU", 
false);
+  }
+
+  void PrePartition(const nnvm::Graph& g,
+    const std::unordered_map<std::string, std::string>& options_map) override {
+    dedup_subgraph = true;
+  }
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string &name = "MKLDNN BN + ReLU optimization pass";
+    auto property = std::make_shared<SgMKLDNNBNReLUProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_BN_RELU_OPT", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                   const int subgraph_id = 0) const override {
+    nnvm::ObjectPtr n = nnvm::Node::Create();
+
+    std::ostringstream node_name;
+    node_name << "sg_mkldnn_batch_norm_relu_" << std::to_string(subgraph_id);
+
+    // Copy params from BatchNorm node into subgraph BatchNormReLU node
+    BatchNormParam param;
+    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) {
+      if (node->op() && node->op()->name == "BatchNorm") {
+        param = nnvm::get<BatchNormParam>(node->attrs.parsed);
+      }
+    });
+
+    n->attrs.name = node_name.str();
+    n->attrs.op = Op::Get("_contrib_BatchNormWithReLU");
+    CHECK(n->attrs.op);
+    n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(sym));
+    n->attrs.parsed = param;
+    return n;
+  }
+
+  SubgraphSelectorPtr CreateSubgraphSelector() const override {
+    auto selector = std::make_shared<SgMKLDNNBNReLUSelector>(disable_bn_relu_);
+    return selector;
+  }
+
+ private:
+  bool disable_bn_relu_;
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_BN_RELU_PROPERTY_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc 
b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index 18cd303..07f06cd 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -21,6 +21,7 @@
 
 #include "mkldnn_conv_property.h"
 #include "mkldnn_fc_property.h"
+#include "mkldnn_bn_relu_property.h"
 #include "mkldnn_post_quantize_property.h"
 #include "mkldnn_fc_post_quantize_property.h"
 #include "mkldnn_elemwisemul_post_quantize_property.h"
@@ -34,35 +35,23 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
 .set_attr("context", Context::CPU());
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
-
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNBNReLUProperty);
+
 MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
 .set_attr("context", Context::CPU());
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)
 .set_attr("quantize", true);
 
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
 .set_attr("quantize", true);
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
-MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, 
SgMKLDNNPostQuantizeProperty);
-#endif  // MXNET_USE_MKLDNN == 1
 
-#if MXNET_USE_MKLDNN == 1
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, 
SgMKLDNNPostQuantizeProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, 
SgMKLDNNFCPostQuantizeProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, 
ElemwiseMulPostQuantizeProperty);
-
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, 
SgMKLDNNPostQuantizeAlignScaleProperty);
-#endif  // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+
 }  // namespace op
 }  // namespace mxnet
 #endif  // MXNET_USE_MKLDNN == 1
diff --git a/tests/python/mkl/test_subgraph.py 
b/tests/python/mkl/test_subgraph.py
index 4a10cef..2f09854 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -22,6 +22,7 @@ import numpy as np
 import unittest
 import ctypes
 import pytest
+from mxnet.test_utils import assert_almost_equal
 
 def test_float64_fallback():
     sym = mx.sym.FullyConnected(
@@ -37,3 +38,20 @@ def test_float64_fallback():
     ex = sym._bind(mx.cpu(), args, args_grad=None, grad_req='write')
     ex.forward()
     ex.outputs[0].wait_to_read()
+
+
[email protected]('axis', [0, 1, 2, 3])
+def test_bn_relu_fusion(axis):
+    dummy_data = mx.nd.uniform(-1.0, 1.0, shape=(32, 3, 224, 224))
+
+    net = mx.gluon.nn.HybridSequential()
+    net.add(mx.gluon.nn.BatchNorm(axis=axis))
+    net.add(mx.gluon.nn.Activation('relu'))
+    net.initialize()
+
+    out1 = net(dummy_data)
+    out1.wait_to_read()
+    net.optimize_for(dummy_data, backend='MKLDNN')
+    out2 = net(dummy_data)
+
+    assert_almost_equal(out1, out2)

Reply via email to