ashokei closed pull request #12008: [WIP] Enable graph partitioning in executor
with types/shape awareness
URL: https://github.com/apache/incubator-mxnet/pull/12008
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc
index 33c6f574a04..06e54fb894c 100644
--- a/src/executor/graph_executor.cc
+++ b/src/executor/graph_executor.cc
@@ -33,6 +33,7 @@
#include "../profiler/profiler.h"
#include "../common/utils.h"
#include "../common/exec_utils.h"
+#include "../operator/subgraph/default_subgraph_op.h"
namespace mxnet {
namespace exec {
@@ -506,6 +507,61 @@ static void HandleInferStorageTypeError(const size_t
num_forward_inputs,
<< oss.str();
}
+/*!
+ * \brief Infer and add to graph attr using given info.
+ * given values are moved to graph attr
+ */
+void InferAttrs(nnvm::Graph *g, size_t num_forward_inputs, nnvm::ShapeVector
*arg_shapes,
+ nnvm::DTypeVector *arg_dtypes, StorageTypeVector *arg_stypes) {
+ *g = InferShape(std::move(*g), std::move(*arg_shapes), "__shape__");
+ if (g->GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+ HandleInferShapeError(num_forward_inputs, g->indexed_graph(),
+ g->GetAttr<nnvm::ShapeVector>("shape"));
+ }
+
+ *g = InferType(std::move(*g), std::move(*arg_dtypes), "__dtype__");
+ if (g->GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+ HandleInferTypeError(num_forward_inputs, g->indexed_graph(),
+ g->GetAttr<nnvm::DTypeVector>("dtype"));
+ }
+ if (arg_stypes->size() == 0) {
+ *g = InferStorageType(std::move(*g), StorageTypeVector(), "");
+ } else {
+ *g = InferStorageType(std::move(*g), std::move(*arg_stypes),
"__storage_type__");
+ }
+ if (g->GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+ HandleInferStorageTypeError(num_forward_inputs, g->indexed_graph(),
+ g->GetAttr<StorageTypeVector>("storage_type"));
+ }
+}
+
+void GetAttrs(const nnvm::Symbol &symbol, const std::vector<NDArray> &in_args,
+ const std::vector<NDArray> &aux_states, nnvm::ShapeVector
*arg_shapes,
+ nnvm::DTypeVector *arg_dtypes, StorageTypeVector *arg_stypes) {
+ auto num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size();
+ nnvm::Graph g;
+ g.outputs = symbol.outputs;
+ const auto &idx = g.indexed_graph();
+ const auto &mutable_nodes = idx.mutable_input_nodes();
+ size_t arg_top = 0, aux_top = 0;
+ for (size_t i = 0; i < num_forward_inputs; ++i) {
+ const uint32_t nid = idx.input_nodes().at(i);
+ if (mutable_nodes.count(nid)) {
+ CHECK_LT(aux_top, aux_states.size());
+ arg_shapes->push_back(aux_states[aux_top].shape());
+ arg_dtypes->push_back(aux_states[aux_top].dtype());
+ arg_stypes->push_back(aux_states[aux_top].storage_type());
+ ++aux_top;
+ } else {
+ CHECK_LT(arg_top, in_args.size());
+ arg_shapes->push_back(in_args[arg_top].shape());
+ arg_dtypes->push_back(in_args[arg_top].dtype());
+ arg_stypes->push_back(in_args[arg_top].storage_type());
+ ++arg_top;
+ }
+ }
+}
+
/*!
* \brief GraphExecutor initializer for regular bind flow in which
* input arguments and gradients are provided by users. This initializer
@@ -533,8 +589,13 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
std::vector<Context> aux_state_ctxes(aux_states.size());
std::transform(aux_states.begin(), aux_states.end(),
aux_state_ctxes.begin(), get_ctx1);
+ nnvm::ShapeVector sym_arg_shapes;
+ nnvm::DTypeVector sym_arg_dtypes;
+ StorageTypeVector sym_arg_stypes;
+ GetAttrs(symbol, in_args, aux_states, &sym_arg_shapes, &sym_arg_dtypes,
&sym_arg_stypes);
nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
- arg_grad_ctxes, aux_state_ctxes, grad_req_types);
+ arg_grad_ctxes, aux_state_ctxes, &sym_arg_shapes,
+ &sym_arg_dtypes, &sym_arg_stypes, grad_req_types);
// create arg_shapes and arg_dtypes for shape and type inferences
const auto& idx = g.indexed_graph();
@@ -582,27 +643,12 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
}
}
+ g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(arg_stypes));
// expand arg_shapes and arg_dtypes to contain backward inputs
arg_shapes.resize(idx.input_nodes().size(), TShape());
- g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
- if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
- HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<nnvm::ShapeVector>("shape"));
- }
-
arg_dtypes.resize(idx.input_nodes().size(), -1);
- g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
- if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
- HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<nnvm::DTypeVector>("dtype"));
- }
-
- g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(arg_stypes));
- g = InferStorageType(std::move(g), StorageTypeVector(), "");
- if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
- HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<StorageTypeVector>("storage_type"));
- }
+ StorageTypeVector tmp_stypes;
+ InferAttrs(&g, num_forward_inputs_, &arg_shapes, &arg_dtypes, &tmp_stypes);
// Initialize the rest attributes of the graph.
// This function can be called by regular bind
@@ -941,6 +987,37 @@ void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
this->InitOpSegs();
}
+void GetAttrs(const nnvm::Symbol &symbol,
+ const std::unordered_map<std::string, TShape> &arg_shape_map,
+ const std::unordered_map<std::string, int> &arg_dtype_map,
+ const std::unordered_map<std::string, int> &arg_stype_map,
+ nnvm::ShapeVector *arg_shapes, nnvm::DTypeVector *arg_dtypes,
+ StorageTypeVector *arg_stypes) {
+ auto num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size();
+ nnvm::Graph g;
+ g.outputs = symbol.outputs;
+ const auto &idx = g.indexed_graph();
+ arg_shapes->resize(idx.input_nodes().size(), TShape());
+ arg_dtypes->resize(idx.input_nodes().size(), -1);
+ arg_stypes->resize(idx.input_nodes().size(), -1);
+ for (size_t i = 0; i < num_forward_inputs; ++i) {
+ const uint32_t nid = idx.input_nodes().at(i);
+ const std::string &name = idx[nid].source->attrs.name;
+ auto it1 = arg_shape_map.find(name);
+ if (arg_shape_map.end() != it1) {
+ (*arg_shapes)[i] = it1->second;
+ }
+ auto it2 = arg_dtype_map.find(name);
+ if (arg_dtype_map.end() != it2) {
+ (*arg_dtypes)[i] = it2->second;
+ }
+ auto it3 = arg_stype_map.find(name);
+ if (arg_stype_map.end() != it3) {
+ (*arg_stypes)[i] = it3->second;
+ }
+ }
+}
+
/*!
* \brief GraphExecutor initializer for simple bind flow in
* which only certain input shapes and dtypes are provided by users.
@@ -972,8 +1049,14 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
std::unordered_map<std::string, NDArray>*
shared_buffer,
Executor* shared_exec,
const nnvm::NodeEntryMap<NDArray>& feed_dict) {
- nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
arg_grad_ctxes,
- aux_state_ctxes, grad_req_types);
+ nnvm::ShapeVector sym_arg_shapes;
+ nnvm::DTypeVector sym_arg_dtypes;
+ StorageTypeVector sym_arg_stypes;
+ GetAttrs(symbol, arg_shape_map, arg_dtype_map, arg_stype_map,
&sym_arg_shapes,
+ &sym_arg_dtypes, &sym_arg_stypes);
+ nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
+ arg_grad_ctxes, aux_state_ctxes, &sym_arg_shapes,
+ &sym_arg_dtypes, &sym_arg_stypes, grad_req_types);
// The following code of shape and dtype inferences and argument
// initialization is for simple_bind only. Regular bind operation
// should do this differently.
@@ -1000,23 +1083,7 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
arg_stypes[i] = it3->second;
}
}
- g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
- if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
- HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<nnvm::ShapeVector>("shape"));
- }
-
- g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
- if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
- HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<nnvm::DTypeVector>("dtype"));
- }
-
- g = InferStorageType(std::move(g), std::move(arg_stypes),
"__storage_type__");
- if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
- HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
- g.GetAttr<StorageTypeVector>("storage_type"));
- }
+ InferAttrs(&g, num_forward_inputs_, &arg_shapes, &arg_dtypes, &arg_stypes);
// Create in_args, arg_grads, and aux_states using
// the inferred shapes and dtypes.
@@ -1168,9 +1235,36 @@ Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
+ nnvm::ShapeVector *arg_shapes,
+ nnvm::DTypeVector *arg_dtypes,
+ StorageTypeVector *arg_stypes,
const std::vector<OpReqType>& grad_req_types) {
- // setup gradient
- nnvm::Graph g = InitFullGraph(symbol, grad_req_types);
+ // parition given symbol graph using shape & types.
+ nnvm::Graph g;
+ if (ctx_map.size() == 0) {
+ g.outputs = symbol.outputs;
+ auto num_forward_outputs = symbol.outputs.size();
+ auto num_forward_inputs = symbol.ListInputs(nnvm::Symbol::kAll).size();
+ g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
+ aux_state_ctxes, grad_req_types, num_forward_inputs,
+ num_forward_outputs);
+
+ // infer & add types/shape attrs to graph before calling partition pass
+ InferAttrs(&g, num_forward_inputs, arg_shapes, arg_dtypes, arg_stypes);
+ // partition pass with default subgraph property
+ mxnet::op::SubgraphPropertyPtr property =
+ std::make_shared<mxnet::op::DefaultSubgraphProperty>();
+ g.attrs["subgraph_property"] =
+ std::make_shared<nnvm::any>(std::move(property));
+ g = ApplyPass(std::move(g), "PartitionGraph");
+ auto sym = symbol.Copy();
+ sym.outputs = g.outputs;
+ // setup gradient
+ g = InitFullGraph(sym, grad_req_types);
+ } else {
+ // setup gradient
+ g = InitFullGraph(symbol, grad_req_types);
+ }
// create "device" and "context" attrs for the graph
g = AssignContext(g, default_ctx, ctx_map,
diff --git a/src/executor/graph_executor.h b/src/executor/graph_executor.h
index bfc415b4526..7c6de842645 100644
--- a/src/executor/graph_executor.h
+++ b/src/executor/graph_executor.h
@@ -184,6 +184,9 @@ class GraphExecutor : public Executor {
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& arg_grad_ctxes,
const std::vector<Context>& aux_state_ctxes,
+ nnvm::ShapeVector *arg_shapes,
+ nnvm::DTypeVector *arg_dtypes,
+ StorageTypeVector *arg_stypes,
const std::vector<OpReqType>& grad_req_types);
// intialize the full graph for simple bind, including gradient
Graph InitFullGraph(nnvm::Symbol symbol,
diff --git a/src/operator/subgraph/default_subgraph_op.h
b/src/operator/subgraph/default_subgraph_op.h
index 7d6624ef14d..008a4ef38d9 100644
--- a/src/operator/subgraph/default_subgraph_op.h
+++ b/src/operator/subgraph/default_subgraph_op.h
@@ -20,8 +20,9 @@
#ifndef MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
#define MXNET_OPERATOR_SUBGRAPH_DEFAULT_SUBGRAPH_OP_H_
-#include <vector>
+#include <mxnet/graph_attr_types.h>
#include <string>
+#include <vector>
#include "./common.h"
namespace mxnet {
@@ -35,18 +36,19 @@ namespace op {
*/
class SubgraphSelector {
public:
- virtual ~SubgraphSelector() {
- }
+ virtual ~SubgraphSelector() {}
// Determine if the node should be selected for a subgraph.
- virtual bool Select(const nnvm::Node &n) = 0;
+ virtual bool Select(const nnvm::Graph &g, const nnvm::Node &n) = 0;
// Determine if the input node should be selected for a subgraph.
- virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) =
0;
+ virtual bool SelectInput(const nnvm::Graph &g, const nnvm::Node &n,
+ const nnvm::Node &new_node) = 0;
// Determine if the output node should be selected for a subgraph.
- virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) =
0;
+ virtual bool SelectOutput(const nnvm::Graph &g, const nnvm::Node &n,
+ const nnvm::Node &new_node) = 0;
// Post processes pre-selected subgraph nodes. Return a list of nodes that
// users want to keep in subgraph(s).
- virtual std::vector<nnvm::Node*> Filter(nnvm::Graph* g,
- const std::vector<nnvm::Node*>&
candidates) {
+ virtual std::vector<nnvm::Node *> Filter(const nnvm::Graph &g,
+ const std::vector<nnvm::Node *>
&candidates) {
return candidates;
}
};
@@ -76,7 +78,7 @@ void RegisterSubgraphProperty(SubgraphPropertyPtr property);
* This selects nodes for a subgraph that only contains operators
* in a given set and it visits nodes via both input and output links.
*/
-class ContainOpSelector: public SubgraphSelector {
+class ContainOpSelector : public SubgraphSelector {
std::shared_ptr<const std::unordered_set<std::string>> op_names;
public:
@@ -84,15 +86,16 @@ class ContainOpSelector: public SubgraphSelector {
this->op_names = op_names;
}
- virtual bool Select(const nnvm::Node &n) {
+ virtual bool Select(const nnvm::Graph &g, const nnvm::Node &n) {
return !n.is_variable() && op_names->count(n.op()->name);
}
- virtual bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) {
+ virtual bool SelectInput(const nnvm::Graph &g, const nnvm::Node &n, const
nnvm::Node &new_node) {
return !new_node.is_variable() && op_names->count(new_node.op()->name);
}
- virtual bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) {
+ virtual bool SelectOutput(const nnvm::Graph &g, const nnvm::Node &n,
+ const nnvm::Node &new_node) {
return !new_node.is_variable() && op_names->count(new_node.op()->name);
}
};
@@ -101,10 +104,11 @@ class ContainOpSelector: public SubgraphSelector {
* This subgraph property finds a subgraph whose nodes have only operators
* within a set. The operators in the subgraph will be executed by
_default_subgraph_op.
*/
-class DefaultSubgraphProperty: public SubgraphProperty {
+class DefaultSubgraphProperty : public SubgraphProperty {
public:
- explicit DefaultSubgraphProperty(const std::unordered_set<std::string>
&op_names) :
- op_names_(std::make_shared<std::unordered_set<std::string>>(op_names)) {}
+ explicit DefaultSubgraphProperty(
+ const std::unordered_set<std::string> &op_names =
std::unordered_set<std::string>{})
+ : op_names_(std::make_shared<std::unordered_set<std::string>>(op_names))
{}
virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym,
const int subgraph_id = 0) const {
nnvm::NodePtr n = nnvm::Node::Create();
diff --git a/src/operator/subgraph/partition_graph.cc
b/src/operator/subgraph/partition_graph.cc
index 9672877eb1d..75aa75019de 100644
--- a/src/operator/subgraph/partition_graph.cc
+++ b/src/operator/subgraph/partition_graph.cc
@@ -163,7 +163,7 @@ bool LabelSubgraph(const Graph& g,
// get qualified adjacent input nodes
for (auto& e : cur_node->node->inputs) {
const bool select_input = (!excluded_nodes ||
!excluded_nodes->count(e.node.get()))
- && subgraph_selector->SelectInput(*cur_node->node, *e.node);
+ && subgraph_selector->SelectInput(g, *cur_node->node, *e.node);
if (select_input) {
// e.node is a subgraph node
const auto nid = indexed_graph.node_id(e.node.get());
@@ -181,7 +181,7 @@ bool LabelSubgraph(const Graph& g,
// get qualified output nodes
for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end();
++it) {
const bool select_output = (!excluded_nodes ||
!excluded_nodes->count(it->first))
- && subgraph_selector->SelectOutput(*cur_node->node, *it->first);
+ && subgraph_selector->SelectOutput(g, *cur_node->node, *it->first);
if (select_output) {
// it->first is a subgraph node
const auto nid = indexed_graph.node_id(it->first);
@@ -401,14 +401,14 @@ void FindSubgraphs(Graph* g,
for (size_t i = 0; i < simple_nodes.size(); ++i) {
nnvm::Node* node = simple_nodes[i]->node;
auto subgraph_selector = subg_prop.CreateSubgraphSelector();
- if (subgraph_selector->Select(*node) && simple_nodes[i]->label == -1) {
+ if (subgraph_selector->Select(*g, *node) && simple_nodes[i]->label == -1) {
// pre-select nodes that can be grouped in a subgraph
std::vector<nnvm::Node*> preselected_nodes;
PreSelectSubgraphNodes(*g, subgraph_selector, subgraph_id, i,
simple_nodes,
&preselected_nodes);
// filter out unqualified pre-selected nodes
- std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(g,
preselected_nodes);
+ std::vector<nnvm::Node*> filtered_nodes = subgraph_selector->Filter(*g,
preselected_nodes);
// make sure filtered_nodes is a subset of preselected_nodes
for (const auto n : filtered_nodes) {
@@ -741,6 +741,11 @@ Graph PartitionGraph(Graph&& g) {
#endif
CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], i,
&entry_top_order_map);
}
+#if SUBGRAPH_DEBUG
+ if (subgraph_nodes.size() == 0) {
+ LOG(INFO) << "The graph has no fuseable nodes, the original graph is
returned.";
+ }
+#endif
return g;
}
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services