This is an automated email from the ASF dual-hosted git repository.
ptrendx pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/v1.x by this push:
new d2600f2 [1.x] Backport Faster pointwise fusion graph pass (#19269)
(#19413)
d2600f2 is described below
commit d2600f21c2e9bb377bb3f77264a378d2b84236a7
Author: Przemyslaw Tredak <[email protected]>
AuthorDate: Tue Nov 17 15:26:17 2020 -0800
[1.x] Backport Faster pointwise fusion graph pass (#19269) (#19413)
* Faster pointwise fusion graph pass (#19269)
* Faster pointwise fusion graph pass
* Fix lint
* Fix lint 2
* Fixes
* Fixing slice parameter handling in fusion
* Fixing the slice fix
* Fix the cycle bug
* Added test
* Fix lint
* Fix merging of subgraphs
* Fixes from review
* Use std::tie instead of C++17 structured binding
* More fixes for lack of c++17
* Fix
---
src/executor/exec_pass.h | 16 +-
src/executor/graph_executor.cc | 5 +-
src/executor/pointwise_fusion_pass.cc | 519 +++++++++++++++--------------
src/executor/simple_partition_pass.cc | 265 +++++++++++++++
src/executor/simple_partition_pass.h | 599 +++++++++++-----------------------
src/imperative/cached_op.h | 5 +-
src/operator/fusion/fused_op.cu | 51 ++-
tests/python/gpu/test_fusion.py | 21 ++
8 files changed, 809 insertions(+), 672 deletions(-)
diff --git a/src/executor/exec_pass.h b/src/executor/exec_pass.h
index 4552fa1..07ccc02 100644
--- a/src/executor/exec_pass.h
+++ b/src/executor/exec_pass.h
@@ -222,22 +222,14 @@ Graph DetectInplaceAddTo(Graph g);
Graph EliminateCommonExpr(Graph && g);
/*!
- * \brief Fuse pointwise operations in the forward pass.
+ * \brief Fuse pointwise operations in the graph.
*
* \param g input graph (needs to be entire graph, not just forward part)
+ * \param num_forward_outputs number of outputs in the graph produced by the
forward pass
*
- * \return graph with fused pointwise operations in the forward pass
+ * \return copy of the graph with fused pointwise operations
*/
-Graph FusePointwiseForward(Graph&& g);
-
-/*!
- * \brief Fuse pointwise operations in the backward pass.
- *
- * \param g input graph (needs to be entire graph, not just forward part)
- *
- * \return graph with fused pointwise operations in the backward pass
- */
-Graph FusePointwiseBackward(Graph&& g);
+Graph FusePointwise(const Graph& g, const size_t num_forward_outputs);
/*!
* \brief Issue a one-time warning that fusion is not possible for this
platform or build.
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 28b79ae..3f2c7c9 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -1009,10 +1009,7 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
common::CopyGraph(&unoptimized_graph, g, false);
if
(common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) {
- g.attrs["num_forward_outputs"] =
std::make_shared<nnvm::any>(num_forward_outputs_);
- g = FusePointwiseForward(std::move(g));
- g.attrs["num_forward_outputs"] =
std::make_shared<nnvm::any>(num_forward_outputs_);
- g = FusePointwiseBackward(std::move(g));
+ g = exec::FusePointwise(std::move(g), num_forward_outputs_);
// Check the topological order of inputs
const auto &original_inputs =
unoptimized_graph.indexed_graph().input_nodes();
const auto &new_inputs = g.indexed_graph().input_nodes();
diff --git a/src/executor/pointwise_fusion_pass.cc
b/src/executor/pointwise_fusion_pass.cc
index 3203f67..961dade 100644
--- a/src/executor/pointwise_fusion_pass.cc
+++ b/src/executor/pointwise_fusion_pass.cc
@@ -31,6 +31,7 @@
#include <nnvm/pass_functions.h>
#include <algorithm>
#include <queue>
+#include <chrono>
#include "./simple_partition_pass.h"
#include "../operator/fusion/fused_op-inl.h"
#include "../operator/fusion/fused_op.h"
@@ -57,281 +58,323 @@ void WarnFusionNotSupported() {
#if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
namespace {
- bool IsFusionCompatible(nnvm::Node* n) {
- using namespace mxnet::fusion;
- if (n->op() == nullptr)
- return false;
- std::string op_name = n->op()->name;
- if (ops_desc.count(op_name))
- return true;
- if (slice_ops.count(op_name))
- return false;
- if (std::find(variable_io_ops.begin(),
- variable_io_ops.end(),
- op_name) !=
- variable_io_ops.end())
- return true;
- if (op_name == "LeakyReLU") {
- std::string act_type = n->attrs.dict.at("act_type");
- if (LeakyReLU_ops.count(act_type))
- return true;
- else
- return false;
- }
- if (op_name == "_backward_LeakyReLU") {
- std::string act_type = n->attrs.dict.at("act_type");
- if (LeakyReLU_bwd_ops.count(act_type))
- return true;
- else
- return false;
- }
+
+bool IsFusionCompatible(const nnvm::Node* n) {
+ using namespace mxnet::fusion;
+ if (n->op() == nullptr)
+ return false;
+ const std::string& op_name = n->op()->name;
+ if (ops_desc.count(op_name))
+ return true;
+ if (slice_ops.count(op_name))
return false;
+ if (std::find(variable_io_ops.begin(),
+ variable_io_ops.end(),
+ op_name) !=
+ variable_io_ops.end())
+ return true;
+ if (op_name == "LeakyReLU") {
+ std::string act_type = n->attrs.dict.at("act_type");
+ if (LeakyReLU_ops.count(act_type))
+ return true;
+ else
+ return false;
}
+ if (op_name == "_backward_LeakyReLU") {
+ std::string act_type = n->attrs.dict.at("act_type");
+ if (LeakyReLU_bwd_ops.count(act_type))
+ return true;
+ else
+ return false;
+ }
+ return false;
+}
- bool IsInputsOnlyCompatible(nnvm::Node* n) {
- using namespace mxnet::fusion;
- if (n->op() == nullptr)
- return false;
- std::string op_name = n->op()->name;
- if (slice_ops.count(op_name)) {
- if (op_name == "slice") {
- // slice with non-default step attribute is not supported
- // currently
- if (n->attrs.dict.count("step") &&
- !(n->attrs.dict.at("step") == "()" ||
- n->attrs.dict.at("step") == "[]")) {
- return false;
- }
+bool IsInputsOnlyCompatible(const nnvm::Node* n) {
+ using namespace mxnet::fusion;
+ if (n->op() == nullptr)
+ return false;
+ const std::string& op_name = n->op()->name;
+ if (slice_ops.count(op_name)) {
+ if (op_name == "slice") {
+ // slice with non-default step attribute is not supported
+ // currently
+ if (n->attrs.dict.count("step") &&
+ !(n->attrs.dict.at("step") == "()" ||
+ n->attrs.dict.at("step") == "[]")) {
+ return false;
}
- return true;
}
- return false;
+ return true;
}
+ return false;
+}
+
+void CreateSubgraphNode(const nnvm::Graph& subgraph,
+ size_t inputs_size,
+ nnvm::Node* subgraph_node) {
+ static const Op* fused_op_ptr = Op::Get("_FusedOp");
+
subgraph_node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>());
+ subgraph_node->attrs.subgraphs.back()->outputs = subgraph.outputs;
+ subgraph_node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
+ subgraph_node->attrs.dict["num_outputs"] =
std::to_string(subgraph.outputs.size());
+ subgraph_node->attrs.op = fused_op_ptr;
+ subgraph_node->op()->attr_parser(&(subgraph_node->attrs));
+}
+
+struct EntryInfo {
+ int source_node;
+ int index;
+};
- nnvm::ObjectPtr CreateSubgraphNode(const Graph& subgraph, size_t
inputs_size) {
- nnvm::Symbol subgraph_sym;
- auto node = nnvm::Node::Create();
- subgraph_sym.outputs = subgraph.outputs;
-
node->attrs.subgraphs.emplace_back(std::make_shared<nnvm::Symbol>(subgraph_sym));
- node->attrs.name = "FusedOp";
- node->attrs.dict["num_inputs"] = std::to_string(inputs_size);
- node->attrs.dict["num_outputs"] = std::to_string(subgraph.outputs.size());
- node->attrs.op = Op::Get("_FusedOp");
- node->op()->attr_parser(&(node->attrs));
- return node;
+inline int SetInsert(const EntryInfo& new_elem,
+ std::vector<EntryInfo>* elements) {
+ for (size_t i = 0; i < elements->size(); ++i) {
+ if ((new_elem.source_node == elements->at(i).source_node) &&
+ (new_elem.index == elements->at(i).index)) {
+ return i;
+ }
}
+ elements->emplace_back(new_elem);
+ return elements->size() - 1;
+}
+
} // namespace
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- * This function is used specifically in pointwise fusion.
+/* \brief Create (if necessary) copy of the graph, replacing subgraphs with
+ * FusedOps. If there are no subgraphs to be replaced, the
+ * original graph is returned.
+ * \param g original graph.
+ * \param subgraph_assignment assignment of nodes in g's IndexedGraphs to
+ * subgraphs. Values from -1 to num_subgraphs - 1
+ * are allowed, -1 means that the node is not in a
+ * subgraph.
+ * \param num_subgraphs number of subgraphs.
+ * \param create_subgraph_node function used to prepare the subgraph node.
*/
template<typename FCreateNode>
-Graph ReplaceSubgraphsPointwise(Graph&& g, const std::vector<NodeRawPtrSet>&
subgraph_sets,
- FCreateNode create_subgraph_node) {
- for (auto subgraph_set : subgraph_sets) {
- // Create MXNet subgraph
- Graph subgraph;
- const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
- subgraph.outputs.resize(sub_outputs_in_main.size());
- for (auto p : sub_outputs_in_main) {
- subgraph.outputs[p.second] = p.first;
- }
- // To generate a subgraph an input has to be replaced by data node (no op)
- // and it has to be agnostic to the node from which it's an output
- // (For example, even if two inputs are two different outputs from the
same node,
- // they need to be replaced by two completely separate data nodes)
- auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
- auto subgraph_node = create_subgraph_node(subgraph, inputs.size());
- subgraph_node->inputs = inputs;
- // replug inputs of node out of subgraph to be output of the subgraph node
- // if it was a node in the subgraph
- DFSVisit(g.outputs,
- [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const
nnvm::ObjectPtr node) {
- if (!subgraph_set.count(node.get())) {
- for (auto &e : node->inputs) {
- auto it = sub_outputs_in_main.find(e);
- if (it != sub_outputs_in_main.end()) {
- e.node = subgraph_node;
- e.index = it->second;
- }
- }
- }
- });
- // replug outputs of the graph to be output of the subgraph node
- // if it was a node in the subgraph
- for (auto &e : g.outputs) {
- auto it = sub_outputs_in_main.find(e);
- if (it != sub_outputs_in_main.end()) {
- e.node = subgraph_node;
- e.index = it->second;
- }
+Graph CopyAndReplaceSubgraphs(const Graph& g,
+ const std::vector<int>& subgraph_assignment,
+ const int num_subgraphs,
+ FCreateNode create_subgraph_node) {
+ if (num_subgraphs == 0) {
+ return g;
+ }
+
+ Graph ret;
+
+ const auto& idx = g.indexed_graph();
+
+ CHECK_EQ(idx.num_nodes(), subgraph_assignment.size()) <<
+ "Every node in the graph needs to be included in subgraph assignment.";
+
+ std::vector<nnvm::ObjectPtr> new_nodes;
+ new_nodes.reserve(idx.num_nodes());
+ struct SubgraphInfo {
+ nnvm::Graph graph;
+ nnvm::ObjectPtr subgraph_node;
+ std::vector<EntryInfo> outputs;
+ std::vector<EntryInfo> inputs;
+ std::vector<nnvm::ObjectPtr> input_nodes;
+ };
+
+ std::vector<SubgraphInfo> subgraphs(num_subgraphs);
+
+ for (auto& info : subgraphs) {
+ info.subgraph_node = nnvm::Node::Create();
+ }
+
+ for (size_t i = 0; i < idx.num_nodes(); ++i) {
+ // First copy the node, it will be used
+ // either in the new graph or inside a
+ // subgraph. Variables are not copied.
+ if (idx[i].source->op() != nullptr) {
+ new_nodes.emplace_back(nnvm::Node::Create());
+ auto& node_copy = new_nodes.back();
+ node_copy->attrs = idx[i].source->attrs;
+ node_copy->info = idx[i].source->info;
+ } else {
+ new_nodes.emplace_back(idx[i].weak_ref.lock());
+ continue;
}
- // move control dependencies between nodes of the subgraph and out of the
subgraph
- // to a dependencies between the subgraph node and the nodes out of the
subgraph
- DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const
nnvm::ObjectPtr& node) {
- if (subgraph_set.count(node.get())) {
- auto it = node->control_deps.begin();
- static auto& is_fusion =
Op::GetAttr<exec::TIsFusionHelper>("TIsFusionHelper");
- std::vector<nnvm::ObjectPtr> new_control_deps;
- // Use the first control dependency to get the inferattr helper
- if (it != node->control_deps.end()) {
- if (subgraph_set.count(it->get())) {
- new_control_deps.push_back(*it);
+ auto& node_copy = new_nodes.back();
+ const int subgraph_id = subgraph_assignment[i];
+ if (subgraph_id != -1) {
+ auto& info = subgraphs[subgraph_id];
+ for (const auto& input : idx[i].inputs) {
+ const int their_subgraph = subgraph_assignment[input.node_id];
+ if (their_subgraph == subgraph_id) {
+ node_copy->inputs.emplace_back(new_nodes[input.node_id],
+ input.index,
+ input.version);
+ } else {
+ int input_num;
+ int output_num;
+ if (their_subgraph == -1) {
+ input_num = SetInsert({static_cast<int>(input.node_id),
+ static_cast<int>(input.index)},
&(info.inputs));
} else {
- if ((*it)->is_variable() || !is_fusion.get((*it)->op(), false)) {
- uint32_t node_id = subgraph_node->control_deps.size();
- subgraph_node->control_deps.push_back(*it);
- auto helper_node = op::MakeNode("_FusedOpOutHelper",
- "FusedOp_" + node->attrs.name +
"_outhelper",
- nullptr,
- nullptr,
- nullptr);
- helper_node->attrs.parsed =
- FusedOpHelperParamPtr(new FusedOpHelperParam(
- nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
- node_id));
- new_control_deps.push_back(helper_node);
+ auto& their_subgraph_info = subgraphs[their_subgraph];
+ output_num = SetInsert({static_cast<int>(input.node_id),
+ static_cast<int>(input.index)},
+ &(their_subgraph_info.outputs));
+ input_num = SetInsert({static_cast<int>(idx.num_nodes() +
their_subgraph),
+ output_num},
+ &(info.inputs));
+ }
+ if (static_cast<size_t>(input_num) == info.input_nodes.size()) {
+ info.input_nodes.emplace_back(nnvm::Node::Create());
+ info.input_nodes.back()->attrs.name = "input_" +
std::to_string(input_num);
+ if (their_subgraph == -1) {
+ info.subgraph_node->inputs.emplace_back(new_nodes[input.node_id],
+ input.index,
+ input.version);
} else {
- new_control_deps.push_back(*it);
+
info.subgraph_node->inputs.emplace_back(subgraphs[their_subgraph].subgraph_node,
+ output_num,
+ input.version);
}
}
- ++it;
+ node_copy->inputs.emplace_back(info.input_nodes[input_num], 0, 0);
}
- node->control_deps = new_control_deps;
}
- });
-
- std::ostringstream name_oss;
- // the name of the new node will be the concatenation of all the node
names in the subgraph
- DFSVisit(subgraph.outputs, [&name_oss](const nnvm::ObjectPtr n) {
- if (n->op() != nullptr) {
- name_oss << n->op()->name << "_";
+ } else {
+ for (const auto& input : idx[i].inputs) {
+ const int subgraph_id = subgraph_assignment[input.node_id];
+ if (subgraph_id == -1) {
+ node_copy->inputs.emplace_back(new_nodes[input.node_id],
+ input.index,
+ input.version);
+ } else {
+ auto& info = subgraphs[subgraph_id];
+ const int output_num = SetInsert({static_cast<int>(input.node_id),
+ static_cast<int>(input.index)},
+ &(info.outputs));
+ node_copy->inputs.emplace_back(info.subgraph_node,
+ output_num,
+ input.version);
+ }
}
- });
- auto subgraph_name = name_oss.str();
- subgraph_name.pop_back();
- subgraph_node->attrs.name = subgraph_name;
+ }
- const auto& index = subgraph.indexed_graph();
- DFSVisit(g.outputs, [&subgraph_node, &subgraph_set, &index](const
nnvm::ObjectPtr& node) {
- for (auto &e : node->control_deps) {
- if (subgraph_set.count(e.get())) {
- uint32_t node_id = index.node_id(e.get());
- auto helper_node = op::MakeNode("_FusedOpHelper",
- subgraph_node->attrs.name + "_"
- + node->attrs.name + "_helper",
- nullptr,
- nullptr,
- nullptr);
- helper_node->attrs.parsed =
- FusedOpHelperParamPtr(new FusedOpHelperParam(
- nnvm::get<FusedOpPtr>(subgraph_node->attrs.parsed),
- node_id));
- e = helper_node;
- }
+ // Control deps
+ for (const auto& dep : idx[i].control_deps) {
+ if (subgraph_id == subgraph_assignment[dep]) {
+ node_copy->control_deps.emplace_back(new_nodes[dep]);
}
- });
+ }
}
- Graph new_graph;
- new_graph.outputs = g.outputs;
- return new_graph;
-}
-/* \brief Add nodes as inputs to the subgraph. This is used for operations
- * which are only compatible when they are the first nodes in the
- * subgraph.
- */
-template <typename IsCompatible>
-void AddInputsOnlyCompatible(const Graph &g,
- std::vector<std::unordered_set<nnvm::Node*> >*
subsets,
- IsCompatible is_compatible) {
- std::unordered_map<nnvm::Node*, uint32_t> node2setidx;
- size_t subgraphs_fullsize = 0;
- for (auto& s : *subsets) {
- subgraphs_fullsize += s.size();
- }
- node2setidx.reserve(subgraphs_fullsize);
- for (size_t i = 0; i < subsets->size(); ++i) {
- for (auto& n : (*subsets)[i]) {
- node2setidx.insert({n, i});
+ ret.outputs.reserve(idx.outputs().size());
+ for (const auto& output : idx.outputs()) {
+ const int subgraph_id = subgraph_assignment[output.node_id];
+ if (subgraph_id == -1) {
+ ret.outputs.emplace_back(new_nodes[output.node_id],
+ output.index,
+ output.version);
+ } else {
+ const int output_num = SetInsert({static_cast<int>(output.node_id),
+ static_cast<int>(output.index)},
+ &(subgraphs[subgraph_id].outputs));
+ ret.outputs.emplace_back(subgraphs[subgraph_id].subgraph_node,
+ output_num,
+ output.version);
}
}
- std::vector<std::vector<nnvm::Node*> > to_add(subsets->size());
- DFSVisit(g.outputs, [&is_compatible, &node2setidx, &to_add](const
nnvm::ObjectPtr& n) {
- const auto& it = node2setidx.find(n.get());
- if (it != node2setidx.end()) {
- for (auto& e : n->inputs) {
- if (is_compatible(e.node.get()))
- to_add[it->second].push_back(e.node.get());
- }
+
+ for (auto& info : subgraphs) {
+ info.graph.outputs.reserve(info.outputs.size());
+ for (const auto& entry_info : info.outputs) {
+ info.graph.outputs.emplace_back(new_nodes[entry_info.source_node],
+ entry_info.index,
+ 0);
}
- });
+ create_subgraph_node(info.graph, info.inputs.size(),
info.subgraph_node.get());
+ }
- // Avoid duplicating the node that is input of two subsets
- std::unordered_set<nnvm::Node*> added;
- for (size_t i = 0; i < subsets->size(); ++i) {
- std::vector<nnvm::NodeEntry> heads;
- for (auto n : subsets->at(i)) {
- for (auto e : n->inputs) {
- if (!subsets->at(i).count(e.node.get()))
- heads.push_back(e);
+ for (size_t i = 0; i < idx.num_nodes(); ++i) {
+ // Add _FusedOpHelper nodes
+ const int subgraph_id = subgraph_assignment[i];
+ for (size_t dep_num = 0; dep_num < idx[i].control_deps.size(); ++dep_num) {
+ const auto& dep = idx[i].control_deps[dep_num];
+ const int their_subgraph_id = subgraph_assignment[dep];
+ if (subgraph_id != -1 && their_subgraph_id == -1) {
+ // Not in any subgraph, use FusedOpOutHelper
+ auto& info = subgraphs[subgraph_id];
+ size_t node_id = info.subgraph_node->control_deps.size();
+ info.subgraph_node->control_deps.emplace_back(new_nodes[dep]);
+ auto helper_node = op::MakeNode("_FusedOpOutHelper",
+ "FusedOp_" + new_nodes[i]->attrs.name
+ "_outhelper",
+ nullptr,
+ nullptr,
+ nullptr);
+ helper_node->attrs.parsed =
+ FusedOpHelperParamPtr(new FusedOpHelperParam(
+ nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+ node_id));
+ new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() +
dep_num,
+ std::move(helper_node));
+ } else if (their_subgraph_id != subgraph_id &&
+ their_subgraph_id != -1) {
+ auto& info = subgraphs[their_subgraph_id];
+ const auto& subgraph_idx = info.graph.indexed_graph();
+ uint32_t node_id = subgraph_idx.node_id(new_nodes[dep].get());
+ auto helper_node = op::MakeNode("_FusedOpHelper",
+ info.subgraph_node->attrs.name + "_"
+ + idx[i].source->attrs.name +
"_helper",
+ nullptr,
+ nullptr,
+ nullptr);
+ helper_node->attrs.parsed =
+ FusedOpHelperParamPtr(new FusedOpHelperParam(
+ nnvm::get<FusedOpPtr>(info.subgraph_node->attrs.parsed),
+ node_id));
+ new_nodes[i]->control_deps.insert(new_nodes[i]->control_deps.begin() +
dep_num,
+ std::move(helper_node));
}
}
- for (size_t j = 0; j < to_add[i].size(); ++j) {
- if (!added.count(to_add[i][j])) {
- bool make_cycle = false;
- const auto& node = to_add[i][j];
- std::vector<nnvm::NodeEntry> _heads;
- std::copy_if(heads.begin(), heads.end(), std::back_inserter(_heads),
- [&node](const nnvm::NodeEntry& n) {
- return n.node.get() != node;
- });
- DFSVisit(_heads, [&make_cycle, &node](const nnvm::ObjectPtr& n) {
- if (n.get() == node)
- make_cycle = true;
- });
- if (!make_cycle) {
- (*subsets)[i].insert(to_add[i][j]);
- added.insert(to_add[i][j]);
+ }
+ for (auto& info : subgraphs) {
+ const auto& idx = info.graph.indexed_graph();
+ const auto& input_nodes = idx.input_nodes();
+ std::vector<nnvm::NodeEntry> subgraph_inputs;
+ subgraph_inputs.reserve(info.subgraph_node->inputs.size());
+ for (const int input : input_nodes) {
+ for (size_t i = 0; i < info.input_nodes.size(); ++i) {
+ const auto& input_ptr = info.input_nodes[i].get();
+ if (input_ptr == idx[input].source) {
+ subgraph_inputs.emplace_back(info.subgraph_node->inputs[i]);
}
}
}
+ info.subgraph_node->inputs.swap(subgraph_inputs);
+ std::string name;
+ for (size_t i = 0; i < idx.num_nodes(); ++i) {
+ if (idx[i].source->op() != nullptr) {
+ name += idx[i].source->op()->name + "_";
+ }
+ }
+ info.subgraph_node->attrs.name = name;
}
-}
-
-Graph FusePointwiseForward(Graph &&g) {
- Graph ret;
- g.indexed_graph();
- const auto& num_forward_outputs = g.GetAttr<size_t>("num_forward_outputs");
- Graph fg;
- fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(),
- g.outputs.begin() + num_forward_outputs);
- auto subsets = GetCompatibleSubsets(fg, IsFusionCompatible);
- AddInputsOnlyCompatible(fg, &subsets, IsInputsOnlyCompatible);
- g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode);
- ret.outputs = g.outputs;
return ret;
}
-Graph FusePointwiseBackward(Graph &&g) {
- Graph ret;
- g.indexed_graph();
- const auto& num_forward_outputs = g.GetAttr<size_t>("num_forward_outputs");
- Graph fg;
- fg.outputs.insert(fg.outputs.begin(), g.outputs.begin(),
- g.outputs.begin() + num_forward_outputs);
- std::unordered_set<nnvm::Node*> exclusion_set;
- DFSVisit(fg.outputs, [&exclusion_set](const nnvm::ObjectPtr& n) {
- exclusion_set.insert(n.get());
- });
- auto subsets = GetCompatibleSubsets(g, [&exclusion_set](nnvm::Node* n) {
- if (exclusion_set.count(n))
- return false;
- return IsFusionCompatible(n);
- });
- g = ReplaceSubgraphsPointwise(std::move(g), subsets, CreateSubgraphNode);
- ret.outputs = g.outputs;
+Graph FusePointwise(const Graph &g, const size_t num_forward_outputs) {
+ auto start = std::chrono::steady_clock::now();
+ std::vector<int> subset_assignment;
+ int num_subsets;
+ std::tie(subset_assignment, num_subsets) = GetCompatibleSubsets(g,
num_forward_outputs, // NOLINT(*)
+
IsFusionCompatible,
+
IsInputsOnlyCompatible);
+ Graph ret = CopyAndReplaceSubgraphs(g, subset_assignment, num_subsets,
+ CreateSubgraphNode);
+ auto end = std::chrono::steady_clock::now();
+ if (dmlc::GetEnv("MXNET_RTC_VERBOSE", false)) {
+ auto diff = end - start;
+ LOG(INFO) << "Pointwise fusion graph pass took: "
+ << std::chrono::duration<double, std::milli>(diff).count()
+ << "ms.";
+ }
return ret;
}
#endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC
diff --git a/src/executor/simple_partition_pass.cc
b/src/executor/simple_partition_pass.cc
new file mode 100644
index 0000000..941959d
--- /dev/null
+++ b/src/executor/simple_partition_pass.cc
@@ -0,0 +1,265 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file simple_partition_pass.cc
+ * \brief Utilities used in simple partition pass
+ * \author Przemyslaw Tredak
+ */
+
+#include "./simple_partition_pass.h"
+#include <memory>
+#include <utility>
+
+namespace mxnet {
+namespace exec {
+
+namespace detail {
+
+const IntervalVec* LargerSet(const IntervalVec* const first,
+ const IntervalVec* const second) noexcept {
+ const IntervalVec* ret = nullptr;
+ auto first_iter = first->begin();
+ auto second_iter = second->begin();
+ while (first_iter != first->end() &&
+ second_iter != second->end()) {
+ if (*first_iter == *second_iter) {
+ ++first_iter;
+ ++second_iter;
+ } else {
+ // Entry in first set not seen in the second set
+ if (first_iter->second < second_iter->first) {
+ if (ret == first || ret == nullptr) {
+ ret = first;
+ ++first_iter;
+ } else {
+ return nullptr;
+ }
+ continue;
+ }
+ // Entry in second set not seen in the first set
+ if (second_iter->second < first_iter->first) {
+ if (ret == second || ret == nullptr) {
+ ret = second;
+ ++second_iter;
+ } else {
+ return nullptr;
+ }
+ continue;
+ }
+ // Entry in first set fully encloses the entry in the second set
+ if (first_iter->first <= second_iter->first &&
+ first_iter->second >= second_iter->second) {
+ if (ret == first || ret == nullptr) {
+ ret = first;
+ ++second_iter;
+ } else {
+ return nullptr;
+ }
+ continue;
+ }
+ // Entry in second set fully encloses the entry in the first set
+ if (second_iter->first <= first_iter->first &&
+ second_iter->second >= first_iter->second) {
+ if (ret == second || ret == nullptr) {
+ ret = second;
+ ++first_iter;
+ } else {
+ return nullptr;
+ }
+ continue;
+ }
+ // Entries intersect but one is not fully enclosed in the other
+ return nullptr;
+ }
+ }
+ if (ret == nullptr) {
+ // The common part is the same
+ return second_iter == second->end() ? first : second;
+ } else {
+ if ((ret == first && second_iter == second->end()) ||
+ (ret == second && first_iter == first->end())) {
+ return ret;
+ }
+ }
+ return nullptr;
+}
+
+void MergeSets(const IntervalVec** const my_set,
+ const IntervalVec* const other_set,
+ std::vector<std::unique_ptr<const IntervalVec>>* const storage)
noexcept {
+ if ((*my_set == nullptr) || (*my_set)->size() == 0) {
+ *my_set = other_set;
+ return;
+ }
+ if (other_set == nullptr || other_set->size() == 0) {
+ return;
+ }
+ auto* larger_set = LargerSet(*my_set, other_set);
+ if (larger_set != nullptr) {
+ *my_set = larger_set;
+ return;
+ }
+ auto my_iter = (*my_set)->cbegin();
+ auto other_iter = other_set->cbegin();
+ auto new_set = IntervalVec();
+ int last_end = -10; // less than -1
+ while (my_iter != (*my_set)->cend() &&
+ other_iter != other_set->cend()) {
+ const auto& mine = *my_iter;
+ const auto& other = *other_iter;
+ if (other.second < mine.first - 1) {
+ // other interval is before ours
+ if (last_end >= other.first - 1) {
+ new_set.back().second = other.second;
+ } else {
+ new_set.emplace_back(other);
+ }
+ last_end = other.second;
+ ++other_iter;
+ } else if (other.first > mine.second + 1) {
+ // other interval is after ours
+ if (last_end >= mine.first - 1) {
+ new_set.back().second = mine.second;
+ } else {
+ new_set.emplace_back(mine);
+ }
+ last_end = mine.second;
+ ++my_iter;
+ } else {
+ // Intervals can be merged together
+ Interval n(std::min(mine.first, other.first),
+ std::max(mine.second, other.second));
+ if (last_end >= n.first - 1) {
+ new_set.back().second = n.second;
+ } else {
+ new_set.emplace_back(n);
+ }
+ last_end = n.second;
+ if (other.second >= mine.second) {
+ ++my_iter;
+ }
+ if (mine.second >= other.second) {
+ ++other_iter;
+ }
+ }
+ }
+ auto remaining_iter = my_iter == (*my_set)->cend() ? other_iter : my_iter;
+ auto remaining_end = my_iter == (*my_set)->cend() ? other_set->cend() :
(*my_set)->cend();
+ // Add the rest of entries
+ for (; remaining_iter != remaining_end; ++remaining_iter) {
+ auto& mine = new_set.back();
+ const auto& other = *remaining_iter;
+ if (other.second < mine.first - 1) {
+ // other interval is before ours, should never happen
+ continue;
+ } else if (other.first > mine.second + 1) {
+ // other interval is after ours
+ new_set.emplace_back(other);
+ } else {
+ // Intervals can be merged together
+ mine.first = std::min(mine.first, other.first);
+ mine.second = std::max(mine.second, other.second);
+ }
+ }
+ storage->emplace_back(std::make_unique<IntervalVec>(std::move(new_set)));
+ *my_set = storage->back().get();
+}
+
+bool Intersect(const IntervalVec& checked_sets,
+ const IntervalVec& excluded_sets) noexcept {
+ size_t current_interval = 0, current_other_interval = 0;
+ while (current_interval < checked_sets.size() &&
+ current_other_interval < excluded_sets.size()) {
+ const auto& mine = checked_sets[current_interval];
+ const auto& other = excluded_sets[current_other_interval];
+ if (other.second < mine.first) {
+ // other interval is before ours
+ ++current_other_interval;
+ } else if (other.first > mine.second) {
+ // other interval is after ours
+ ++current_interval;
+ } else {
+ // Intervals intersect
+ return true;
+ }
+ }
+ return false;
+}
+
+void AddSet(const IntervalVec** const sets, const int set_to_add,
+ std::vector<std::unique_ptr<const IntervalVec>>* const storage)
noexcept {
+ if (*sets != nullptr && (*sets)->size() != 0) {
+ for (auto& interval : (**sets)) {
+ if (set_to_add >= interval.first &&
+ set_to_add <= interval.second) {
+ return;
+ }
+ }
+ }
+ storage->emplace_back(
+ std::make_unique<IntervalVec>(1, std::make_pair(set_to_add,
set_to_add)));
+ MergeSets(sets, storage->back().get(), storage);
+}
+
+int GetSetMapping(const int set, std::vector<int>* const set_mapping) noexcept
{
+ if (set == -1) return -1;
+ int temp = set;
+ while ((*set_mapping)[temp] != temp) {
+ temp = (*set_mapping)[temp];
+ }
+ (*set_mapping)[set] = temp;
+ return temp;
+}
+
+void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const
combined_excluded_sets_ptr,
+ const IntervalVec* const
new_excluded_sets,
+ std::vector<const IntervalVec*>* const
excluded_sets_ptr,
+ const int set_id,
+ const int first_node_in_set,
+ const size_t new_node_id,
+ const std::vector<int>& set_assignment,
+ std::vector<int>* const
set_mapping_ptr,
+ const IntervalVec& inverse_set_mapping,
+ std::vector<std::unique_ptr<const
IntervalVec>>* const
+ storage) noexcept {
+ const auto* previous_excluded_sets = *combined_excluded_sets_ptr;
+ MergeSets(combined_excluded_sets_ptr, new_excluded_sets, storage);
+ if (new_excluded_sets != nullptr) {
+ if (previous_excluded_sets == nullptr ||
+ *previous_excluded_sets != **(combined_excluded_sets_ptr)) {
+ // Their set's excluded sets list got larger, need to update the
descendants
+ // of their set
+ auto& excluded_sets = *excluded_sets_ptr;
+ for (size_t j = first_node_in_set; j < new_node_id; ++j) {
+ if (GetSetMapping(set_assignment[j], set_mapping_ptr) == set_id ||
+ (excluded_sets[j] != nullptr &&
+ Intersect(inverse_set_mapping, *excluded_sets[j]))) {
+ MergeSets(&excluded_sets[j], *combined_excluded_sets_ptr, storage);
+ }
+ }
+ }
+ }
+}
+
+} // namespace detail
+
+} // namespace exec
+} // namespace mxnet
diff --git a/src/executor/simple_partition_pass.h
b/src/executor/simple_partition_pass.h
index 1ca0086..2a135b4 100644
--- a/src/executor/simple_partition_pass.h
+++ b/src/executor/simple_partition_pass.h
@@ -18,10 +18,10 @@
*/
/*!
- * Copyright (c) 2019 by Contributors
+ * Copyright (c) 2019-2020 by Contributors
* \file simple_partition_pass.h
* \brief Simple pass for partitioning a graph.
- * \author Clement Fuji Tsang
+ * \author Clement Fuji Tsang, Przemyslaw Tredak
*/
#ifndef MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_
#define MXNET_EXECUTOR_SIMPLE_PARTITION_PASS_H_
@@ -34,440 +34,237 @@
#include <deque>
#include <algorithm>
#include <vector>
+#include <tuple>
#include "exec_pass.h"
namespace mxnet {
namespace exec {
+namespace detail {
-/*!
- * \brief Custom graph class, which contains bi-directional nodes
- * required for traversing in both directions (from outputs to inputs
- * and vice versa). It is a non-owning layer on top of NNVM graph, since
- * NNVM graph enables traversing only in 1 direction (from outputs to inputs).
+using Interval = std::pair<int, int>;
+using IntervalVec = std::vector<Interval>;
+
+/* \brief Return the set that fully contains the other set, or nullptr
+ * if neither set is a subset of another.
*/
-class BidirectionalGraph {
- public:
- struct Node {
- nnvm::Node* nnvmptr;
- std::vector<Node*> inputs;
- std::vector<Node*> outputs;
- };
+const IntervalVec* LargerSet(const IntervalVec* const first,
+ const IntervalVec* const second) noexcept;
- explicit BidirectionalGraph(const Graph &g) {
- auto& idx = g.indexed_graph();
- auto num_nodes = idx.num_nodes();
- nodes.reserve(num_nodes);
- nnvm2nid.reserve(num_nodes);
- outputs.reserve(idx.outputs().size());
- // Create all the nodes in a new graph from
- // nodes in the NNVM graph and store them
- // in nodes array
- DFSVisit(g.outputs, [this](const nnvm::ObjectPtr& n) {
- Node new_node;
- new_node.nnvmptr = n.get();
- nnvm2nid[n.get()] = static_cast<uint32_t>(nodes.size());
- nodes.emplace_back(std::move(new_node));
- });
- // Create all connections between nodes in
- // the graph (both directions)
- for (const auto& it : nnvm2nid) {
- nnvm::Node* nnvmnode = it.first;
- uint32_t nid = it.second;
- for (auto& n : nnvmnode->inputs) {
- uint32_t input_nid = nnvm2nid[n.node.get()];
- nodes[input_nid].outputs.emplace_back(&nodes[nid]);
- nodes[nid].inputs.emplace_back(&nodes[input_nid]);
- }
- }
- // Create output connections from the graph
- for (auto& e : g.outputs) {
- uint32_t nid = nnvm2nid[e.node.get()];
- outputs.emplace_back(&nodes[nid]);
- }
- }
+/* \brief Compute the sum of the 2 sets and store it in my_set.
+ */
+void MergeSets(const IntervalVec** const my_set,
+ const IntervalVec* const other_set,
+ std::vector<std::unique_ptr<const IntervalVec>>* const storage)
noexcept;
- /* \brief Get all subsets of nodes, where:
- * - graph constructed from nodes in each subset is a connected graph
- * - every node fulfills a predicate is_compatible
- * - if nodes u and v are part of a subset, then for each path between
- * u and v in the original directed graph, all nodes on those paths
- * are also part of the subset
- * \param is_compatible A function taking nnvm::Node* and returning bool
- * which identifies which nodes should be included in
- * subsets.
- */
- template<typename FCompatible>
- std::vector<std::unordered_set<Node*>> get_subsets(FCompatible
is_compatible) {
- std::vector<std::unordered_set<Node*>> subgraphs;
- std::unordered_set<Node*> incomp_set;
- std::vector<std::pair<bool, PairSet>> separation_sets;
- // Check each node for compatibility
- // and, if it is incompatible, mark nodes
- // on each side of it as not possible to be
- // in the same subset
- for (Node& node : nodes) {
- if (!is_compatible(node.nnvmptr)) {
- incomp_set.insert(&node);
- }
- }
- for (Node& node : nodes) {
- if (incomp_set.count(&node) != 0) {
- // Check if all your inputs are incompatible too.
- // If so, then your separation set does not matter,
- // because it will covered by the sets of your inputs
- bool inside_node = true;
- for (Node* input : node.inputs) {
- if (incomp_set.count(input) == 0) {
- inside_node = false;
- }
- }
- if (!inside_node) {
- std::unordered_set<Node*> in_graph;
- std::unordered_set<Node*> out_graph;
- std::vector<Node*> dummy_head;
- dummy_head.emplace_back(&node);
- DFS(dummy_head, false, [&out_graph](Node* node) {
- out_graph.insert(node);
- });
- DFS(dummy_head, true, [&in_graph](Node* node) {
- in_graph.insert(node);
- });
- separation_sets.push_back(std::make_pair(true,
- std::make_pair(in_graph,
out_graph)));
- } else {
- separation_sets.push_back(std::make_pair(false, PairSet()));
- }
- } else {
- separation_sets.push_back(std::make_pair(false, PairSet()));
- }
- }
- IncompMap incomp_map;
- // For each node construct the map of nodes that cannot be in
- // the same subset
- index_t num_nodes = nodes.size();
- for (index_t i = 0; i < num_nodes; ++i) {
- const auto n = &(nodes[i]);
- if (incomp_set.count(n) == 0) {
- for (index_t j = i + 1; j < num_nodes; ++j) {
- const auto& sep_set_pair = separation_sets[j];
- if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
- const auto& p = sep_set_pair.second;
- if (p.first.count(n)) {
- incomp_map[n].insert(p.second.begin(), p.second.end());
- } else if (p.second.count(n)) {
- incomp_map[n].insert(p.first.begin(), p.first.end());
- }
- }
- }
- for (index_t j = i - 1; j >= 0; --j) {
- const auto& sep_set_pair = separation_sets[j];
- if (sep_set_pair.first && incomp_map[n].count(&nodes[j]) == 0) {
- const auto& p = sep_set_pair.second;
- if (p.first.count(n)) {
- incomp_map[n].insert(p.second.begin(), p.second.end());
- } else if (p.second.count(n)) {
- incomp_map[n].insert(p.first.begin(), p.first.end());
- }
- }
- }
- for (Node* incomp_n : incomp_set) {
- incomp_map[n].erase(incomp_n);
- }
- }
- }
- std::unordered_set<Node*> unused_set;
+/* \brief Returns true if there is non-empty intersection
+ * between the 2 sets.
+ */
+bool Intersect(const IntervalVec& checked_sets,
+ const IntervalVec& excluded_sets) noexcept;
- for (auto& n : nodes) {
- if (incomp_set.count(&n) == 0) {
- unused_set.insert(&n);
- }
- }
- std::unordered_set<Node*> visited;
- std::deque<Node*> stack(outputs.begin(), outputs.end());
- // Create subsets
- while (!stack.empty()) {
- Node* vertex = stack.front();
- stack.pop_front();
- if (!visited.count(vertex)) {
- visited.insert(vertex);
- if (unused_set.count(vertex)) {
- subgraphs.emplace_back(naive_grow_subgraph(vertex, &unused_set,
&incomp_map));
- }
- for (Node* input : vertex->inputs) {
- stack.emplace_back(input);
- }
- }
- }
- return subgraphs;
- }
+/* \brief Add a single entry to the sets.
+ */
+void AddSet(const IntervalVec** const sets, const int set_to_add,
+ std::vector<std::unique_ptr<const IntervalVec>>* const storage)
noexcept;
- private:
- using PairSet = std::pair<std::unordered_set<Node*>,
std::unordered_set<Node*>>;
- using PairVec = std::pair<std::vector<Node*>, std::vector<Node*>>;
- using IncompMap = std::unordered_map<Node*, std::unordered_set<Node*>>;
+/* \brief Get the true mapping of the set (which could change
+ * due to merging of multiple sets.
+ */
+int GetSetMapping(const int set, std::vector<int>* const set_mapping) noexcept;
- /* \brief Traverse the graph using DFS in either direction.
- * \param heads Starting nodes for the DFS algorithm.
- * \param reverse If true, DFS will traverse the graph from
- * outputs to inputs. Otherwise, it will
- * traverse the graph from inputs to outputs.
- * \param fvisit Function to call on each visisted node.
- */
- template <typename FVisit>
- void DFS(const std::vector<Node*>& heads, bool reverse, FVisit fvisit) {
- std::unordered_set<Node*> visited;
- std::vector<Node*> vec(heads.begin(), heads.end());
- visited.reserve(heads.size());
- while (!vec.empty()) {
- Node* vertex = vec.back();
- vec.pop_back();
- if (visited.count(vertex) == 0) {
- visited.insert(vertex);
- fvisit(vertex);
- std::vector<Node*> nexts = reverse ? vertex->inputs : vertex->outputs;
- for (Node* node : nexts) {
- if (visited.count(node) == 0) {
- vec.emplace_back(node);
- }
- }
- }
- }
- }
+/* \brief Check if 2 ids are on the same side of the cutoff
+ * (so either both on the FWD side or the BWD side).
+ */
+inline bool IsSamePass(const int my_id, const int their_id, const int cutoff)
noexcept {
+ return (my_id > cutoff && their_id > cutoff) ||
+ (my_id <= cutoff && their_id <= cutoff);
+}
- /* \brief Get the connected subgraph that contains the head node,
- * only previously unused nodes, according to the rules
- * from incompatibility map.
- * \param head Node which needs to be part of the returned subgraph.
- * \param unused_set Only nodes from this set will be considered when
- * adding to the growing subgraph.
- * \param incomp_map Map containing data on which nodes are incompatible
- * to be in the same subgraph.
- */
- std::unordered_set<Node*> naive_grow_subgraph(Node* head,
- std::unordered_set<Node*>*
unused_set,
- IncompMap* incomp_map) {
- std::unordered_set<Node*> subgraph;
- std::unordered_set<Node*> incomp_set;
- std::deque<Node*> stack;
- stack.emplace_back(head);
- while (!stack.empty()) {
- Node* vertex = stack.back();
- stack.pop_back();
- if (unused_set->count(vertex) && !incomp_set.count(vertex)) {
- unused_set->erase(vertex);
- subgraph.insert(vertex);
- incomp_set.insert((*incomp_map)[vertex].begin(),
(*incomp_map)[vertex].end());
- // Traverse the grpah in both directions
- for (Node* input : vertex->inputs) {
- if (unused_set->count(input) && !incomp_set.count(input)) {
- stack.emplace_back(input);
- }
- }
- for (Node* output : vertex->outputs) {
- if (unused_set->count(output) && !incomp_set.count(output)) {
- stack.emplace_back(output);
- }
- }
- }
- }
- return subgraph;
- }
+/* \brief Check if adding a new node to the set changes the excluded set of
the future
+ * fused node. If so, update all descendants of the fused node.
+ *
+ * \param combined_excluded_sets_ptr pointer to the set's list of excluded sets
+ * before adding the new node
+ * \param new_excluded_sets list of excluded sets of the new node
+ * \param excluded_sets_ptr pointer to the lists of excluded sets of all the
nodes
+ * \param set_id number of the set, to which the new node is added
+ * \param first_node_in_set id of the first node in the set, according to
topological ordering
+ * \param new_node_id id of the node added to the set
+ * \param set_assignment assignment of sets
+ * \param set_mapping_ptr pointer to the mappings of sets
+ * \param inverse_set_mapping inverse mapping of the set
+ * \param storage memory storage
+ */
+void CheckAndUpdateCombinedExcludedSets(const IntervalVec** const
combined_excluded_sets_ptr,
+ const IntervalVec* const
new_excluded_sets,
+ std::vector<const IntervalVec*>* const
excluded_sets_ptr,
+ const int set_id,
+ const int first_node_in_set,
+ const size_t new_node_id,
+ const std::vector<int>& set_assignment,
+ std::vector<int>* const
set_mapping_ptr,
+ const IntervalVec& inverse_set_mapping,
+ std::vector<std::unique_ptr<const
IntervalVec>>* const
+ storage) noexcept;
- friend class Graph;
+} // namespace detail
- std::vector<Node> nodes;
- std::unordered_map<nnvm::Node*, uint32_t> nnvm2nid;
- std::vector<Node*> outputs;
-}; // class BidirectionalGraph
-using NodeEntrySet = std::unordered_set<nnvm::NodeEntry, nnvm::NodeEntryHash,
- nnvm::NodeEntryEqual>;
-using NodeRawPtrSet = std::unordered_set<nnvm::Node*>;
+/* \brief Get all subsets of nodes, where:
+ * - graph constructed from nodes in each subset is a connected graph
+ * - every node fulfills a predicate is_compatible
+ * - if nodes u and v are part of a subset, then for each path between
+ * u and v in the original directed graph, all nodes on those paths
+ * are also part of the subset
+ * \param g NNVM graph
+ * \param num_forward_outputs Number of outputs from the graph that come
+ * from the forward pass
+ * \param is_compatible A function taking nnvm::Node* and returning bool
+ * which identifies which nodes could be included in
+ * subsets.
+ * \param is_input_only_compatible A function taking nnvm::Node* and
+ * returning bool which identifies which
+ * nodes could be included in subsets only
+ * as the first operations (their inputs
+ * need to be excluded).
+ * \return tuple (subset assignment, number of found subsets)
+ */
+template<typename FCompatible, typename FInputOnlyCompatible>
+std::tuple<std::vector<int>, int> GetCompatibleSubsets(
+ const Graph& g,
+ const size_t num_forward_outputs,
+ FCompatible is_compatible,
+ FInputOnlyCompatible is_input_only_compatible) {
-/*!
- * \brief Get the output nodes of the subgraph in the main graph.
- * \return a map between the node in the main graph and the output index of
the subgraph node
-*/
-nnvm::NodeEntryMap<uint32_t> GetSubgraphOutputs(Graph g, NodeRawPtrSet
subgraph_set) {
- nnvm::NodeEntryMap<uint32_t> outputs;
- uint32_t count = 0;
- for (auto& e : g.outputs) {
- if (subgraph_set.count(e.node.get()) && !outputs.count(e)) {
- outputs.insert({e, count++});
+ using namespace detail;
+ const auto& idx = g.indexed_graph();
+ std::vector<int> set_assignment(idx.num_nodes(), -1);
+ std::vector<const std::vector<Interval>*> excluded_sets(idx.num_nodes());
+ std::vector<int> set_mapping;
+ std::vector<const std::vector<Interval>*> combined_excluded_sets;
+ std::vector<int> first_node_in_set;
+ std::vector<const std::vector<Interval>*> inverse_set_mapping;
+ std::vector<std::unique_ptr<const std::vector<Interval>>> storage;
+
+ int last_forward_node = -1;
+ for (size_t i = 0; i < num_forward_outputs; ++i) {
+ const int output_id = idx.outputs()[i].node_id;
+ if (last_forward_node < output_id) {
+ last_forward_node = output_id;
}
}
- DFSVisit(g.outputs, [&subgraph_set, &outputs, &count](const nnvm::ObjectPtr
&node){
- if (!subgraph_set.count(node.get())) {
- for (auto& e : node->inputs) {
- if (subgraph_set.count(e.node.get()) && !outputs.count(e)) {
- outputs.insert({e, count++});
- }
- }
- }
- });
- return outputs;
-}
-/*!
- * \brief Create new input nodes of the subgraph and plug them.
- * \return the inputs of the subgraph node in the main graph
-*/
-std::vector<nnvm::NodeEntry> GetSubgraphInputs(Graph g, NodeRawPtrSet
subgraph_set) {
- std::vector<nnvm::NodeEntry> inputs;
- nnvm::NodeEntryMap<nnvm::NodeEntry> entry_map;
- DFSVisit(g.outputs, [&subgraph_set, &inputs, &entry_map](const
nnvm::ObjectPtr &node){
- if (subgraph_set.count(node.get())) {
- for (auto &e : node->inputs) {
- if (!subgraph_set.count(e.node.get())) {
- if (entry_map.count(e)) {
- e = entry_map[e];
+ int num_sets = 0;
+ for (size_t i = 0; i < idx.num_nodes(); ++i) {
+ const auto& node = idx[i];
+ auto& my_excluded_sets = excluded_sets[i];
+ for (const auto& input : node.inputs) {
+ MergeSets(&my_excluded_sets, excluded_sets[input.node_id], &storage);
+ }
+ if (is_compatible(node.source)) {
+ int my_set = -1;
+ for (const auto& input : node.inputs) {
+ int their_set = GetSetMapping(set_assignment[input.node_id],
&set_mapping);
+ if (their_set != -1 &&
+ their_set != my_set &&
+ IsSamePass(i, input.node_id, last_forward_node) &&
+ (my_excluded_sets == nullptr ||
+ !Intersect(*inverse_set_mapping[their_set], *my_excluded_sets))) {
+ if (my_set == -1) {
+ my_set = their_set;
+
CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]),
+ my_excluded_sets,
+ &excluded_sets,
+ their_set,
+ first_node_in_set[their_set],
+ i,
+ set_assignment,
+ &set_mapping,
+
*(inverse_set_mapping[their_set]),
+ &storage);
} else {
- auto new_node = nnvm::Node::Create();
- new_node->attrs.name = "input_" + std::to_string(inputs.size());
- entry_map.insert({e, nnvm::NodeEntry{new_node, 0, 0}});
- inputs.push_back(e);
- e.node = new_node;
- e.index = 0;
+ MergeSets(&inverse_set_mapping[my_set],
+ inverse_set_mapping[their_set],
+ &storage);
+ set_mapping[their_set] = my_set;
+ first_node_in_set[my_set] = std::min(first_node_in_set[my_set],
+ first_node_in_set[their_set]);
+
CheckAndUpdateCombinedExcludedSets(&(combined_excluded_sets[their_set]),
+ combined_excluded_sets[my_set],
+ &excluded_sets,
+ my_set,
+ first_node_in_set[my_set],
+ i,
+ set_assignment,
+ &set_mapping,
+ *(inverse_set_mapping[my_set]),
+ &storage);
}
}
}
+ if (my_set == -1) {
+ set_mapping.emplace_back(num_sets);
+ combined_excluded_sets.emplace_back(my_excluded_sets);
+ first_node_in_set.emplace_back(i);
+ storage.emplace_back(std::make_unique<std::vector<Interval>>(
+ 1, std::make_pair(num_sets,
+ num_sets)));
+ inverse_set_mapping.emplace_back(storage.back().get());
+ my_set = num_sets++;
+ }
+ set_assignment[i] = my_set;
+ } else {
+ for (const auto& input : node.inputs) {
+ int their_set = GetSetMapping(set_assignment[input.node_id],
&set_mapping);
+ if (their_set != -1) {
+ AddSet(&my_excluded_sets, their_set, &storage);
+ }
+ }
+ if ((is_input_only_compatible != nullptr) &&
+ is_input_only_compatible(node.source)) {
+ set_mapping.emplace_back(num_sets);
+ combined_excluded_sets.emplace_back(my_excluded_sets);
+ first_node_in_set.emplace_back(i);
+ storage.emplace_back(std::make_unique<std::vector<Interval>>(
+ 1, std::make_pair(num_sets,
+ num_sets)));
+ inverse_set_mapping.emplace_back(storage.back().get());
+ set_assignment[i] = num_sets++;
+ }
}
- });
- // Fix ordering of w.r.t to topology
- Graph _g;
- _g.outputs = g.outputs;
- const auto &idx = _g.indexed_graph();
- std::sort(inputs.begin(), inputs.end(),
- [&idx, &entry_map](const nnvm::NodeEntry lhs, const nnvm::NodeEntry rhs)
{
- return idx.entry_id(entry_map.at(lhs)) <
idx.entry_id(entry_map.at(rhs));
- });
- return inputs;
-}
-
-std::unordered_map<uint32_t, uint32_t> GetGraphInputsMap(const Graph& g) {
- std::unordered_map<uint32_t, uint32_t> outputs;
- auto& idx = g.indexed_graph();
- outputs.reserve(idx.num_nodes());
- std::vector<uint32_t> input_nodes = idx.input_nodes();
- for (size_t i = 0; i < input_nodes.size(); ++i) {
- outputs[input_nodes[i]] = static_cast<uint32_t>(i);
}
- return outputs;
-}
-/*!
- * \brief Helper function to display what nodes are in a specific subset.
- */
-void dispNodesSet(Graph g, NodeRawPtrSet s) {
- DFSVisit(g.outputs, [&s](const nnvm::ObjectPtr n){
- if (s.count(n.get())) {
- std::cout << " Y " << n->attrs.name << std::endl;
- } else {
- std::cout << " N " << n->attrs.name << std::endl;
- }
- });
-}
+ for (int& set : set_assignment) {
+ set = GetSetMapping(set, &set_mapping);
+ }
-/*!
- * \brief Replace a set of nodes by a subgraph node.
- */
-template<typename FCreateNode>
-Graph ReplaceSubgraphs(Graph&& g, const std::vector<NodeRawPtrSet>&
subgraph_sets,
- FCreateNode create_subgraph_node) {
- for (auto subgraph_set : subgraph_sets) {
- // Create MXNet subgraph
- Graph subgraph;
- const auto sub_outputs_in_main = GetSubgraphOutputs(g, subgraph_set);
- subgraph.outputs.resize(sub_outputs_in_main.size());
- for (auto p : sub_outputs_in_main) {
- subgraph.outputs[p.second] = p.first;
+ std::vector<int> set_reorder(num_sets, 0);
+ // First count the number of elements in each set.
+ for (int& set : set_assignment) {
+ if (set != -1) {
+ ++set_reorder[set];
}
- // To generate a subgraph an input has to be replaced by data node (no op)
- // and it has to be agnostic to the node from which it's an output
- // (For example, even if two inputs are two different outputs from the
same node,
- // they need to be replaced by two completely separate data nodes)
- auto inputs = GetSubgraphInputs(subgraph, subgraph_set);
- auto subgraph_node = create_subgraph_node(subgraph);
- subgraph_node->inputs = inputs;
- // replug inputs of node out of subgraph to be output of the subgraph node
- // if it was a node in the subgraph
- DFSVisit(g.outputs,
- [&subgraph_node, &subgraph_set, &sub_outputs_in_main](const
nnvm::ObjectPtr node) {
- if (!subgraph_set.count(node.get())) {
- for (auto &e : node->inputs) {
- auto it = sub_outputs_in_main.find(e);
- if (it != sub_outputs_in_main.end()) {
- e.node = subgraph_node;
- e.index = it->second;
- }
- }
- }
- });
- // replug outputs of the graph to be output of the subgraph node
- // if it was a node in the subgraph
- for (auto &e : g.outputs) {
- auto it = sub_outputs_in_main.find(e);
- if (it != sub_outputs_in_main.end()) {
- e.node = subgraph_node;
- e.index = it->second;
- }
+ }
+ // Then reorder them, removing sets that have
+ // only a single element.
+ int final_num_sets = 0;
+ for (int& set : set_reorder) {
+ if (set > 1) {
+ set = final_num_sets++;
+ } else {
+ set = -1;
}
- // move control dependencies between nodes of the subgraph and out of the
subgraph
- // to a dependencies between the subgraph node and the nodes out of the
subgraph
- DFSVisit(g.outputs, [&subgraph_node, &subgraph_set](const nnvm::ObjectPtr&
node) {
- for (auto &e : node->control_deps) {
- if (subgraph_set.count(e.get()))
- e = subgraph_node;
- }
- });
- DFSVisit(subgraph.outputs, [&subgraph_node, &subgraph_set](const
nnvm::ObjectPtr& node) {
- auto it = node->control_deps.begin();
- while (it != node->control_deps.end()) {
- if (subgraph_set.count(it->get())) {
- ++it;
- } else {
- subgraph_node->control_deps.push_back(*it);
- it = node->control_deps.erase(it);
- }
- }
- });
}
- Graph new_graph;
- new_graph.outputs = g.outputs;
- return new_graph;
-}
-/* \brief Get all subsets of nodes, where:
- * - graph constructed from nodes in each subset is a connected graph
- * - every node fulfills a predicate is_compatible
- * - if nodes u and v are part of a subset, then for each path between
- * u and v in the original directed graph, all nodes on those paths
- * are also part of the subset
- * \param g NNVM graph
- * \param is_compatible A function taking nnvm::Node* and returning bool
- * which identifies which nodes should be included in
- * subsets.
- */
-template<typename FCompatible>
-std::vector<NodeRawPtrSet> GetCompatibleSubsets(const Graph& g, FCompatible
is_compatible) {
- BidirectionalGraph biG = BidirectionalGraph(g);
- std::vector<std::unordered_set<BidirectionalGraph::Node*>> subsets =
- biG.get_subsets(is_compatible);
- std::vector<NodeRawPtrSet> nnvm_subsets;
- nnvm_subsets.reserve(subsets.size());
- for (auto& subset : subsets) {
- if (subset.size() > 1) {
- NodeRawPtrSet node_set;
- node_set.reserve(subset.size());
- for (auto& n : subset) {
- node_set.insert(n->nnvmptr);
- }
- nnvm_subsets.push_back(node_set);
+ for (int& set : set_assignment) {
+ if (set != -1) {
+ set = set_reorder[set];
}
}
- return nnvm_subsets;
+
+ return std::make_tuple(std::move(set_assignment), final_num_sets);
}
} // namespace exec
diff --git a/src/imperative/cached_op.h b/src/imperative/cached_op.h
index c56d8cf..23ab4a0 100644
--- a/src/imperative/cached_op.h
+++ b/src/imperative/cached_op.h
@@ -238,10 +238,7 @@ void OptimizeGraph(nnvm::Graph * full_graph, nnvm::Graph *
fwd_graph, nnvm::Grap
common::CopyGraph(&unoptimized_graph, *full_graph, false);
if
(common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) {
- full_graph->attrs["num_forward_outputs"] =
std::make_shared<nnvm::any>(num_forward_outputs);
- *full_graph = exec::FusePointwiseForward(std::move(*full_graph));
- full_graph->attrs["num_forward_outputs"] =
std::make_shared<nnvm::any>(num_forward_outputs);
- *full_graph = exec::FusePointwiseBackward(std::move(*full_graph));
+ *full_graph = exec::FusePointwise(*full_graph, num_forward_outputs);
// Check the topological order of inputs
const auto &original_inputs =
unoptimized_graph.indexed_graph().input_nodes();
const auto &new_inputs = full_graph->indexed_graph().input_nodes();
diff --git a/src/operator/fusion/fused_op.cu b/src/operator/fusion/fused_op.cu
index cb13dbf..1b914cc 100644
--- a/src/operator/fusion/fused_op.cu
+++ b/src/operator/fusion/fused_op.cu
@@ -261,12 +261,40 @@ std::string FusedOp::GenerateCode(const
std::vector<OpReqType> &req,
const auto& var_name = g[node_id].source->attrs.name;
const auto vec_name = "vec_" + var_name + "_" + std::to_string(i);
load_index[node_id] = 0;
- auto parse_tuple = [](const std::string& input, const std::string
def) {
+ auto parse_tuple = [ndim](const std::string& input, const
std::string& def) {
std::string out = input;
- replaceString(&out, "(", "{");
- replaceString(&out, ")", "}");
+ replaceString(&out, " ", "");
+ if (out[0] == '(') {
+ replaceString(&out, "(", "{");
+ replaceString(&out, ")", "}");
+ // First check if out is ()
+ int n_entries = out.size() != 2;
+ for (size_t i = 1; i < out.size() - 1; ++i) {
+ if (out[i] == ',') {
+ ++n_entries;
+ }
+ }
+ if (n_entries != ndim) {
+ out.pop_back();
+ for (int i = n_entries; i < ndim; ++i) {
+ out += "," + def;
+ }
+ out += "}";
+ }
+ } else {
+ out = "{" + std::move(out);
+ for (int i = 1; i < ndim; ++i) {
+ out += "," + def;
+ }
+ out += "}";
+ }
replaceString(&out, "None", def);
+ return out;
+ };
+ auto parse_int = [](const std::string& input, const std::string&
def) {
+ std::string out = input;
replaceString(&out, " ", "");
+ replaceString(&out, "None", def);
return out;
};
auto build_tuple = [ndim](int axis, const std::string str, const
std::string def) {
@@ -279,11 +307,11 @@ std::string FusedOp::GenerateCode(const
std::vector<OpReqType> &req,
}
std::string tuple = "{";
for (int i = 0; i < axis; i++) {
- tuple = tuple + def + ",";
+ tuple += def + ",";
}
tuple += str;
for (int i = axis + 1; i < ndim; i++) {
- tuple = tuple + "," + def;
+ tuple += "," + def;
}
tuple += "}";
return tuple;
@@ -295,12 +323,6 @@ std::string FusedOp::GenerateCode(const
std::vector<OpReqType> &req,
}
return false;
};
- auto build_string_axis = [ndim](int axis) {
- if (axis < 0) {
- axis = ndim + axis;
- }
- return std::to_string(axis);
- };
auto build_string_end = [i, ndim, var_name](std::string* code) {
std::string end_var_name = var_name + "_" + std::to_string(i) +
"_end";
*code += "op::Shape<" + std::to_string(ndim) + "> "+ end_var_name
+ ";\n";
@@ -323,12 +345,15 @@ std::string FusedOp::GenerateCode(const
std::vector<OpReqType> &req,
}
end = extra_var_name;
} else {
- begin = parse_tuple(source->attrs.dict.at("begin"), "0");
- end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX");
if (op_name == "slice_axis") {
+ begin = parse_int(source->attrs.dict.at("begin"), "0");
+ end = parse_int(source->attrs.dict.at("end"), "INT_MAX");
int axis = std::stoi(source->attrs.dict.at("axis"));
begin = build_tuple(axis, begin, "0");
end = build_tuple(axis, end, "INT_MAX");
+ } else {
+ begin = parse_tuple(source->attrs.dict.at("begin"), "0");
+ end = parse_tuple(source->attrs.dict.at("end"), "INT_MAX");
}
if (check_shapes) {
if (check_tuple(begin) && check_tuple(end)) {
diff --git a/tests/python/gpu/test_fusion.py b/tests/python/gpu/test_fusion.py
index 1febf8d..e81c794 100644
--- a/tests/python/gpu/test_fusion.py
+++ b/tests/python/gpu/test_fusion.py
@@ -340,6 +340,27 @@ def test_fusion_reshape_executor():
out = f.forward(is_train=False, data1=data, data2=data)
assert out[0].sum().asscalar() == 150
+@with_seed()
+def test_fusion_cycle():
+ from mxnet.gluon import HybridBlock
+ class Test(HybridBlock):
+ def __init__(self, **kwargs):
+ super(Test, self).__init__(**kwargs)
+
+ def hybrid_forward(self, F, x, y):
+ x = F.relu(x)
+ y = F.relu(y)
+ z1 = F.expand_dims(F.sum_axis(x, axis=1), axis=1)
+ z2 = F.expand_dims(F.sum_axis(y, axis=1), axis=1)
+ return x + z2, y + z1
+
+ t = Test()
+ a = mx.nd.zeros(shape=(10,1), ctx=mx.gpu())
+ b = mx.nd.zeros(shape=(10,1), ctx=mx.gpu())
+ t.hybridize(static_alloc=True, static_shape=True)
+ out = t(a, b)
+ mx.nd.waitall()
+
if __name__ == '__main__':
import nose
nose.runmodule()