ptrendx commented on a change in pull request #19269:
URL: https://github.com/apache/incubator-mxnet/pull/19269#discussion_r506688373



##########
File path: src/imperative/pointwise_fusion_pass.cc
##########
@@ -57,281 +58,321 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  
subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = 
std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t 
inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    
node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& 
subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the 
same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const 
nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the 
subgraph
-    // to a dependencies between the subgraph node and the nodes out of the 
subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const 
nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = 
Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, 
&(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + 
"_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + 
their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + 
std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              
info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node 
names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const 
nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* 
subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const 
nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), 
info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_assignment[i] != -1 && subgraph_assignment[dep] == -1) {

Review comment:
       :+1:




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to