This is an automated email from the ASF dual-hosted git repository.

jevans pushed a commit to branch v1.9.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.9.x by this push:
     new 453ccb8  identity fuse (#20884)
453ccb8 is described below

commit 453ccb8e2ea3bf0c883591e24b9c73c3809988ff
Author: bgawrych <[email protected]>
AuthorDate: Tue Feb 15 18:22:24 2022 +0100

    identity fuse (#20884)
    
    rewrite test
    
    fix sanity
    
    remove clang warning
    
    Co-authored-by: Bartlomiej Gawrych <[email protected]>
---
 .../subgraph/mkldnn/mkldnn_identity_property.h     | 173 +++++++++++++++++++++
 .../subgraph/mkldnn/mkldnn_subgraph_base-inl.h     |   2 +-
 .../subgraph/mkldnn/mkldnn_subgraph_property.cc    |   6 +
 tests/python/mkl/test_subgraph.py                  |  28 ++++
 4 files changed, 208 insertions(+), 1 deletion(-)

diff --git a/src/operator/subgraph/mkldnn/mkldnn_identity_property.h 
b/src/operator/subgraph/mkldnn/mkldnn_identity_property.h
new file mode 100644
index 0000000..00f9499
--- /dev/null
+++ b/src/operator/subgraph/mkldnn/mkldnn_identity_property.h
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file mkldnn_identity_property.cc
+ * \brief Graph property for removing identity operators
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_IDENTITY_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_IDENTITY_PROPERTY_H_
+#if MXNET_USE_MKLDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "../common.h"
+#include "../../nn/dropout-inl.h"
+#include "mkldnn_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgMKLDNNIdentitySelector : public SubgraphSelectorV2 {
+ private:
+  std::vector<const BiDirectedNode*> matched_list_;
+  bool pattern_found = false;
+
+ public:
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    bool status = false;
+    if (seed_node.node->op() == Op::Get("_copy")) {
+      status = true;
+    }
+
+    if (seed_node.node->op() == Op::Get("Dropout")) {
+      auto const& dropout_param = 
nnvm::get<DropoutParam>(seed_node.node->attrs.parsed);
+      if (dropout_param.mode == dropout::kTraining) {
+        status = true;
+      }
+    }
+
+    if (status) {
+      matched_list_.clear();
+      matched_list_.emplace_back(&seed_node);
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) 
override {
+    if (pattern_found || input_node.node->is_variable()) {
+      return false;
+    } else if (input_node.node->op()) {
+      matched_list_.emplace_back(&input_node);
+      pattern_found = true;
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& 
output_node) override {
+    return false;
+  }
+
+  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
+    // candidates should contain only two nodes - custom node and identity node
+    if (pattern_found && candidates.size() == matched_list_.size()) {
+      CHECK_EQ(candidates.size(), 2);
+      return candidates;
+    } else {
+      return std::vector<BiDirectedNode*>(0);
+    }
+  }
+
+  void Reset() override {
+    CHECK_GE(matched_list_.size(), 1);
+    auto new_selector = SgMKLDNNIdentitySelector();
+    new_selector.Select(*matched_list_[0], nullptr);
+    *this = new_selector;
+  }
+};
+
+inline bool IsIdentityNode(const nnvm::ObjectPtr node) {
+  return node->op() && (node->op() == Op::Get("_copy") || node->op() == 
Op::Get("Dropout"));
+}
+
+class SgMKLDNNIdentityProperty : public SubgraphProperty {
+ public:
+  SgMKLDNNIdentityProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string& name = "MKLDNN Identity optimization passs";
+    auto property                  = 
std::make_shared<SgMKLDNNIdentityProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+                                     const int subgraph_id = 0) const override 
{
+    nnvm::NodeEntry identity_node_entry;
+    for (auto entry : sym.outputs) {
+      if (IsIdentityNode(entry.node)) {
+        identity_node_entry = entry;
+      }
+    }
+
+    auto last_node = identity_node_entry.node;
+    nnvm::Symbol new_sym;
+    new_sym.outputs.emplace_back(last_node);
+
+    nnvm::ObjectPtr org_node;
+    DFSVisit(new_sym.outputs, [&](const nnvm::ObjectPtr& node) {
+      if (!IsIdentityNode(node)) {
+        org_node = node;
+      }
+    });
+
+    // Create copy of original node
+    nnvm::ObjectPtr n = nnvm::Node::Create();
+    n->attrs          = org_node->attrs;
+    if (n->op() && n->op()->attr_parser) {
+      n->op()->attr_parser(&(n->attrs));
+    }
+
+    return n;
+  }
+
+  void ConnectSubgraphOutputs(const nnvm::ObjectPtr n,
+                              std::vector<nnvm::NodeEntry*>* output_entries) 
const override {
+    // output of identity must be connected as output of operator before 
identity
+    // e.g. for:        /--index 0--> custom_op
+    //         (n) slice
+    //                  \--index 1--> Dropout --index 0--> OUT_NODE
+    //  for OUT_NODE index 0 must be changed to index 1
+    for (size_t i = 0; i < output_entries->size(); ++i) {
+      auto out_node = output_entries->at(i)->node;
+      if (IsIdentityNode(out_node)) {
+        output_entries->at(i)->index = out_node->inputs[0].index;
+      }
+      output_entries->at(i)->node = n;
+    }
+  }
+
+  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+    auto selector = std::make_shared<SgMKLDNNIdentitySelector>();
+    return selector;
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_MKLDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_IDENTITY_PROPERTY_H_
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h 
b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
index 6436852..9e0f9bc 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_base-inl.h
@@ -31,7 +31,7 @@ static inline bool SupportMKLDNNAttr(const 
std::shared_ptr<NodeAttr>& node_attr)
     return (node_attr->dispatch_mode == DispatchMode::kFComputeEx) &&
            (node_attr->itype[0] == mshadow::kFloat32 ||
             node_attr->itype[0] == mshadow::kBfloat16) &&
-           (ndim == 1 || ndim == 2 || ndim == 4 || ndim == 5);
+           (ndim >= 1 && ndim <= 5);
   } else {
     return true;
   }
diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc 
b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
index 9190ba4..e41bf7d 100644
--- a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
+++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc
@@ -21,6 +21,7 @@
 
 #include "mkldnn_conv_property.h"
 #include "mkldnn_fc_property.h"
+#include "mkldnn_identity_property.h"
 #include "mkldnn_post_quantize_property.h"
 #include "mkldnn_fc_post_quantize_property.h"
 #include "mkldnn_elemwisemul_post_quantize_property.h"
@@ -35,6 +36,8 @@ MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN)
 .set_attr("enable", MKLDNNEnvSet())
 .set_attr("context", Context::CPU());
 
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNIdentityProperty);
+
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty);
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty);
@@ -44,12 +47,15 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, 
SgMKLDNNTransformerProperty);
 MXNET_REGISTER_SUBGRAPH_BACKEND(MKLDNN_QUANTIZE)
 .set_attr("context", Context::CPU());
 
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNIdentityProperty);
+
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNConvProperty)
 .set_attr("quantize", true);
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNFCProperty)
 .set_attr("quantize", true);
 
+MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNIdentityProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, SgMKLDNNTransformerProperty);
 
 MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_QUANTIZE, 
SgMKLDNNTransformerPostQuantizeProperty);
diff --git a/tests/python/mkl/test_subgraph.py 
b/tests/python/mkl/test_subgraph.py
index 811b006..c1b73fc 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -693,6 +693,28 @@ def fc_eltwise(no_bias, data_shape, flatten=True, 
alg='relu'):
 
   return sym, attr
 
+def fc_identity_eltwise(data_shape, identity_node):
+  attrs = {'sg_mkldnn_fully_connected_eltwise_0' : {'with_eltwise': 'true'},
+           'sg_mkldnn_fully_connected_eltwise_1' : {'with_eltwise': 'true'}}
+  data, fc1_weight = head_symbol(data_shape)
+  fc2_weight = mx.symbol.Variable('fc2_weight', dtype='float32')
+
+  sym = mx.symbol.FullyConnected(name='fc1', data=data, weight=fc1_weight, 
num_hidden=64,
+                                no_bias=True, flatten=True)
+  if identity_node == 'copy':
+    sym = mx.symbol.identity(sym)
+  else:
+    sym = mx.symbol.Dropout(sym)
+  sym = mx.symbol.Activation(sym, act_type='relu')
+  sym = mx.symbol.FullyConnected(name='fc2', data=sym, weight=fc2_weight, 
num_hidden=64,
+                                no_bias=True, flatten=True)
+  if identity_node == 'copy':
+    sym = mx.symbol.identity(sym)
+  else:
+    sym = mx.symbol.Dropout(sym)
+  sym = mx.symbol.Activation(sym, act_type='relu')
+  return sym, attrs
+
 def single_selfatt_qk(data_shape, nheads=16):
   attr = {'selfatt_qk': {}}
   data = mx.symbol.Variable('data', shape=data_shape, dtype='float32')
@@ -877,6 +899,12 @@ def test_single_fc():
       check_fusion(syms, dshape, attrs, check_quantization=False)
 
 @with_seed()
+def test_fc_eltwise_identity():
+  for dshape, identity_node in itertools.product(DATA_SHAPE, ['copy', 
'dropout']):
+    syms, attrs = fc_identity_eltwise(dshape, identity_node)
+    check_fusion(syms, dshape, attrs, check_quantization=False)
+
+@with_seed()
 def test_fc_eltwise():
   for dshape, no_bias, flatten, alg in itertools.product(DATA_SHAPE,
                                                         [True, False],

Reply via email to