This is an automated email from the ASF dual-hosted git repository.

bgawrych pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new b322bee0e7 [FEATURE] Add property removing duplicate Cast operations 
(#21020)
b322bee0e7 is described below

commit b322bee0e7151366fee206f0bd8ce1cd8d75d4bf
Author: bartekkuncer <[email protected]>
AuthorDate: Mon Jun 20 13:06:55 2022 +0200

    [FEATURE] Add property removing duplicate Cast operations (#21020)
    
    * [FEATURE] Add property removing duplicate Cast operations to 'ONEDNN' 
subgraph backend
    
    * Fix CI
    
    * Fix sanity
    
    * Add checks for the amount of inputs and outputs
    
    * Fix review
    
    * Fix review
---
 .../subgraph/dnnl/dnnl_remove_casts_property.h     | 146 +++++++++++++++++++++
 .../subgraph/dnnl/dnnl_subgraph_property.cc        |   2 +
 2 files changed, 148 insertions(+)

diff --git a/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h 
b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h
new file mode 100644
index 0000000000..d796718949
--- /dev/null
+++ b/src/operator/subgraph/dnnl/dnnl_remove_casts_property.h
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file dnnl_remove_casts_property.h
+ * \brief Graph property for removing two unnecessary Cast operations
+ *
+ * ... -> Cast(dtype) -> expand_dims -> Cast(dtype) -> Cast(dtype) -> ...
+ *                                  ||
+ *                                  \/
+ *                ... -> Cast(dtype) -> expand_dims -> ...
+ */
+
+#ifndef MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_REMOVE_CASTS_PROPERTY_H_
+#define MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_REMOVE_CASTS_PROPERTY_H_
+
+#if MXNET_USE_ONEDNN == 1
+
+#include <map>
+#include <string>
+#include <vector>
+
+#include "operator/subgraph/common.h"
+#include "dnnl_subgraph_base-inl.h"
+
+namespace mxnet {
+namespace op {
+
+class SgDNNLRemoveCastsSelector : public SubgraphSelectorV2 {
+ private:
+  enum CastStatus { kExpand, kCast, kSuccess, kFail };
+  CastStatus status_ = kFail;
+  int castDtype = -1;  // used to determine whether Cast on the input to 
expand_dims has the same
+                       // dtype as the ones performed on the output
+
+ public:
+  bool Select(const BiDirectedNode& seed_node,
+              const std::shared_ptr<NodeAttr>& node_attr) override {
+    if (seed_node.node->op() == Op::Get("expand_dims") && 
seed_node.node->num_inputs() == 1 &&
+        seed_node.node->num_outputs() == 1) {
+      status_ = kExpand;
+      return true;
+    }
+    return false;
+  }
+
+  bool SelectInput(const BiDirectedNode& n, const BiDirectedNode& input_node) 
override {
+    if (input_node.node->op() != Op::Get("Cast")) {
+      status_ = kFail;
+    } else {
+      auto const& cast_param = 
nnvm::get<CastParam>(input_node.node->attrs.parsed);
+      castDtype              = cast_param.dtype;
+    }
+    return false;
+  }
+
+  bool SelectOutput(const BiDirectedNode& n, const BiDirectedNode& 
output_node) override {
+    if (status_ == kFail || status_ == kSuccess || 
output_node.node->is_variable()) {
+      return false;
+    }
+    if (output_node.node->op() == Op::Get("Cast")) {
+      auto const& cast_param = 
nnvm::get<CastParam>(output_node.node->attrs.parsed);
+      if (cast_param.dtype == castDtype) {
+        if (status_ == kExpand && output_node.node->num_outputs() == 1) {
+          status_ = kCast;
+          return true;
+        } else if (status_ == kCast) {
+          status_ = kSuccess;
+          return true;
+        }
+      }
+    }
+    status_ = kFail;
+    return false;
+  }
+
+  std::vector<BiDirectedNode*> Filter(const std::vector<BiDirectedNode*>& 
candidates) override {
+    return status_ == kSuccess ? candidates : std::vector<BiDirectedNode*>(0);
+  }
+
+  void Reset() override {
+    status_   = kFail;
+    castDtype = -1;
+  }
+};
+
+class SgDNNLRemoveCastsProperty : public SubgraphProperty {
+ public:
+  SgDNNLRemoveCastsProperty() {}
+
+  static SubgraphPropertyPtr Create() {
+    static const std::string& name = "Remove casts optimization pass";
+    auto property                  = 
std::make_shared<SgDNNLRemoveCastsProperty>();
+    property->SetAttr<std::string>("property_name", name);
+    property->SetAttr<bool>("inference_only", true);
+    if (dmlc::GetEnv("MXNET_DISABLE_REMOVE_CASTS_PROPERTY", 0)) {
+      property->SetAttr<bool>("disable", true);
+    }
+    return property;
+  }
+
+  nnvm::ObjectPtr CreateSubgraphNode(const nnvm::Symbol& sym,
+                                     const int subgraph_id = 0) const override 
{
+    nnvm::ObjectPtr n = nnvm::Node::Create();
+    n->attrs.op       = Op::Get("expand_dims");
+    DFSVisit(sym.outputs, [&](const nnvm::ObjectPtr& node) {
+      if (node->attrs.op == Op::Get("expand_dims")) {
+        n->attrs.name         = node->attrs.name;
+        n->attrs.dict["axis"] = node->attrs.dict["axis"];
+      }
+      return;
+    });
+
+    if (n->op()->attr_parser) {
+      n->op()->attr_parser(&(n->attrs));
+    }
+    return n;
+  }
+
+  SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
+    auto selector = std::make_shared<SgDNNLRemoveCastsSelector>();
+    return selector;
+  }
+};
+
+}  // namespace op
+}  // namespace mxnet
+
+#endif  // if MXNET_USE_ONEDNN == 1
+#endif  // MXNET_OPERATOR_SUBGRAPH_DNNL_DNNL_REMOVE_CASTS_PROPERTY_H_
diff --git a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc 
b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
index 2930fffa67..a7a290f93f 100644
--- a/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
+++ b/src/operator/subgraph/dnnl/dnnl_subgraph_property.cc
@@ -30,6 +30,7 @@
 #include "dnnl_transformer_qk_property.h"
 #include "dnnl_transformer_valatt_property.h"
 #include "dnnl_fc_sum_fuse.h"
+#include "dnnl_remove_casts_property.h"
 
 namespace mxnet {
 namespace op {
@@ -42,6 +43,7 @@ MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, 
SgDNNLIdentityProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLConvProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLFCProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBNReLUProperty);
+MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLRemoveCastsProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerQKProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLTransformerValAttProperty);
 MXNET_REGISTER_SUBGRAPH_PROPERTY(ONEDNN, SgDNNLBatchDotProperty);

Reply via email to