anirudh2290 commented on a change in pull request #12157: Subgraph API for
integrating accelerators with MXNet
URL: https://github.com/apache/incubator-mxnet/pull/12157#discussion_r212480073
##########
File path: src/executor/graph_executor.cc
##########
@@ -1428,6 +1430,146 @@ GraphExecutor::CachedSegOpr
GraphExecutor::CreateCachedSegOpr(size_t topo_start,
iter->c_str());
return ret;
}
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+ nnvm::ShapeVector arg_shapes,
+ nnvm::DTypeVector arg_dtypes,
+ StorageTypeVector arg_stypes,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>&
aux_state_ctxes) {
+ const auto& indexed_graph = g.indexed_graph();
+ const auto num_forward_inputs = indexed_graph.input_nodes().size();
+ g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+ aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+ g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+ if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+ HandleInferShapeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<nnvm::ShapeVector>("shape"));
+ }
+ g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+ if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+ HandleInferTypeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<nnvm::DTypeVector>("dtype"));
+ }
+ g = InferStorageType(std::move(g), std::move(arg_stypes),
"__storage_type__");
+ if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+ HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+ g.GetAttr<StorageTypeVector>("storage_type"));
+ }
+ return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal
to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const nnvm::ShapeVector& arg_shapes,
+ const nnvm::DTypeVector& arg_dtypes,
+ const StorageTypeVector& arg_stypes,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>&
aux_state_ctxes) {
+ auto subgraph_prop =
op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+ nnvm::Symbol ret = src.Copy();
+ nnvm::Graph g;
+ g.outputs = ret.outputs;
+ g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
+ ctx_map, in_arg_ctxes, aux_state_ctxes);
+ subgraph_prop->SetAttr("graph", g);
+ auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name);
+ // assign a op name set to the subgraph property if it has been provided by
users
+ if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
+ LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " <<
prop_name
+ << " has been assigned a value. Please make sure it is
initialized"
+ " only for the testing purpose.";
+ subgraph_prop->SetAttr("op_names", it->second);
+ }
+ g.attrs["subgraph_property"] =
std::make_shared<nnvm::any>(std::move(subgraph_prop));
+ g = ApplyPass(std::move(g), "PartitionGraph");
+ ret.outputs = g.outputs;
+ return ret;
+}
+
+// Given input attr dicts, partition the graph using the backend name equal to
prop_name.
+// This is for simple_bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const std::unordered_map<std::string,
TShape>& arg_shape_map,
+ const std::unordered_map<std::string, int>&
arg_dtype_map,
+ const std::unordered_map<std::string, int>&
arg_stype_map,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map,
+ const std::vector<Context>& in_arg_ctxes,
+ const std::vector<Context>&
aux_state_ctxes) {
+ const std::vector<std::string> input_names =
src.ListInputNames(Symbol::kAll);
+ nnvm::ShapeVector arg_shapes(input_names.size(), TShape());
+ nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
+ StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ auto it1 = arg_shape_map.find(input_names[i]);
+ if (arg_shape_map.end() != it1) {
+ arg_shapes[i] = it1->second;
+ }
+ auto it2 = arg_dtype_map.find(input_names[i]);
+ if (arg_dtype_map.end() != it2) {
+ arg_dtypes[i] = it2->second;
+ }
+ auto it3 = arg_stype_map.find(input_names[i]);
+ if (arg_stype_map.end() != it3) {
+ arg_stypes[i] = it3->second;
+ }
+ }
+ return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+ default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
+
+// Given input ndarrays, partition the graph using the backend name equal to
prop_name.
+// This is for bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+ const std::string& prop_name,
+ const std::vector<NDArray> &in_args,
+ const std::vector<NDArray> &aux_states,
+ const Context& default_ctx,
+ const std::map<std::string, Context>&
ctx_map) {
+ const std::vector<std::string> input_names =
src.ListInputNames(Symbol::kAll);
+ const std::vector<std::string> arg_names =
src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+ const std::vector<std::string> aux_names =
src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+ CHECK_EQ(arg_names.size(), in_args.size());
+ CHECK_EQ(aux_names.size(), aux_states.size());
+ nnvm::ShapeVector arg_shapes; // all input shapes
+ arg_shapes.reserve(input_names.size());
+ nnvm::DTypeVector arg_dtypes; // all input dtypes
+ arg_dtypes.reserve(input_names.size());
+ StorageTypeVector arg_stypes; // all input stypes
+ arg_stypes.reserve(input_names.size());
+ std::vector<Context> in_arg_ctxes(in_args.size());
+ std::vector<Context> aux_state_ctxes(aux_states.size());
+
+ size_t i1 = 0, i2 = 0;
+ for (size_t i = 0; i < input_names.size(); ++i) {
+ if (i2 < aux_names.size() && aux_names[i2] == input_names[i]) {
+ arg_shapes.push_back(aux_states[i2].shape());
+ arg_dtypes.push_back(aux_states[i2].dtype());
+ arg_stypes.push_back(aux_states[i2].storage_type());
+ aux_state_ctxes[i2] = aux_states[i2].ctx();
+ ++i2;
+ } else {
+ CHECK(i1 < arg_names.size());
+ CHECK_EQ(arg_names[i1], input_names[i]);
+ arg_shapes.push_back(in_args[i1].shape());
+ arg_dtypes.push_back(in_args[i1].dtype());
+ arg_stypes.push_back(in_args[i1].storage_type());
+ in_arg_ctxes[i1] = in_args[i1].ctx();
+ ++i1;
+ }
+ }
+ return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
Review comment:
should we check for prop_name here. Currently only default is supported
correct ?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services