This is an automated email from the ASF dual-hosted git repository.

ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new d2600f2  [1.x] Backport Faster pointwise fusion graph pass (#19269) 
(#19413)
d2600f2 is described below

commit d2600f21c2e9bb377bb3f77264a378d2b84236a7
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Tue Nov 17 15:26:17 2020 -0800

    [1.x] Backport Faster pointwise fusion graph pass (#19269) (#19413)
    
    * Faster pointwise fusion graph pass (#19269)
    
    * Faster pointwise fusion graph pass
    
    * Fix lint
    
    * Fix lint 2
    
    * Fixes
    
    * Fixing slice parameter handling in fusion
    
    * Fixing the slice fix
    
    * Fix the cycle bug
    
    * Added test
    
    * Fix lint
    
    * Fix merging of subgraphs
    
    * Fixes from review
    
    * Use std::tie instead of C++17 structured binding
    
    * More fixes for lack of c++17
    
    * Fix
---
 src/executor/exec_pass.h              |  16 +-
 src/executor/graph_executor.cc        |   5 +-
 src/executor/pointwise_fusion_pass.cc | 519 +++++++++++++++--------------
 src/executor/simple_partition_pass.cc | 265 +++++++++++++++
 src/executor/simple_partition_pass.h  | 599 +++++++++++-----------------------
 src/imperative/cached_op.h            |   5 +-
 src/operator/fusion/fused_op.cu       |  51 ++-
 tests/python/gpu/test_fusion.py       |  21 ++
 8 files changed, 809 insertions(+), 672 deletions(-)

diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 4552fa1..07ccc02 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -222,22 +222,14 @@ Graph DetectInplaceAddTo(Graph g);
 Graph EliminateCommonExpr(Graph && g);
 
 /*!
- * \brief Fuse pointwise operations in the forward pass.
+ * \brief Fuse pointwise operations in the graph.
  *
  * \param g input graph (needs to be entire graph, not just forward part)
+ * \param num_forward_outputs number of outputs in the graph produced by the 
forward pass
  *
- * \return graph with fused pointwise operations in the forward pass
+ * \return copy of the graph with fused pointwise operations
  */
-Graph FusePointwiseForward(Graph&& g);
-
-/*!
- * \brief Fuse pointwise operations in the backward pass.
- *
- * \param g input graph (needs to be entire graph, not just forward part)
- *
- * \return graph with fused pointwise operations in the backward pass
- */
-Graph FusePointwiseBackward(Graph&& g);
+Graph FusePointwise(const Graph& g, const size_t num_forward_outputs);
 
 /*!
  * \brief Issue a one-time warning that fusion is not possible for this 
platform or build.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 28b79ae..3f2c7c9 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1009,10 +1009,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
     common::CopyGraph(&unoptimized_graph, g, false);
 
     if 
(common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) {
-      g.attrs["num_forward_outputs"] = 
std::make_shared<nnvm::any>(num_forward_outputs_);
-      g = FusePointwiseForward(std::move(g));
-      g.attrs["num_forward_outputs"] = 
std::make_shared<nnvm::any>(num_forward_outputs_);
-      g = FusePointwiseBackward(std::move(g));
+      g = exec::FusePointwise(std::move(g), num_forward_outputs_);
       // Check the topological order of inputs
       const auto &original_inputs = 
unoptimized_graph.indexed_graph().input_nodes();
       const auto &new_inputs = g.indexed_graph().input_nodes();
diff --git a/src/executor/pointwise_fusion_pass.cc 
b/src/executor/pointwise_fusion_pass.cc
index 3203f67..961dade 100644
--- a/src/executor/pointwise_fusion_pass.cc
+++ b/src/executor/pointwise_fusion_pass.cc
@@ -31,6 +31,7 @@
 #include <nnvm/pass_functions.h>
 #include <algorithm>
 #include <queue>
+#include <chrono>
 #include "./simple_partition_pass.h"
 #include "../operator/fusion/fused_op-inl.h"
 #include "../operator/fusion/fused_op.h"
@@ -57,281 +58,323 @@ void WarnFusionNotSupported() {
 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
 
 namespace {
-  bool IsFusionCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (ops_desc.count(op_name))
-      return true;
-    if (slice_ops.count(op_name))
-      return false;
-    if (std::find(variable_io_ops.begin(),
-                  variable_io_ops.end(),
-                  op_name) !=
-        variable_io_ops.end())
-      return true;
-    if (op_name == "LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
-    if (op_name == "_backward_LeakyReLU") {
-        std::string act_type = n->attrs.dict.at("act_type");
-        if (LeakyReLU_bwd_ops.count(act_type))
-          return true;
-        else
-          return false;
-    }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (ops_desc.count(op_name))
+    return true;
+  if (slice_ops.count(op_name))
     return false;
+  if (std::find(variable_io_ops.begin(),
+                variable_io_ops.end(),
+                op_name) !=
+      variable_io_ops.end())
+    return true;
+  if (op_name == "LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_ops.count(act_type))
+        return true;
+      else
+        return false;
   }
+  if (op_name == "_backward_LeakyReLU") {
+      std::string act_type = n->attrs.dict.at("act_type");
+      if (LeakyReLU_bwd_ops.count(act_type))
+        return true;
+      else
+        return false;
+  }
+  return false;
+}
 
-  bool IsInputsOnlyCompatible(nnvm::Node* n) {
-    using namespace mxnet::fusion;
-    if (n->op() == nullptr)
-      return false;
-    std::string op_name = n->op()->name;
-    if (slice_ops.count(op_name)) {
-      if (op_name == "slice") {
-        // slice with non-default step attribute is not supported
-        // currently
-        if (n->attrs.dict.count("step") &&
-            !(n->attrs.dict.at("step") == "()" ||
-              n->attrs.dict.at("step") == "[]")) {
-          return false;
-        }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+  using namespace mxnet::fusion;
+  if (n->op() == nullptr)
+    return false;
+  const std::string& op_name = n->op()->name;
+  if (slice_ops.count(op_name)) {
+    if (op_name == "slice") {
+      // slice with non-default step attribute is not supported
+      // currently
+      if (n->attrs.dict.count("step") &&
+          !(n->attrs.dict.at("step") == "()" ||
+            n->attrs.dict.at("step") == "[]")) {
+        return false;
       }
-      return true;
     }
-    return false;
+    return true;
   }
+  return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+                        size_t inputs_size,
+                        nnvm::Node* subgraph_node) {
+  static const Op* fused_op_ptr = Op::Get("_FusedOp");
+  
subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+  subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+  subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+  subgraph_node->attrs.dict["num_outputs"] = 
std::to_string(subgraph.outputs.size());
+  subgraph_node->attrs.op = fused_op_ptr;
+  subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+  int source_node;
+  int index;
+};
 
-  nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t 
inputs_size) {
-    nnvm::Symbol subgraph_sym;
-    auto node = nnvm::Node::Create();
-    subgraph_sym.outputs = subgraph.outputs;
-    
node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
-    node->attrs.name = "FusedOp";
-    node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
-    node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
-    node->attrs.op = Op::Get("_FusedOp");
-    node->op()->attr_parser(&(node->attrs));
-    return node;
+inline int SetInsert(const EntryInfo& new_elem,
+                     std::vector<EntryInfo>* elements) {
+  for (size_t i = 0; i < elements->size(); ++i) {
+    if ((new_elem.source_node == elements->at(i).source_node) &&
+        (new_elem.index == elements->at(i).index)) {
+      return i;
+    }
   }
+  elements->emplace_back(new_elem);
+  return elements->size() - 1;
+}
+
 }  // namespace
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- *        This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ *        FusedOps. If there are no subgraphs to be replaced, the
+ *        original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ *                            subgraphs. Values from -1 to num_subgraphs - 1
+ *                            are allowed, -1 means that the node is not in a
+ *                            subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
  */
 template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>& 
subgraph_sets,
-                                FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
-    }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the 
same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const 
nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+                              const std::vector<int>& subgraph_assignment,
+                              const int num_subgraphs,
+                              FCreateNode create_subgraph_node) {
+  if (num_subgraphs == 0) {
+    return g;
+  }
+
+  Graph ret;
+
+  const auto& idx = g.indexed_graph();
+
+  CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+    "Every node in the graph needs to be included in subgraph assignment.";
+
+  std::vector<nnvm::ObjectPtr> new_nodes;
+  new_nodes.reserve(idx.num_nodes());
+  struct SubgraphInfo {
+    nnvm::Graph graph;
+    nnvm::ObjectPtr subgraph_node;
+    std::vector<EntryInfo> outputs;
+    std::vector<EntryInfo> inputs;
+    std::vector<nnvm::ObjectPtr> input_nodes;
+  };
+
+  std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+  for (auto& info : subgraphs) {
+    info.subgraph_node = nnvm::Node::Create();
+  }
+
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // First copy the node, it will be used
+    // either in the new graph or inside a
+    // subgraph. Variables are not copied.
+    if (idx[i].source->op() != nullptr) {
+      new_nodes.emplace_back(nnvm::Node::Create());
+      auto& node_copy = new_nodes.back();
+      node_copy->attrs = idx[i].source->attrs;
+      node_copy->info = idx[i].source->info;
+    } else {
+      new_nodes.emplace_back(idx[i].weak_ref.lock());
+      continue;
     }
-    // move control dependencies between nodes of the subgraph and out of the 
subgraph
-    // to a dependencies between the subgraph node and the nodes out of the 
subgraph
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const 
nnvm::ObjectPtr& node) {
-      if (subgraph_set.count(node.get())) {
-        auto it = node->control_deps.begin();
-        static auto& is_fusion = 
Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
-        std::vector<nnvm::ObjectPtr> new_control_deps;
-        // Use the first control dependency to get the inferattr helper
-        if (it != node->control_deps.end()) {
-          if (subgraph_set.count(it->get())) {
-            new_control_deps.push_back(*it);
+    auto& node_copy = new_nodes.back();
+    const int subgraph_id = subgraph_assignment[i];
+    if (subgraph_id != -1) {
+      auto& info = subgraphs[subgraph_id];
+      for (const auto& input : idx[i].inputs) {
+        const int their_subgraph = subgraph_assignment[input.node_id];
+        if (their_subgraph == subgraph_id) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          int input_num;
+          int output_num;
+          if (their_subgraph == -1) {
+            input_num = SetInsert({static_cast<int>(input.node_id),
+                                   static_cast<int>(input.index)}, 
&(info.inputs));
           } else {
-            if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
-              uint32_t node_id = subgraph_node->control_deps.size();
-              subgraph_node->control_deps.push_back(*it);
-              auto helper_node = op::MakeNode("_FusedOpOutHelper",
-                                              "FusedOp_" + node->attrs.name + 
"_outhelper",
-                                              nullptr,
-                                              nullptr,
-                                              nullptr);
-              helper_node->attrs.parsed =
-                FusedOpHelperParamPtr(new FusedOpHelperParam(
-                      nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                      node_id));
-              new_control_deps.push_back(helper_node);
+            auto& their_subgraph_info = subgraphs[their_subgraph];
+            output_num = SetInsert({static_cast<int>(input.node_id),
+                                    static_cast<int>(input.index)},
+                                   &(their_subgraph_info.outputs));
+            input_num = SetInsert({static_cast<int>(idx.num_nodes() + 
their_subgraph),
+                                   output_num},
+                                  &(info.inputs));
+          }
+          if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+            info.input_nodes.emplace_back(nnvm::Node::Create());
+            info.input_nodes.back()->attrs.name = "input_" + 
std::to_string(input_num);
+            if (their_subgraph == -1) {
+              info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+                                                      input.index,
+                                                      input.version);
             } else {
-              new_control_deps.push_back(*it);
+              
info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+                                                      output_num,
+                                                      input.version);
             }
           }
-          ++it;
+          node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
         }
-        node->control_deps = new_control_deps;
       }
-    });
-
-    std::ostringstream name_oss;
-    // the name of the new node will be the concatenation of all the node 
names in the subgraph
-    DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
-      if (n->op() != nullptr) {
-        name_oss << n->op()->name << "_";
+    } else {
+      for (const auto& input : idx[i].inputs) {
+        const int subgraph_id = subgraph_assignment[input.node_id];
+        if (subgraph_id == -1) {
+          node_copy->inputs.emplace_back(new_nodes[input.node_id],
+                                         input.index,
+                                         input.version);
+        } else {
+          auto& info = subgraphs[subgraph_id];
+          const int output_num = SetInsert({static_cast<int>(input.node_id),
+                                            static_cast<int>(input.index)},
+                                           &(info.outputs));
+          node_copy->inputs.emplace_back(info.subgraph_node,
+                                         output_num,
+                                         input.version);
+        }
       }
-    });
-    auto subgraph_name = name_oss.str();
-    subgraph_name.pop_back();
-    subgraph_node->attrs.name = subgraph_name;
+    }
 
