anirudh2290 commented on a change in pull request #12157: Subgraph API for 
integrating accelerators with MXNet
URL: https://github.com/apache/incubator-mxnet/pull/12157#discussion_r212480073
 
 

 ##########
 File path: src/executor/graph_executor.cc
 ##########
 @@ -1428,6 +1430,146 @@ GraphExecutor::CachedSegOpr 
GraphExecutor::CreateCachedSegOpr(size_t topo_start,
     iter->c_str());
   return ret;
 }
+
+// Infer shapes, dtypes, stypes, contexts for the forward graph
+static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
+                                     nnvm::ShapeVector arg_shapes,
+                                     nnvm::DTypeVector arg_dtypes,
+                                     StorageTypeVector arg_stypes,
+                                     const Context& default_ctx,
+                                     const std::map<std::string, Context>& 
ctx_map,
+                                     const std::vector<Context>& in_arg_ctxes,
+                                     const std::vector<Context>& 
aux_state_ctxes) {
+  const auto& indexed_graph = g.indexed_graph();
+  const auto num_forward_inputs = indexed_graph.input_nodes().size();
+  g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
+                   aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs, indexed_graph,
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
+  if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
+    HandleInferTypeError(num_forward_inputs, indexed_graph,
+                         g.GetAttr<nnvm::DTypeVector>("dtype"));
+  }
+  g = InferStorageType(std::move(g), std::move(arg_stypes), 
"__storage_type__");
+  if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
+    HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
+                                g.GetAttr<StorageTypeVector>("storage_type"));
+  }
+  return g;
+}
+
+// Given input attr arrays, partition the graph using the backend name equal 
to prop_name.
+// This is a common function for bind and simple_bind flows.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const nnvm::ShapeVector& arg_shapes,
+                                   const nnvm::DTypeVector& arg_dtypes,
+                                   const StorageTypeVector& arg_stypes,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& 
ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& 
aux_state_ctxes) {
+  auto subgraph_prop = 
op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
+  nnvm::Symbol ret = src.Copy();
+  nnvm::Graph g;
+  g.outputs = ret.outputs;
+  g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
+                        ctx_map, in_arg_ctxes, aux_state_ctxes);
+  subgraph_prop->SetAttr("graph", g);
+  auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name);
+  // assign a op name set to the subgraph property if it has been provided by 
users
+  if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
+    LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << 
prop_name
+              << " has been assigned a value. Please make sure it is 
initialized"
+                 " only for the testing purpose.";
+    subgraph_prop->SetAttr("op_names", it->second);
+  }
+  g.attrs["subgraph_property"] = 
std::make_shared<nnvm::any>(std::move(subgraph_prop));
+  g = ApplyPass(std::move(g), "PartitionGraph");
+  ret.outputs = g.outputs;
+  return ret;
+}
+
+// Given input attr dicts, partition the graph using the backend name equal to 
prop_name.
+// This is for simple_bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::unordered_map<std::string, 
TShape>& arg_shape_map,
+                                   const std::unordered_map<std::string, int>& 
arg_dtype_map,
+                                   const std::unordered_map<std::string, int>& 
arg_stype_map,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& 
ctx_map,
+                                   const std::vector<Context>& in_arg_ctxes,
+                                   const std::vector<Context>& 
aux_state_ctxes) {
+  const std::vector<std::string> input_names = 
src.ListInputNames(Symbol::kAll);
+  nnvm::ShapeVector arg_shapes(input_names.size(), TShape());
+  nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
+  StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    auto it1 = arg_shape_map.find(input_names[i]);
+    if (arg_shape_map.end() != it1) {
+      arg_shapes[i] = it1->second;
+    }
+    auto it2 = arg_dtype_map.find(input_names[i]);
+    if (arg_dtype_map.end() != it2) {
+      arg_dtypes[i] = it2->second;
+    }
+    auto it3 = arg_stype_map.find(input_names[i]);
+    if (arg_stype_map.end() != it3) {
+      arg_stypes[i] = it3->second;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
+                        default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes);
+}
+
+// Given input ndarrays, partition the graph using the backend name equal to 
prop_name.
+// This is for bind flow.
+static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
+                                   const std::string& prop_name,
+                                   const std::vector<NDArray> &in_args,
+                                   const std::vector<NDArray> &aux_states,
+                                   const Context& default_ctx,
+                                   const std::map<std::string, Context>& 
ctx_map) {
+  const std::vector<std::string> input_names = 
src.ListInputNames(Symbol::kAll);
+  const std::vector<std::string> arg_names = 
src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
+  const std::vector<std::string> aux_names = 
src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
+  CHECK_EQ(arg_names.size(), in_args.size());
+  CHECK_EQ(aux_names.size(), aux_states.size());
+  nnvm::ShapeVector arg_shapes;  // all input shapes
+  arg_shapes.reserve(input_names.size());
+  nnvm::DTypeVector arg_dtypes;  // all input dtypes
+  arg_dtypes.reserve(input_names.size());
+  StorageTypeVector arg_stypes;  // all input stypes
+  arg_stypes.reserve(input_names.size());
+  std::vector<Context> in_arg_ctxes(in_args.size());
+  std::vector<Context> aux_state_ctxes(aux_states.size());
+
+  size_t i1 = 0, i2 = 0;
+  for (size_t i = 0; i < input_names.size(); ++i) {
+    if (i2 < aux_names.size() && aux_names[i2] == input_names[i]) {
+      arg_shapes.push_back(aux_states[i2].shape());
+      arg_dtypes.push_back(aux_states[i2].dtype());
+      arg_stypes.push_back(aux_states[i2].storage_type());
+      aux_state_ctxes[i2] = aux_states[i2].ctx();
+      ++i2;
+    } else {
+      CHECK(i1 < arg_names.size());
+      CHECK_EQ(arg_names[i1], input_names[i]);
+      arg_shapes.push_back(in_args[i1].shape());
+      arg_dtypes.push_back(in_args[i1].dtype());
+      arg_stypes.push_back(in_args[i1].storage_type());
+      in_arg_ctxes[i1] = in_args[i1].ctx();
+      ++i1;
+    }
+  }
+  return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes,
 
 Review comment:
   should we check for prop_name here. Currently only default is supported 
correct ?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to