eric-haibin-lin commented on a change in pull request #10882: move exec.reshape 
to backend
URL: https://github.com/apache/incubator-mxnet/pull/10882#discussion_r187844712
 
 

 ##########
 File path: src/executor/graph_executor.cc
 ##########
 @@ -1043,6 +1043,108 @@ void GraphExecutor::Init(nnvm::Symbol symbol,
   FinishInitGraph(symbol, g, shared_exec, feed_dict);
 }
 
+/*!
+ * \brief Return a new executor with the same symbol and shared memory,
+ * but different input/output shapes.
+ * For runtime reshaping, variable length sequences, etc.
+ * The returned executor shares state with the current one,
+ * and cannot be used in parallel with it.
+ */
+Executor* GraphExecutor::Reshape(const bool partial_shaping,
+                                 const bool allow_up_sizing,
+                                 const Context& default_ctx,
+                                 const std::map<std::string, Context>& ctx_map,
+                                 const std::unordered_map<std::string, TShape>&
+                                   provided_arg_shapes,
+                                 std::vector<NDArray>* in_args,
+                                 std::vector<NDArray>* arg_grads,
+                                 std::vector<NDArray>* aux_states) {
+  nnvm::Graph g;
+  g.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
+    graph_.outputs.begin() + num_forward_outputs_);
+  nnvm::Symbol symbol;
+  symbol.outputs = g.outputs;
+  const nnvm::IndexedGraph& idx = g.indexed_graph();
+  nnvm::ShapeVector arg_shapes(idx.input_nodes().size(), TShape());
+  for (size_t i = 0; i < num_forward_inputs_; ++i) {
+    const uint32_t nid = idx.input_nodes().at(i);
+    const std::string& name = idx[nid].source->attrs.name;
+    auto it = provided_arg_shapes.find(name);
+    if (provided_arg_shapes.end() != it) {
+      arg_shapes[i] = it->second;
+    }
+  }
+  g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
+  if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
+    HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
+                          g.GetAttr<nnvm::ShapeVector>("shape"));
+  }
+  const nnvm::ShapeVector& shape_vec = g.GetAttr<nnvm::ShapeVector>("shape");
+  std::vector<OpReqType> grad_req_types;
+  size_t grad_top = 0;
+
+  std::ostringstream up_sizing_msg, unspecified_msg;
+  up_sizing_msg << ": Arg of new shape which is larger than original."
+                << "First making a big executor and then down sizing it "
+                << "is more efficient than the reverse."
+                << "If you really want to up size, set allow_up_sizing=True "
+                << "to enable allocation of new arrays.";
+  unspecified_msg << ": unspecified array's shape changed. "
+                  << "This can cause the new executor to not share parameters "
+                  << "with the old one. Please check for error in network."
+                  << "If this is intended, set partial_shaping=True to 
suppress this warning.";
+  for (uint32_t nid : idx.input_nodes()) {
+    std::string name = idx[nid].source->attrs.name;
+    const TShape& new_shape = shape_vec[idx.entry_id(nid, 0)];
+    if (idx.mutable_input_nodes().count(nid) == 0) {
+      NDArray& arr = in_arg_map_.at(name);
+      auto it = arg_grad_map_.find(name);
+      if (partial_shaping || provided_arg_shapes.count(name) || new_shape == 
arr.shape()) {
+        if (new_shape.Size() > arr.shape().Size()) {
+          CHECK(allow_up_sizing) << name << up_sizing_msg.str();
 
 Review comment:
   Need better error message

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to