-    const auto& index = subgraph.indexed_graph();
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const 
nnvm::ObjectPtr& node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get())) {
-          uint32_t node_id = index.node_id(e.get());
-          auto helper_node = op::MakeNode("_FusedOpHelper",
-                                          subgraph_node->attrs.name + "_"
-                                          + node->attrs.name + "_helper",
-                                          nullptr,
-                                          nullptr,
-                                          nullptr);
-          helper_node->attrs.parsed =
-            FusedOpHelperParamPtr(new FusedOpHelperParam(
-                  nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
-                  node_id));
-          e = helper_node;
-        }
+    // Control deps
+    for (const auto& dep : idx[i].control_deps) {
+      if (subgraph_id == subgraph_assignment[dep]) {
+        node_copy->control_deps.emplace_back(new_nodes[dep]);
       }
-    });
+    }
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- *        which are only compatible when they are the first nodes in the
- *        subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
-                             std::vector<std::unordered_set<nnvm::Node*> >* 
subsets,
-                             IsCompatible is_compatible) {
-  std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
-  size_t subgraphs_fullsize = 0;
-  for (auto& s : *subsets) {
-    subgraphs_fullsize += s.size();
-  }
-  node2setidx.reserve(subgraphs_fullsize);
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    for (auto& n : (*subsets)[i]) {
-      node2setidx.insert({n, i});
+  ret.outputs.reserve(idx.outputs().size());
+  for (const auto& output : idx.outputs()) {
+    const int subgraph_id = subgraph_assignment[output.node_id];
+    if (subgraph_id == -1) {
+      ret.outputs.emplace_back(new_nodes[output.node_id],
+                               output.index,
+                               output.version);
+    } else {
+      const int output_num = SetInsert({static_cast<int>(output.node_id),
+                                        static_cast<int>(output.index)},
+                                       &(subgraphs[subgraph_id].outputs));
+      ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+                               output_num,
+                               output.version);
     }
   }
-  std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
-  DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const 
nnvm::ObjectPtr& n) {
-    const auto& it = node2setidx.find(n.get());
-    if (it != node2setidx.end()) {
-      for (auto& e : n->inputs) {
-        if (is_compatible(e.node.get()))
-          to_add[it->second].push_back(e.node.get());
-      }
+
+  for (auto& info : subgraphs) {
+    info.graph.outputs.reserve(info.outputs.size());
+    for (const auto& entry_info : info.outputs) {
+      info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+                                      entry_info.index,
+                                      0);
     }
-  });
+    create_subgraph_node(info.graph, info.inputs.size(), 
info.subgraph_node.get());
+  }
 
