This is an automated email from the ASF dual-hosted git repository.
zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 748e100 [FEATURE] Add BN+ReLU -> BatchNormWithRelu fuse to MKLDNN
backend (#19578)
748e100 is described below
commit 748e100929bd2e862c4a808531478e4ccf35d8ce
Author: Adam <[email protected]>
AuthorDate: Sun Nov 29 02:19:57 2020 +0100
[FEATURE] Add BN+ReLU -> BatchNormWithRelu fuse to MKLDNN backend (#19578)
* Add BN+ReLU fuse to MKLDNN backend
* Review and bug fixes
* Remove commented out code
* Review fixes v2
* Add tests
* Sanity check fixes
---
src/operator/contrib/batch_norm_relu.cc | 11 +-
.../subgraph/mkldnn/mkldnn_bn_relu_property.h | 142 +++++++++++++++++++++
.../subgraph/mkldnn/mkldnn_subgraph_property.cc | 21 +--
tests/python/mkl/test_subgraph.py | 18 +++
4 files changed, 173 insertions(+), 19 deletions(-)
diff --git a/src/operator/contrib/batch_norm_relu.cc
b/src/operator/contrib/batch_norm_relu.cc
index 51aa4c5..890239d 100644
--- a/src/operator/contrib/batch_norm_relu.cc
+++ b/src/operator/contrib/batch_norm_relu.cc
@@ -127,9 +127,14 @@ static bool BatchNormWithReLUType(const nnvm::NodeAttrs&
attrs,
#if MXNET_USE_MKLDNN == 1
static inline bool SupportMKLDNNBNReLU(const NDArray &input, const
BatchNormParam ¶m) {
- mxnet::TShape shape = input.shape();
- return SupportMKLDNN(input) && shape.ndim() == 4
- && param.axis == mxnet::op::batchnormrelu::DEFAULT_AXIS;
+ if (mxnet::op::batchnorm::disable_mkl) return false;
+ const mxnet::TShape shape = input.shape();
+ const int ndim = shape.ndim();
+ if (ndim == 0 || shape.Size() == 0) return false;
+ const int dtype = input.dtype();
+ return (dtype == mshadow::kFloat32 ||
+ dtype == mshadow::kBfloat16) &&
+ SupportStorageMKLDNN(input.storage_type());
}
void BatchNormWithReLUComputeExCPU(const nnvm::NodeAttrs &attrs,
diff --git a/src/operator/subgraph/mkldnn/mkldnn_bn_relu_property.h
b/src/operator/subgraph/mkldnn/mkldnn_bn_relu_property.h
new file mode 100644
index 0000000..84eb18b
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_bn_relu_property.h
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_BN_RELU_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_BN_RELU_PROPERTY_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <string>
+#include <vector>
+#include "../common.h"
+#include "mkldnn_subgraph_base-inl.h"
+#include "../../nn/mkldnn/mkldnn_batch_norm-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNBNReLUSelector : public SubgraphSelector {
+ public:
+ enum SelectStatus {
+ kStart,
+ kSuccess,
+ kFail
+ };
+
+ explicit SgMKLDNNBNReLUSelector(const bool disable_bn_relu) :
+ disable_bn_relu_(disable_bn_relu), status_(kStart) {}
+
+ bool Select(const nnvm::Node &n) override {
+ return n.op() && n.op()->name == "BatchNorm";
+ }
+
+ bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ return false;
+ }
+
+ bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+ if (n.op() && n.op()->name == "BatchNorm") {
+ if (new_node.op() && status_ == kStart &&
+ (new_node.op()->name == "relu" || (new_node.op()->name ==
"Activation" &&
+ nnvm::get<ActivationParam>(new_node.attrs.parsed).act_type ==
activation::kReLU))) {
+ status_ = kSuccess;
+ return true;
+ } else {
+ // Do not fuse if BatchNorm is connected to other nodes
+ // e.g: ->- BN --- ReLU --- elementwise_add ->-
+ // \ /
+ // \-------->--------/
+ status_ = kFail;
+ return false;
+ }
+ }
+ return false;
+ }
+
+ std::vector<nnvm::Node *> Filter(
+ const std::vector<nnvm::Node *> &candidates) override {
+ if (!disable_bn_relu_ && status_ == kSuccess)
+ return candidates;
+ else
+ return std::vector<nnvm::Node *>();
+ }
+
+ private:
+ bool disable_bn_relu_;
+ SelectStatus status_;
+};
+
+class SgMKLDNNBNReLUProperty : public SubgraphProperty {
+ public:
+ SgMKLDNNBNReLUProperty() {
+ disable_bn_relu_ = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_BN_RELU",
false);
+ }
+
+ void PrePartition(const nnvm::Graph& g,
+ const std::unordered_map<std::string, std::string>& options_map) override {
+ dedup_subgraph = true;
+ }
+
+ static SubgraphPropertyPtr Create() {
+ static const std::string &name = "MKLDNN BN + ReLU optimization pass";
+ auto property = std::make_shared<SgMKLDNNBNReLUProperty>();
+ property->SetAttr<std::string>("property_name", name);
+ property->SetAttr<bool>("inference_only", true);
+ if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_BN_RELU_OPT", 0)) {
+ property->SetAttr<bool>("disable", true);
+ }
+ return property;
+ }
+
+ nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol &sym,
+ const int subgraph_id = 0) const override {
+ nnvm::ObjectPtr n = nnvm::Node::Create();
+
+ std::ostringstream node_name;
+ node_name << "sg_mkldnn_batch_norm_relu_" << std::to_string(subgraph_id);
+
+ // Copy params from BatchNorm node into subgraph BatchNormReLU node
+ BatchNormParam param;
+ DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr &node) {
+ if (node->op() && node->op()->name == "BatchNorm") {
+ param = nnvm::get<BatchNormParam>(node->attrs.parsed);
+ }
+ });
+
+ n->attrs.name = node_name.str();
+ n->attrs.op = Op::Get("_contrib_BatchNormWithReLU");
+ CHECK(n->attrs.op);
+ n->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(sym));
+ n->attrs.parsed = param;
+ return n;
+ }
+
+ SubgraphSelectorPtr CreateSubgraphSelector() const override {
+ auto selector = std::make_shared<SgMKLDNNBNReLUSelector>(disable_bn_relu_);
+ return selector;
+ }
+
+ private:
+ bool disable_bn_relu_;
+};
+
+} // namespace op
+} // namespace mxnet
+
+#endif // if MXNET_USE_MKLDNN == 1
+#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_BN_RELU_PROPERTY_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index 18cd303..07f06cd 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -21,6 +21,7 @@
#include "mkldnn_conv_property.h"
#include "mkldnn_fc_property.h"
+#include "mkldnn_bn_relu_property.h"
#include "mkldnn_post_quantize_property.h"
#include "mkldnn_fc_post_quantize_property.h"
#include "mkldnn_elemwisemul_post_quantize_property.h"
@@ -34,35 +35,23 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
.set_attr("context", Context::CPU());
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
-
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNBNReLUProperty);
+
MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
.set_attr("context", Context::CPU());
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)
.set_attr("quantize", true);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
.set_attr("quantize", true);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
-MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE,
SgMKLDNNPostQuantizeProperty);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE,
SgMKLDNNPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE,
SgMKLDNNFCPostQuantizeProperty);
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE,
ElemwiseMulPostQuantizeProperty);
-
MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE,
SgMKLDNNPostQuantizeAlignScaleProperty);
-#endif // MXNET_USE_MKLDNN == 1
-#if MXNET_USE_MKLDNN == 1
+
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_MKLDNN == 1
diff --git a/tests/python/mkl/test_subgraph.py
b/tests/python/mkl/test_subgraph.py
index 4a10cef..2f09854 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -22,6 +22,7 @@ import numpy as np
import unittest
import ctypes
import pytest
+from mxnet.test_utils import assert_almost_equal
def test_float64_fallback():
sym = mx.sym.FullyConnected(
@@ -37,3 +38,20 @@ def test_float64_fallback():
ex = sym._bind(mx.cpu(), args, args_grad=None, grad_req='write')
ex.forward()
ex.outputs[0].wait_to_read()
+
+
[email protected]('axis', [0, 1, 2, 3])
+def test_bn_relu_fusion(axis):
+ dummy_data = mx.nd.uniform(-1.0, 1.0, shape=(32, 3, 224, 224))
+
+ net = mx.gluon.nn.HybridSequential()
+ net.add(mx.gluon.nn.BatchNorm(axis=axis))
+ net.add(mx.gluon.nn.Activation('relu'))
+ net.initialize()
+
+ out1 = net(dummy_data)
+ out1.wait_to_read()
+ net.optimize_for(dummy_data, backend='MKLDNN')
+ out2 = net(dummy_data)
+
+ assert_almost_equal(out1, out2)