vandanavk commented on a change in pull request #14491: Enable slice embedding 
concat split fuse
URL: https://github.com/apache/incubator-mxnet/pull/14491#discussion_r282766923
 
 

 ##########
 File path: src/operator/subgraph/wide_and_deep_input_fuse_property.cc
 ##########
 @@ -0,0 +1,223 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one
+* or more contributor license agreements.  See the NOTICE file
+* distributed with this work for additional information
+* regarding copyright ownership.  The ASF licenses this file
+* to you under the Apache License, Version 2.0 (the
+* "License"); you may not use this file except in compliance
+* with the License.  You may obtain a copy of the License at
+*
+*   http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing,
+* software distributed under the License is distributed on an
+* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+* KIND, either express or implied.  See the License for the
+* specific language governing permissions and limitations
+* under the License.
+*/
+#include <sstream>
+#include "common.h"
+#include "subgraph_property.h"
+#include "../nn/fully_connected-inl.h"
+#include "../nn/activation-inl.h"
+#include "../nn/concat-inl.h"
+#include "../tensor/slice_split_embedding.h"
+#include "../tensor/indexing_op.h"
+#include "../tensor/matrix_op-inl.h"
+#include "../slice_channel-inl.h"
+namespace mxnet {
+namespace op {
+#define EMBEDDING_NODE_NAME "Embedding"
+#define CONCAT_NODE_NAME "Concat"
+class SgWideAndDeepInputFuseSelector : public SubgraphSelector {
+ public:
+  /*! \brief pattern match status */
+  enum SelectStatus {
+    kFail = 0,
+    kSuccess,
+  };
+
+ private:
+  bool disable_all;
+  SelectStatus status;
+
+ public:
+  explicit SgWideAndDeepInputFuseSelector(int dis_all)
+      : disable_all(dis_all), status(kFail) {}
+  bool Select(const nnvm::Node &n) override {
+    if (disable_all)
+        return false;
+    if (n.op() && n.op()->name == CONCAT_NODE_NAME) {
+      status = kSuccess;
+      return true;
+    }
+    return false;
+  }
+  bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+    if (disable_all) return false;
+    if (new_node.is_variable() )
+        return false;
+    return true;
+  }
+  bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
+      return false;
+  }
+  std::vector<nnvm::Node *> Filter(
+      const std::vector<nnvm::Node *> &candidates) override {
+    if (status != kSuccess) {
+      return std::vector<nnvm::Node *>(0);
+    } else {
+      return candidates;
+    }
+  }
+};
+template <typename T>
+static std::string int_vector_to_attr(T v) {
+  std::stringstream ss;
+  ss << "[";
+  size_t i = 0;
+  for (; i < v.size()-1; i++) {
+      ss << v[i] << ",";
+  }
+  ss << v[i];
+  ss << "]";
+  return ss.str();
+}
+static std::string int_vector_to_tuple_attr(mxnet::Tuple<dmlc::optional<int>> 
v) {
+  std::stringstream ss;
+  ss << "(";
+  index_t i = 0;
+  for (; i < v.ndim()-1; ++i) {
+      ss << v[i] << ",";
+  }
+  ss << v[i];
+  ss << ")";
+  return ss.str();
+}
+static std::string get_value_from_op_prop(
+  std::unordered_map<std::string, std::string> op_dict, std::string key) {
+  std::unordered_map<std::string, std::string>::const_iterator got = 
op_dict.find(key);
+  if (got == op_dict.end())
+      return "";
+  else
+      return got->second;
+}
+class SgWideAndDeepInputFuseProperty : public SubgraphProperty {
+ public:
+  SgWideAndDeepInputFuseProperty() {
+    disable_all = dmlc::GetEnv("MXNET_DISABLE_WIDE_DEEP_OPT", 0);
+    if (disable_all) {
+      LOG(INFO) << "Wide And Deep Input Fuse is disabled.";
+    } else {
+      LOG(INFO) << "Start to execute Wide And Deep Input Fuse optimization 
pass.";
+    }
+  }
+  static SubgraphPropertyPtr Create() {
+    return std::make_shared<SgWideAndDeepInputFuseProperty>();
+  }
+  nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
+                                  const int subgraph_id = 0) const override {
+    std::vector<nnvm::NodePtr> emb_nodes;
+    std::vector<nnvm::NodePtr> slice_nodes;
+    nnvm::NodePtr split_nodes;
+    nnvm::NodePtr concat_node = nullptr;
+    DFSVisit(sym.outputs, [&](const nnvm::NodePtr &node) {
+      if (node->is_variable()) return;
+      auto &op_name = node->op()->name;
+
+      // The Assumption is only base on W&D which all
+      // embedding occur at the beginning and output to 1 concat node
+      if (op_name == EMBEDDING_NODE_NAME) {
+          emb_nodes.push_back(node);
 
 Review comment:
   can push_back be replaced by emplace_back wherever applicable?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to