-  // Avoid duplicating the node that is input of two subsets
-  std::unordered_set<nnvm::Node*> added;
-  for (size_t i = 0; i < subsets->size(); ++i) {
-    std::vector<nnvm::NodeEntry> heads;
-    for (auto n : subsets->at(i)) {
-      for (auto e : n->inputs) {
-        if (!subsets->at(i).count(e.node.get()))
-          heads.push_back(e);
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    // Add _FusedOpHelper nodes
+    const int subgraph_id = subgraph_assignment[i];
+    for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+      const auto& dep = idx[i].control_deps[dep_num];
+      const int their_subgraph_id = subgraph_assignment[dep];
+      if (subgraph_id != -1 && their_subgraph_id == -1) {
+        // Not in any subgraph, use FusedOpOutHelper
+        auto& info = subgraphs[subgraph_id];
+        size_t node_id = info.subgraph_node->control_deps.size();
+        info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+        auto helper_node = op::MakeNode("_FusedOpOutHelper",
+                                        "FusedOp_" + new_nodes[i]->attrs.name 
+ "_outhelper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + 
dep_num,
+                                          std::move(helper_node));
+      } else if (their_subgraph_id != subgraph_id &&
+                 their_subgraph_id != -1) {
+        auto& info = subgraphs[their_subgraph_id];
+        const auto& subgraph_idx = info.graph.indexed_graph();
+        uint32_t node_id = subgraph_idx.node_id(new_nodes[dep].get());
+        auto helper_node = op::MakeNode("_FusedOpHelper",
+                                        info.subgraph_node->attrs.name + "_"
+                                        + idx[i].source->attrs.name + 
"_helper",
+                                        nullptr,
+                                        nullptr,
+                                        nullptr);
+        helper_node->attrs.parsed =
+          FusedOpHelperParamPtr(new FusedOpHelperParam(
+                nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+                node_id));
+        new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() + 
dep_num,
+                                          std::move(helper_node));
       }
     }
-    for (size_t j = 0; j < to_add[i].size(); ++j) {
-      if (!added.count(to_add[i][j])) {
-        bool make_cycle = false;
-        const auto& node = to_add[i][j];
-        std::vector<nnvm::NodeEntry> _heads;
-        std::copy_if(heads.begin(), heads.end(), std::back_inserter(_heads),
-                     [&node](const nnvm::NodeEntry& n) {
-                       return n.node.get() != node;
-                     });
-        DFSVisit(_heads, [&make_cycle, &node](const nnvm::ObjectPtr& n) {
-          if (n.get() == node)
-            make_cycle = true;
-        });
-        if (!make_cycle) {
-          (*subsets)[i].insert(to_add[i][j]);
-          added.insert(to_add[i][j]);
+  }
+  for (auto& info : subgraphs) {
+    const auto& idx = info.graph.indexed_graph();
+    const auto& input_nodes = idx.input_nodes();
+    std::vector<nnvm::NodeEntry> subgraph_inputs;
+    subgraph_inputs.reserve(info.subgraph_node->inputs.size());
+    for (const int input : input_nodes) {
+      for (size_t i = 0; i < info.input_nodes.size(); ++i) {
+        const auto& input_ptr = info.input_nodes[i].get();
+        if (input_ptr == idx[input].source) {
+          subgraph_inputs.emplace_back(info.subgraph_node->inputs[i]);
         }
       }
     }
+    info.subgraph_node->inputs.swap(subgraph_inputs);
+    std::string name;
+    for (size_t i = 0; i < idx.num_nodes(); ++i) {
+      if (idx[i].source->op() != nullptr) {
+        name += idx[i].source->op()->name + "_";
+      }
+    }
+    info.subgraph_node->attrs.name = name;
   }
-}
-
-Graph FusePointwiseForward(Graph &&g) {
-  Graph ret;
-  g.indexed_graph();
-  const auto& num_forward_outputs = g.GetAttr<size_t>("num_forward_outputs");
-  Graph fg;
-  fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(),
-                    g.outputs.begin() + num_forward_outputs);
-  auto subsets = GetCompatibleSubsets(fg, IsFusionCompatible);
-  AddInputsOnlyCompatible(fg, &subsets, IsInputsOnlyCompatible);
-  g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode);
-  ret.outputs = g.outputs;
   return ret;
 }
 
-Graph FusePointwiseBackward(Graph &&g) {
-  Graph ret;
-  g.indexed_graph();
-  const auto& num_forward_outputs = g.GetAttr<size_t>("num_forward_outputs");
-  Graph fg;
-  fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(),
-                    g.outputs.begin() + num_forward_outputs);
-  std::unordered_set<nnvm::Node*> exclusion_set;
-  DFSVisit(fg.outputs, [&exclusion_set](const nnvm::ObjectPtr& n) {
-    exclusion_set.insert(n.get());
-  });
-  auto subsets = GetCompatibleSubsets(g, [&exclusion_set](nnvm::Node* n) {
-    if (exclusion_set.count(n))
-      return false;
-    return IsFusionCompatible(n);
-  });
-  g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode);
-  ret.outputs = g.outputs;
+Graph FusePointwise(const Graph &g, const size_t num_forward_outputs) {
+  auto start = std::chrono::steady_clock::now();
+  std::vector<int> subset_assignment;
+  int num_subsets;
+  std::tie(subset_assignment, num_subsets) = GetCompatibleSubsets(g, 
num_forward_outputs,  // NOLINT(*)
+                                                                  
IsFusionCompatible,
+                                                                  
IsInputsOnlyCompatible);
+  Graph ret = CopyAndReplaceSubgraphs(g, subset_assignment, num_subsets,
+                                      CreateSubgraphNode);
+  auto end = std::chrono::steady_clock::now();
+  if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) {
+    auto diff = end - start;
+    LOG(INFO) << "Pointwise fusion graph pass took: "
+              << std::chrono::duration<double, std::milli>(diff).count()
+              << "ms.";
+  }
   return ret;
 }
 #endif  // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
diff --git a/src/executor/simple_partition_pass.cc 
b/src/executor/simple_partition_pass.cc
new file mode 100644
index 0000000..941959d
--- /dev/null
+++ b/src/executor/simple_partition_pass.cc
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+const IntervalVec* LargerSet(const IntervalVec* const first,
+                             const IntervalVec* const second) noexcept {
+  const IntervalVec* ret = nullptr;
+  auto first_iter = first->begin();
+  auto second_iter = second->begin();
+  while (first_iter != first->end() &&
+         second_iter != second->end()) {
+    if (*first_iter == *second_iter) {
+      ++first_iter;
+      ++second_iter;
+    } else {
+      // Entry in first set not seen in the second set
+      if (first_iter->second < second_iter->first) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set not seen in the first set
+      if (second_iter->second < first_iter->first) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in first set fully encloses the entry in the second set
+      if (first_iter->first <= second_iter->first &&
+          first_iter->second >= second_iter->second) {
+        if (ret == first || ret == nullptr) {
+          ret = first;
+          ++second_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entry in second set fully encloses the entry in the first set
+      if (second_iter->first <= first_iter->first &&
+          second_iter->second >= first_iter->second) {
+        if (ret == second || ret == nullptr) {
+          ret = second;
+          ++first_iter;
+        } else {
+          return nullptr;
+        }
+        continue;
+      }
+      // Entries intersect but one is not fully enclosed in the other
+      return nullptr;
+    }
+  }
+  if (ret == nullptr) {
+    // The common part is the same
+    return second_iter == second->end() ? first : second;
+  } else {
+    if ((ret == first && second_iter == second->end()) ||
+        (ret == second && first_iter == first->end())) {
+      return ret;
+    }
+  }
+  return nullptr;
+}
+
+void MergeSets(const IntervalVec** const my_set,
+               const IntervalVec* const other_set,
+               std::vector<std::unique_ptr<const IntervalVec>>* const storage) 
noexcept {
+  if ((*my_set == nullptr) || (*my_set)->size() == 0) {
+    *my_set = other_set;
+    return;
+  }
+  if (other_set == nullptr || other_set->size() == 0) {
+    return;
+  }
+  auto* larger_set = LargerSet(*my_set, other_set);
+  if (larger_set != nullptr) {
+    *my_set = larger_set;
+    return;
+  }
+  auto my_iter = (*my_set)->cbegin();
+  auto other_iter = other_set->cbegin();
+  auto new_set = IntervalVec();
+  int last_end = -10;  // less than -1
+  while (my_iter != (*my_set)->cend() &&
+         other_iter != other_set->cend()) {
+    const auto& mine = *my_iter;
+    const auto& other = *other_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours
+      if (last_end >= other.first - 1) {
+        new_set.back().second = other.second;
+      } else {
+        new_set.emplace_back(other);
+      }
+      last_end = other.second;
+      ++other_iter;
+    } else if (other.first > mine.second + 1) {
+      // other interval is after ours
+      if (last_end >= mine.first - 1) {
+        new_set.back().second = mine.second;
+      } else {
+        new_set.emplace_back(mine);
+      }
+      last_end = mine.second;
+      ++my_iter;
+    } else {
+      // Intervals can be merged together
+      Interval n(std::min(mine.first, other.first),
+                 std::max(mine.second, other.second));
+      if (last_end >= n.first - 1) {
+        new_set.back().second = n.second;
+      } else {
+        new_set.emplace_back(n);
+      }
+      last_end = n.second;
+      if (other.second >= mine.second) {
+        ++my_iter;
+      }
+      if (mine.second >= other.second) {
+        ++other_iter;
+      }
+    }
+  }
+  auto remaining_iter = my_iter == (*my_set)->cend() ? other_iter : my_iter;
+  auto remaining_end = my_iter == (*my_set)->cend() ? other_set->cend() : 
(*my_set)->cend();
+  // Add the rest of entries
+  for (; remaining_iter != remaining_end; ++remaining_iter) {
+    auto& mine = new_set.back();
+    const auto& other = *remaining_iter;
+    if (other.second < mine.first - 1) {
+      // other interval is before ours, should never happen
+      continue;
+    } else if (other.first > mine.second + 1) {
+      // other interval is after ours
+      new_set.emplace_back(other);
+    } else {
+      // Intervals can be merged together
+      mine.first = std::min(mine.first, other.first);
+      mine.second = std::max(mine.second, other.second);
+    }
+  }
+  storage->emplace_back(std::make_unique<IntervalVec>(std::move(new_set)));
+  *my_set = storage->back().get();
+}
+
+bool Intersect(const IntervalVec& checked_sets,
+               const IntervalVec& excluded_sets) noexcept {
+  size_t current_interval = 0, current_other_interval = 0;
+  while (current_interval < checked_sets.size() &&
+         current_other_interval < excluded_sets.size()) {
+    const auto& mine = checked_sets[current_interval];
+    const auto& other = excluded_sets[current_other_interval];
+    if (other.second < mine.first) {
+      // other interval is before ours
+      ++current_other_interval;
+    } else if (other.first > mine.second) {
+      // other interval is after ours
+      ++current_interval;
+    } else {
+      // Intervals intersect
+      return true;
+    }
+  }
+  return false;
+}
+
+void AddSet(const IntervalVec** const sets, const int set_to_add,
+            std::vector<std::unique_ptr<const IntervalVec>>* const storage) 
noexcept {
+  if (*sets != nullptr && (*sets)->size() != 0) {
+    for (auto& interval : (**sets)) {
+      if (set_to_add >= interval.first &&
+          set_to_add <= interval.second) {
+        return;
+      }
+    }
+  }
+  storage->emplace_back(
+      std::make_unique<IntervalVec>(1, std::make_pair(set_to_add, 
set_to_add)));
+  MergeSets(sets, storage->back().get(), storage);
+}
+
+int GetSetMapping(const int set, std::vector<int>* const set_mapping) noexcept 
{
+  if (set == -1) return -1;
+  int temp = set;
+  while ((*set_mapping)[temp] != temp) {
+    temp = (*set_mapping)[temp];
+  }
+  (*set_mapping)[set] = temp;
+  return temp;
+}
+
+void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const 
combined_excluded_sets_ptr,
+                                        const IntervalVec* const 
new_excluded_sets,
+                                        std::vector<const IntervalVec*>* const 
excluded_sets_ptr,
+                                        const int set_id,
+                                        const int first_node_in_set,
+                                        const size_t new_node_id,
+                                        const std::vector<int>& set_assignment,
+                                        std::vector<int>* const 
set_mapping_ptr,
+                                        const IntervalVec& inverse_set_mapping,
+                                        std::vector<std::unique_ptr<const 
IntervalVec>>* const
+                                          storage) noexcept {
+  const auto* previous_excluded_sets = *combined_excluded_sets_ptr;
+  MergeSets(combined_excluded_sets_ptr, new_excluded_sets, storage);
+  if (new_excluded_sets != nullptr) {
+    if (previous_excluded_sets == nullptr ||
+        *previous_excluded_sets != **(combined_excluded_sets_ptr)) {
+      // Their set's excluded sets list got larger, need to update the 
descendants
+      // of their set
+      auto& excluded_sets = *excluded_sets_ptr;
+      for (size_t j = first_node_in_set; j < new_node_id; ++j) {
+        if (GetSetMapping(set_assignment[j], set_mapping_ptr) == set_id ||
+            (excluded_sets[j] != nullptr &&
+             Intersect(inverse_set_mapping, *excluded_sets[j]))) {
+          MergeSets(&excluded_sets[j], *combined_excluded_sets_ptr, storage);
+        }
+      }
+    }
+  }
+}
+
+}  // namespace detail
+
+}  // namespace exec
+}  // namespace mxnet
diff --git a/src/executor/simple_partition_pass.h 
b/src/executor/simple_partition_pass.h
index 1ca0086..2a135b4 100644
--- a/src/executor/simple_partition_pass.h
+++ b/src/executor/simple_partition_pass.h
@@ -18,10 +18,10 @@
  */
 
 /*!
- * Copyright (c) 2019 by Contributors
+ * Copyright (c) 2019-2020 by Contributors
  * \file simple_partition_pass.h
  * \brief Simple pass for partitioning a graph.
- * \author Clement Fuji Tsang
+ * \author Clement Fuji Tsang, Przemyslaw Tredak
  */
 #ifndef MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_
 #define MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_
@@ -34,440 +34,237 @@
 #include <deque>
 #include <algorithm>
 #include <vector>
+#include <tuple>
 
 #include "exec_pass.h"
 
 namespace mxnet {
 namespace exec {
 
+namespace detail {
 
-/*!
- * \brief Custom graph class, which contains bi-directional nodes
- * required for traversing in both directions (from outputs to inputs
- * and vice versa). It is a non-owning layer on top of NNVM graph, since
- * NNVM graph enables traversing only in 1 direction (from outputs to inputs).
+using Interval = std::pair<int, int>;
+using IntervalVec = std::vector<Interval>;
+
+/* \brief Return the set that fully contains the other set, or nullptr
+ *        if neither set is a subset of another.
  */
-class BidirectionalGraph {
- public:
-  struct Node {
-    nnvm::Node* nnvmptr;
-    std::vector<Node*> inputs;
-    std::vector<Node*> outputs;
-  };
+const IntervalVec*  LargerSet(const IntervalVec* const first,
+                              const IntervalVec* const second) noexcept;
 
-  explicit BidirectionalGraph(const Graph &g) {
-    auto& idx = g.indexed_graph();
-    auto num_nodes = idx.num_nodes();
-    nodes.reserve(num_nodes);
-    nnvm2nid.reserve(num_nodes);
-    outputs.reserve(idx.outputs().size());
-    // Create all the nodes in a new graph from
-    // nodes in the NNVM graph and store them
-    // in nodes array
-    DFSVisit(g.outputs, [this](const nnvm::ObjectPtr& n) {
-      Node new_node;
-      new_node.nnvmptr = n.get();
-      nnvm2nid[n.get()] = static_cast<uint32_t>(nodes.size());
-      nodes.emplace_back(std::move(new_node));
-    });
-    // Create all connections between nodes in
-    // the graph (both directions)
-    for (const auto& it : nnvm2nid) {
-      nnvm::Node* nnvmnode = it.first;
-      uint32_t nid = it.second;
-      for (auto& n : nnvmnode->inputs) {
-        uint32_t input_nid = nnvm2nid[n.node.get()];
-        nodes[input_nid].outputs.emplace_back(&nodes[nid]);
-        nodes[nid].inputs.emplace_back(&nodes[input_nid]);
-      }
-    }
-    // Create output connections from the graph
-    for (auto& e : g.outputs) {
-      uint32_t nid = nnvm2nid[e.node.get()];
-      outputs.emplace_back(&nodes[nid]);
-    }
-  }
+/* \brief Compute the sum of the 2 sets and store it in my_set.
+ */
+void MergeSets(const IntervalVec** const my_set,
+               const IntervalVec* const other_set,
+               std::vector<std::unique_ptr<const IntervalVec>>* const storage) 
noexcept;
 
-  /* \brief Get all subsets of nodes, where:
-   *  - graph constructed from nodes in each subset is a connected graph
-   *  - every node fulfills a predicate is_compatible
-   *  - if nodes u and v are part of a subset, then for each path between
-   *    u and v in the original directed graph, all nodes on those paths
-   *    are also part of the subset
-   * \param is_compatible A function taking nnvm::Node* and returning bool
-   *                      which identifies which nodes should be included in
-   *                      subsets.
-   */
-  template<typename FCompatible>
-  std::vector<std::unordered_set<Node*>> get_subsets(FCompatible 
is_compatible) {
-    std::vector<std::unordered_set<Node*>> subgraphs;
-    std::unordered_set<Node*> incomp_set;
-    std::vector<std::pair<bool, PairSet>> separation_sets;
-    // Check each node for compatibility
-    // and, if it is incompatible, mark nodes
-    // on each side of it as not possible to be
-    // in the same subset
-    for (Node& node : nodes) {
-      if (!is_compatible(node.nnvmptr)) {
-        incomp_set.insert(&node);
-      }
-    }
-    for (Node& node : nodes) {
-      if (incomp_set.count(&node) != 0) {
-        // Check if all your inputs are incompatible too.
-        // If so, then your separation set does not matter,
-        // because it will covered by the sets of your inputs
-        bool inside_node = true;
-        for (Node* input : node.inputs) {
-          if (incomp_set.count(input) == 0) {
-            inside_node = false;
-          }
-        }
-        if (!inside_node) {
-          std::unordered_set<Node*> in_graph;
-          std::unordered_set<Node*> out_graph;
-          std::vector<Node*> dummy_head;
-          dummy_head.emplace_back(&node);
-          DFS(dummy_head, false, [&out_graph](Node* node) {
-              out_graph.insert(node);
-          });
-          DFS(dummy_head, true, [&in_graph](Node* node) {
-              in_graph.insert(node);
-          });
-            separation_sets.push_back(std::make_pair(true,
-                                                     std::make_pair(in_graph, 
out_graph)));
-        } else {
-          separation_sets.push_back(std::make_pair(false, PairSet()));
-        }
-      } else {
-        separation_sets.push_back(std::make_pair(false, PairSet()));
-      }
-    }
-    IncompMap incomp_map;
-    // For each node construct the map of nodes that cannot be in
-    // the same subset
-    index_t num_nodes = nodes.size();
-    for (index_t i = 0; i < num_nodes; ++i) {
-      const auto n = &(nodes[i]);
-      if (incomp_set.count(n) == 0) {
-        for (index_t j = i + 1; j < num_nodes; ++j) {
-          const auto& sep_set_pair = separation_sets[j];
-          if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
-            const auto& p = sep_set_pair.second;
-            if (p.first.count(n)) {
-              incomp_map[n].insert(p.second.begin(), p.second.end());
-            } else if (p.second.count(n)) {
-              incomp_map[n].insert(p.first.begin(), p.first.end());
-            }
-          }
-        }
-        for (index_t j = i - 1; j >= 0; --j) {
-          const auto& sep_set_pair = separation_sets[j];
-          if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
-            const auto& p = sep_set_pair.second;
-            if (p.first.count(n)) {
-              incomp_map[n].insert(p.second.begin(), p.second.end());
-            } else if (p.second.count(n)) {
-              incomp_map[n].insert(p.first.begin(), p.first.end());
-            }
-          }
-        }
-        for (Node* incomp_n : incomp_set) {
-          incomp_map[n].erase(incomp_n);
-        }
-      }
-    }
-    std::unordered_set<Node*> unused_set;
+/* \brief Returns true if there is non-empty intersection
+ *        between the 2 sets.
+ */
+bool Intersect(const IntervalVec& checked_sets,
+               const IntervalVec& excluded_sets) noexcept;
 
-    for (auto& n : nodes) {
-      if (incomp_set.count(&n) == 0) {
-        unused_set.insert(&n);
-      }
-    }
-    std::unordered_set<Node*> visited;
-    std::deque<Node*> stack(outputs.begin(), outputs.end());
-    // Create subsets
-    while (!stack.empty()) {
-      Node* vertex = stack.front();
-      stack.pop_front();
-      if (!visited.count(vertex)) {
-        visited.insert(vertex);
-        if (unused_set.count(vertex)) {
-          subgraphs.emplace_back(naive_grow_subgraph(vertex, &unused_set, 
&incomp_map));
-        }
-        for (Node* input : vertex->inputs) {
-          stack.emplace_back(input);
-        }
-      }
-    }
-    return subgraphs;
-  }
+/* \brief Add a single entry to the sets.
+ */
+void AddSet(const IntervalVec** const sets, const int set_to_add,
+            std::vector<std::unique_ptr<const IntervalVec>>* const storage) 
noexcept;
 
- private:
-  using PairSet = std::pair<std::unordered_set<Node*>, 
std::unordered_set<Node*>>;
-  using PairVec = std::pair<std::vector<Node*>, std::vector<Node*>>;
-  using IncompMap = std::unordered_map<Node*, std::unordered_set<Node*>>;
+/* \brief Get the true mapping of the set (which could change
+ *        due to merging of multiple sets.
+ */
+int GetSetMapping(const int set, std::vector<int>* const set_mapping) noexcept;
 
-  /* \brief Traverse the graph using DFS in either direction.
-   * \param heads Starting nodes for the DFS algorithm.
-   * \param reverse If true, DFS will traverse the graph from
-   *                outputs to inputs. Otherwise, it will
-   *                traverse the graph from inputs to outputs.
-   * \param fvisit Function to call on each visisted node.
-   */
-  template <typename FVisit>
-  void DFS(const std::vector<Node*>& heads, bool reverse, FVisit fvisit) {
-    std::unordered_set<Node*> visited;
-    std::vector<Node*> vec(heads.begin(), heads.end());
-    visited.reserve(heads.size());
-    while (!vec.empty()) {
-      Node* vertex = vec.back();
-      vec.pop_back();
-      if (visited.count(vertex) == 0) {
-        visited.insert(vertex);
-        fvisit(vertex);
-        std::vector<Node*> nexts = reverse ? vertex->inputs : vertex->outputs;
-        for (Node* node : nexts) {
-          if (visited.count(node) == 0) {
-            vec.emplace_back(node);
-          }
-        }
-      }
-    }
-  }
+/* \brief Check if 2 ids are on the same side of the cutoff
+ *        (so either both on the FWD side or the BWD side).
+ */
+inline bool IsSamePass(const int my_id, const int their_id, const int cutoff) 
noexcept {
+  return (my_id > cutoff && their_id > cutoff) ||
+         (my_id <= cutoff && their_id <= cutoff);
+}
 
-  /* \brief Get the connected subgraph that contains the head node,
-   *        only previously unused nodes, according to the rules
-   *        from incompatibility map.
-   * \param head Node which needs to be part of the returned subgraph.
-   * \param unused_set Only nodes from this set will be considered when
-   *                   adding to the growing subgraph.
-   * \param incomp_map Map containing data on which nodes are incompatible
-   *                   to be in the same subgraph.
-   */
-  std::unordered_set<Node*> naive_grow_subgraph(Node* head,
-                                                std::unordered_set<Node*>* 
unused_set,
-                                                IncompMap* incomp_map) {
-    std::unordered_set<Node*> subgraph;
-    std::unordered_set<Node*> incomp_set;
-    std::deque<Node*> stack;
-    stack.emplace_back(head);
-    while (!stack.empty()) {
-      Node* vertex = stack.back();
-      stack.pop_back();
-      if (unused_set->count(vertex) && !incomp_set.count(vertex)) {
-        unused_set->erase(vertex);
-        subgraph.insert(vertex);
-        incomp_set.insert((*incomp_map)[vertex].begin(), 
(*incomp_map)[vertex].end());
-        // Traverse the grpah in both directions
-        for (Node* input : vertex->inputs) {
-          if (unused_set->count(input) && !incomp_set.count(input)) {
-            stack.emplace_back(input);
-          }
-        }
-        for (Node* output : vertex->outputs) {
-          if (unused_set->count(output) && !incomp_set.count(output)) {
-            stack.emplace_back(output);
-          }
-        }
-      }
-    }
-    return subgraph;
-  }
+/* \brief Check if adding a new node to the set changes the excluded set of 
the future
+ *        fused node. If so, update all descendants of the fused node.
+ *
+ * \param combined_excluded_sets_ptr pointer to the set's list of excluded sets
+ *                                   before adding the new node
+ * \param new_excluded_sets list of excluded sets of the new node
+ * \param excluded_sets_ptr pointer to the lists of excluded sets of all the 
nodes
+ * \param set_id number of the set, to which the new node is added
+ * \param first_node_in_set id of the first node in the set, according to 
topological ordering
+ * \param new_node_id id of the node added to the set
+ * \param set_assignment assignment of sets
+ * \param set_mapping_ptr pointer to the mappings of sets
+ * \param inverse_set_mapping inverse mapping of the set
+ * \param storage memory storage
+ */
+void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const 
combined_excluded_sets_ptr,
+                                        const IntervalVec* const 
new_excluded_sets,
+                                        std::vector<const IntervalVec*>* const 
excluded_sets_ptr,
+                                        const int set_id,
+                                        const int first_node_in_set,
+                                        const size_t new_node_id,
+                                        const std::vector<int>& set_assignment,
+                                        std::vector<int>* const 
set_mapping_ptr,
+                                        const IntervalVec& inverse_set_mapping,
+                                        std::vector<std::unique_ptr<const 
IntervalVec>>* const
+                                          storage) noexcept;
 
-  friend class Graph;
+}  // namespace detail
 
-  std::vector<Node> nodes;
-  std::unordered_map<nnvm::Node*, uint32_t> nnvm2nid;
-  std::vector<Node*> outputs;
-};  // class BidirectionalGraph
 
-using NodeEntrySet = std::unordered_set<nnvm::NodeEntry, nnvm::NodeEntryHash,
-                                        nnvm::NodeEntryEqual>;
-using NodeRawPtrSet = std::unordered_set<nnvm::Node*>;
+/* \brief Get all subsets of nodes, where:
+ *  - graph constructed from nodes in each subset is a connected graph
+ *  - every node fulfills a predicate is_compatible
+ *  - if nodes u and v are part of a subset, then for each path between
+ *    u and v in the original directed graph, all nodes on those paths
+ *    are also part of the subset
+ * \param g NNVM graph
+ * \param num_forward_outputs Number of outputs from the graph that come
+ *                            from the forward pass
+ * \param is_compatible A function taking nnvm::Node* and returning bool
+ *                      which identifies which nodes could be included in
+ *                      subsets.
+ * \param is_input_only_compatible A function taking nnvm::Node* and
+ *                                 returning bool which identifies which
+ *                                 nodes could be included in subsets only
+ *                                 as the first operations (their inputs
+ *                                 need to be excluded).
+ * \return tuple (subset assignment, number of found subsets)
+ */
+template<typename FCompatible, typename FInputOnlyCompatible>
+std::tuple<std::vector<int>, int> GetCompatibleSubsets(
+    const Graph& g,
+    const size_t num_forward_outputs,
+    FCompatible is_compatible,
+    FInputOnlyCompatible is_input_only_compatible) {
 
-/*!
- * \brief Get the output nodes of the subgraph in the main graph.
- * \return a map between the node in the main graph and the output index of 
the subgraph node
-*/
-nnvm::NodeEntryMap<uint32_t> GetSubgraphOutputs(Graph g, NodeRawPtrSet 
subgraph_set) {
-  nnvm::NodeEntryMap<uint32_t> outputs;
-  uint32_t count = 0;
-  for (auto& e : g.outputs) {
-    if (subgraph_set.count(e.node.get()) && !outputs.count(e)) {
-      outputs.insert({e, count++});
+  using namespace detail;
+  const auto& idx = g.indexed_graph();
+  std::vector<int> set_assignment(idx.num_nodes(), -1);
+  std::vector<const std::vector<Interval>*> excluded_sets(idx.num_nodes());
+  std::vector<int> set_mapping;
+  std::vector<const std::vector<Interval>*> combined_excluded_sets;
+  std::vector<int> first_node_in_set;
+  std::vector<const std::vector<Interval>*> inverse_set_mapping;
+  std::vector<std::unique_ptr<const std::vector<Interval>>> storage;
+
+  int last_forward_node = -1;
+  for (size_t i = 0; i < num_forward_outputs; ++i) {
+    const int output_id = idx.outputs()[i].node_id;
+    if (last_forward_node < output_id) {
+      last_forward_node = output_id;
     }
   }
-  DFSVisit(g.outputs, [&subgraph_set, &outputs, &count](const nnvm::ObjectPtr 
&node){
-    if (!subgraph_set.count(node.get())) {
-      for (auto& e : node->inputs) {
-        if (subgraph_set.count(e.node.get()) && !outputs.count(e)) {
-          outputs.insert({e, count++});
-        }
-      }
-    }
-  });
-  return outputs;
-}
 
-/*!
- * \brief Create new input nodes of the subgraph and plug them.
- * \return the inputs of the subgraph node in the main graph
-*/
-std::vector<nnvm::NodeEntry> GetSubgraphInputs(Graph g, NodeRawPtrSet 
subgraph_set) {
-  std::vector<nnvm::NodeEntry> inputs;
-  nnvm::NodeEntryMap<nnvm::NodeEntry> entry_map;
-  DFSVisit(g.outputs, [&subgraph_set, &inputs, &entry_map](const 
nnvm::ObjectPtr &node){
-    if (subgraph_set.count(node.get())) {
-      for (auto &e : node->inputs) {
-        if (!subgraph_set.count(e.node.get())) {
-          if (entry_map.count(e)) {
-            e = entry_map[e];
+  int num_sets = 0;
+  for (size_t i = 0; i < idx.num_nodes(); ++i) {
+    const auto& node = idx[i];
+    auto& my_excluded_sets = excluded_sets[i];
+    for (const auto& input : node.inputs) {
+      MergeSets(&my_excluded_sets, excluded_sets[input.node_id], &storage);
+    }
+    if (is_compatible(node.source)) {
+      int my_set = -1;
+      for (const auto& input : node.inputs) {
+        int their_set = GetSetMapping(set_assignment[input.node_id], 
&set_mapping);
+        if (their_set != -1 &&
+            their_set != my_set &&
+            IsSamePass(i, input.node_id, last_forward_node) &&
+            (my_excluded_sets == nullptr ||
+            !Intersect(*inverse_set_mapping[their_set], *my_excluded_sets))) {
+          if (my_set == -1) {
+            my_set = their_set;
+            
CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]),
+                                               my_excluded_sets,
+                                               &excluded_sets,
+                                               their_set,
+                                               first_node_in_set[their_set],
+                                               i,
+                                               set_assignment,
+                                               &set_mapping,
+                                               
*(inverse_set_mapping[their_set]),
+                                               &storage);
           } else {
-            auto new_node = nnvm::Node::Create();
-            new_node->attrs.name = "input_" + std::to_string(inputs.size());
-            entry_map.insert({e, nnvm::NodeEntry{new_node, 0, 0}});
-            inputs.push_back(e);
-            e.node = new_node;
-            e.index = 0;
+            MergeSets(&inverse_set_mapping[my_set],
+                      inverse_set_mapping[their_set],
+                      &storage);
+            set_mapping[their_set] = my_set;
+            first_node_in_set[my_set] = std::min(first_node_in_set[my_set],
+                                                 first_node_in_set[their_set]);
+            
CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]),
+                                               combined_excluded_sets[my_set],
+                                               &excluded_sets,
+                                               my_set,
+                                               first_node_in_set[my_set],
+                                               i,
+                                               set_assignment,
+                                               &set_mapping,
+                                               *(inverse_set_mapping[my_set]),
+                                               &storage);
           }
         }
       }
+      if (my_set == -1) {
+        set_mapping.emplace_back(num_sets);
+        combined_excluded_sets.emplace_back(my_excluded_sets);
+        first_node_in_set.emplace_back(i);
+        storage.emplace_back(std::make_unique<std::vector<Interval>>(
+                               1, std::make_pair(num_sets,
+                                                 num_sets)));
+        inverse_set_mapping.emplace_back(storage.back().get());
+        my_set = num_sets++;
+      }
+      set_assignment[i] = my_set;
+    } else {
+      for (const auto& input : node.inputs) {
+        int their_set = GetSetMapping(set_assignment[input.node_id], 
&set_mapping);
+        if (their_set != -1) {
+          AddSet(&my_excluded_sets, their_set, &storage);
+        }
+      }
+      if ((is_input_only_compatible != nullptr) &&
+          is_input_only_compatible(node.source)) {
+        set_mapping.emplace_back(num_sets);
+        combined_excluded_sets.emplace_back(my_excluded_sets);
+        first_node_in_set.emplace_back(i);
+        storage.emplace_back(std::make_unique<std::vector<Interval>>(
+                               1, std::make_pair(num_sets,
+                                                 num_sets)));
+        inverse_set_mapping.emplace_back(storage.back().get());
+        set_assignment[i] = num_sets++;
+      }
     }
-  });
-  // Fix ordering of w.r.t to topology
-  Graph _g;
-  _g.outputs = g.outputs;
-  const auto &idx = _g.indexed_graph();
-  std::sort(inputs.begin(), inputs.end(),
-      [&idx, &entry_map](const nnvm::NodeEntry lhs, const nnvm::NodeEntry rhs) 
{
-        return idx.entry_id(entry_map.at(lhs)) < 
idx.entry_id(entry_map.at(rhs));
-      });
-  return inputs;
-}
-
-std::unordered_map<uint32_t, uint32_t> GetGraphInputsMap(const Graph& g) {
-  std::unordered_map<uint32_t, uint32_t> outputs;
-  auto& idx = g.indexed_graph();
-  outputs.reserve(idx.num_nodes());
-  std::vector<uint32_t> input_nodes = idx.input_nodes();
-  for (size_t i = 0; i < input_nodes.size(); ++i) {
-    outputs[input_nodes[i]] = static_cast<uint32_t>(i);
   }
-  return outputs;
-}
 
-/*!
- * \brief Helper function to display what nodes are in a specific subset.
- */
-void dispNodesSet(Graph g, NodeRawPtrSet s) {
-  DFSVisit(g.outputs, [&s](const nnvm::ObjectPtr n){
-    if (s.count(n.get())) {
-      std::cout << "  Y " << n->attrs.name << std::endl;
-    } else {
-      std::cout << "  N " << n->attrs.name << std::endl;
-    }
-  });
-}
+  for (int& set : set_assignment) {
+    set = GetSetMapping(set, &set_mapping);
+  }
 
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- */
-template<typename FCreateNode>
-Graph ReplaceSubgraphs(Graph&& g, const std::vector<NodeRawPtrSet>& 
subgraph_sets,
-                       FCreateNode create_subgraph_node) {
-  for (auto subgraph_set : subgraph_sets) {
-    // Create MXNet subgraph
-    Graph subgraph;
-    const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
-    subgraph.outputs.resize(sub_outputs_in_main.size());
-    for (auto p : sub_outputs_in_main) {
-      subgraph.outputs[p.second] = p.first;
+  std::vector<int> set_reorder(num_sets, 0);
+  // First count the number of elements in each set.
+  for (int& set : set_assignment) {
+    if (set != -1) {
+      ++set_reorder[set];
     }
-    // To generate a subgraph an input has to be replaced by data node (no op)
-    // and it has to be agnostic to the node from which it's an output
-    // (For example, even if two inputs are two different outputs from the 
same node,
-    // they need to be replaced by two completely separate data nodes)
-    auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
-    auto subgraph_node = create_subgraph_node(subgraph);
-    subgraph_node->inputs = inputs;
-    // replug inputs of node out of subgraph to be output of the subgraph node
-    // if it was a node in the subgraph
-    DFSVisit(g.outputs,
-        [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const 
nnvm::ObjectPtr node) {
-      if (!subgraph_set.count(node.get())) {
-        for (auto &e : node->inputs) {
-          auto it = sub_outputs_in_main.find(e);
-          if (it != sub_outputs_in_main.end()) {
-            e.node = subgraph_node;
-            e.index = it->second;
-          }
-        }
-      }
-    });
-    // replug outputs of the graph to be output of the subgraph node
-    // if it was a node in the subgraph
-    for (auto &e : g.outputs) {
-      auto it = sub_outputs_in_main.find(e);
-      if (it != sub_outputs_in_main.end()) {
-        e.node = subgraph_node;
-        e.index = it->second;
-      }
+  }
+  // Then reorder them, removing sets that have
+  // only a single element.
+  int final_num_sets = 0;
+  for (int& set : set_reorder) {
+    if (set > 1) {
+      set = final_num_sets++;
+    } else {
+      set = -1;
     }
-    // move control dependencies between nodes of the subgraph and out of the 
subgraph
-    // to a dependencies between the subgraph node and the nodes out of the 
subgraph
-    DFSVisit(g.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr& 
node) {
-      for (auto &e : node->control_deps) {
-        if (subgraph_set.count(e.get()))
-          e = subgraph_node;
-      }
-    });
-    DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const 
nnvm::ObjectPtr& node) {
-      auto it = node->control_deps.begin();
-      while (it != node->control_deps.end()) {
-        if (subgraph_set.count(it->get())) {
-          ++it;
-        } else {
-          subgraph_node->control_deps.push_back(*it);
-          it = node->control_deps.erase(it);
-        }
-      }
-    });
   }
-  Graph new_graph;
-  new_graph.outputs = g.outputs;
-  return new_graph;
-}
 
-/* \brief Get all subsets of nodes, where:
- *  - graph constructed from nodes in each subset is a connected graph
- *  - every node fulfills a predicate is_compatible
- *  - if nodes u and v are part of a subset, then for each path between
- *    u and v in the original directed graph, all nodes on those paths
- *    are also part of the subset
- * \param g NNVM graph
- * \param is_compatible A function taking nnvm::Node* and returning bool
- *                      which identifies which nodes should be included in
- *                      subsets.
- */
-template<typename FCompatible>
-std::vector<NodeRawPtrSet> GetCompatibleSubsets(const Graph& g, FCompatible 
is_compatible) {
-  BidirectionalGraph biG = BidirectionalGraph(g);
-  std::vector<std::unordered_set<BidirectionalGraph::Node*>> subsets =
-    biG.get_subsets(is_compatible);
-  std::vector<NodeRawPtrSet> nnvm_subsets;
-  nnvm_subsets.reserve(subsets.size());
-  for (auto& subset : subsets) {
-    if (subset.size() > 1) {
-      NodeRawPtrSet node_set;
-      node_set.reserve(subset.size());
-      for (auto& n : subset) {
-        node_set.insert(n->nnvmptr);
-      }
-      nnvm_subsets.push_back(node_set);
+  for (int& set : set_assignment) {
+    if (set != -1) {
+      set = set_reorder[set];
     }
   }
-  return nnvm_subsets;
+
+  return std::make_tuple(std::move(set_assignment), final_num_sets);
 }
 
 }  // namespace exec
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index c56d8cf..23ab4a0 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -238,10 +238,7 @@ void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph * 
fwd_graph, nnvm::Grap
     common::CopyGraph(&unoptimized_graph, *full_graph, false);
 
     if 
(common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) {
-      full_graph->attrs["num_forward_outputs"] = 
std::make_shared<nnvm::any>(num_forward_outputs);
-      *full_graph = exec::FusePointwiseForward(std::move(*full_graph));
-      full_graph->attrs["num_forward_outputs"] = 
std::make_shared<nnvm::any>(num_forward_outputs);
-      *full_graph = exec::FusePointwiseBackward(std::move(*full_graph));
+      *full_graph = exec::FusePointwise(*full_graph, num_forward_outputs);
       // Check the topological order of inputs
       const auto &original_inputs = 
unoptimized_graph.indexed_graph().input_nodes();
       const auto &new_inputs = full_graph->indexed_graph().input_nodes();
diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu
index cb13dbf..1b914cc 100644
--- a/src/operator/fusion/fused_op.cu
+++ b/src/operator/fusion/fused_op.cu
@@ -261,12 +261,40 @@ std::string FusedOp::GenerateCode(const 
std::vector<OpReqType> &req,
           const auto& var_name = g[node_id].source->attrs.name;
           const auto vec_name = "vec_" + var_name + "_" + std::to_string(i);
           load_index[node_id] = 0;
-          auto parse_tuple = [](const std::string& input, const std::string 
def) {
+          auto parse_tuple = [ndim](const std::string& input, const 
std::string& def) {
             std::string out = input;
-            replaceString(&out, "(", "{");
-            replaceString(&out, ")", "}");
+            replaceString(&out, " ", "");
+            if (out[0] == '(') {
+              replaceString(&out, "(", "{");
+              replaceString(&out, ")", "}");
+              // First check if out is ()
+              int n_entries = out.size() != 2;
+              for (size_t i = 1; i < out.size() - 1; ++i) {
+                if (out[i] == ',') {
+                  ++n_entries;
+                }
+              }
+              if (n_entries != ndim) {
+                out.pop_back();
+                for (int i = n_entries; i < ndim; ++i) {
+                  out += "," + def;
+                }
+                out += "}";
+              }
+            } else {
+              out = "{" + std::move(out);
+              for (int i = 1; i < ndim; ++i) {
+                out += "," + def;
+              }
+              out += "}";
+            }
             replaceString(&out, "None", def);
+            return out;
+          };
+          auto parse_int = [](const std::string& input, const std::string& 
def) {
+            std::string out = input;
             replaceString(&out, " ", "");
+            replaceString(&out, "None", def);
             return out;
           };
           auto build_tuple = [ndim](int axis, const std::string str, const 
std::string def) {
@@ -279,11 +307,11 @@ std::string FusedOp::GenerateCode(const 
std::vector<OpReqType> &req,
             }
             std::string tuple = "{";
             for (int i = 0; i < axis; i++) {
-                tuple = tuple + def + ",";
+                tuple += def + ",";
             }
             tuple += str;
             for (int i = axis + 1; i < ndim; i++) {
-                tuple = tuple + "," + def;
+                tuple += "," + def;
             }
             tuple += "}";
             return tuple;
@@ -295,12 +323,6 @@ std::string FusedOp::GenerateCode(const 
std::vector<OpReqType> &req,
             }
             return false;
           };
-          auto build_string_axis = [ndim](int axis) {
-            if (axis < 0) {
-                axis = ndim + axis;
-            }
-            return std::to_string(axis);
-          };
           auto build_string_end = [i, ndim, var_name](std::string* code) {
             std::string end_var_name = var_name + "_" + std::to_string(i) + 
"_end";
             *code += "op::Shape<" + std::to_string(ndim) + "> "+ end_var_name 
+ ";\n";
@@ -323,12 +345,15 @@ std::string FusedOp::GenerateCode(const 
std::vector<OpReqType> &req,
             }
             end = extra_var_name;
           } else {
-            begin = parse_tuple(source->attrs.dict.at("begin"), "0");
-            end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX");
             if (op_name == "slice_axis") {
+              begin = parse_int(source->attrs.dict.at("begin"), "0");
+              end = parse_int(source->attrs.dict.at("end"), "INT_MAX");
               int axis = std::stoi(source->attrs.dict.at("axis"));
               begin = build_tuple(axis, begin, "0");
               end = build_tuple(axis, end, "INT_MAX");
+            } else {
+              begin = parse_tuple(source->attrs.dict.at("begin"), "0");
+              end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX");
             }
             if (check_shapes) {
               if (check_tuple(begin) && check_tuple(end)) {
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 1febf8d..e81c794 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -340,6 +340,27 @@ def test_fusion_reshape_executor():
     out = f.forward(is_train=False, data1=data, data2=data)
     assert out[0].sum().asscalar() == 150
 
+@with_seed()
+def test_fusion_cycle():
+    from mxnet.gluon import HybridBlock
+    class Test(HybridBlock):
+        def __init__(self, **kwargs):
+            super(Test, self).__init__(**kwargs)
+
+        def hybrid_forward(self, F, x, y):
+            x = F.relu(x)
+            y = F.relu(y)
+            z1 = F.expand_dims(F.sum_axis(x, axis=1), axis=1)
+            z2 = F.expand_dims(F.sum_axis(y, axis=1), axis=1)
+            return x + z2, y + z1
+
+    t = Test()
+    a = mx.nd.zeros(shape=(10,1), ctx=mx.gpu())
+    b = mx.nd.zeros(shape=(10,1), ctx=mx.gpu())
+    t.hybridize(static_alloc=True, static_shape=True)
+    out = t(a, b)
+    mx.nd.waitall()
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

Reply via email to