zheng-da commented on a change in pull request #14393: [MXNET-1352] Allow
dynamic shape in while_loop and if conditionals
URL: https://github.com/apache/incubator-mxnet/pull/14393#discussion_r282183624
##########
File path: src/executor/graph_executor.cc
##########
@@ -76,20 +77,77 @@ void GraphExecutor::PartialForward(bool is_train, int
step, int *step_left) {
}
void GraphExecutor::Backward(const std::vector<NDArray>& head_grads, bool
is_train) {
- const auto& idx = graph_.indexed_graph();
- if (num_forward_inputs_ != idx.input_nodes().size()) {
- for (size_t i = 0; i < head_grad_array_.size(); ++i) {
- if (!head_grad_array_[i].is_none()) {
- CHECK(i < head_grads.size() && !head_grads[i].is_none())
- << "Because the last operator is not Loss function, "
- << "head_gradient is required when calling backward. "
- << "If you are attempting to minimize the output as "
- << "an objective, please modify your network and "
- << "pass it through the make_loss symbol.";
- CopyFromTo(head_grads[i], &(head_grad_array_[i]));
+ {
+ const auto& idx = graph_.indexed_graph();
+ if (num_forward_inputs_ != idx.input_nodes().size()) {
+ for (size_t i = 0; i < head_grad_array_.size(); ++i) {
+ if (!head_grad_array_[i].is_none()) {
+ CHECK(i < head_grads.size() && !head_grads[i].is_none())
+ << "Because the last operator is not Loss function, "
+ << "head_gradient is required when calling backward. "
+ << "If you are attempting to minimize the output as "
+ << "an objective, please modify your network and "
+ << "pass it through the make_loss symbol.";
+ const NDArray &from = head_grads[i];
+ NDArray &to = head_grad_array_[i];
+ if (this->is_dynamic_) {
+ to.WaitToRead();
+ if (!shape_is_known(to.shape())) {
+ to.Init(from.shape());
+ }
+ }
+ CopyFromTo(from, &to);
+ }
+ }
+ }
+ }
+ if (this->is_dynamic_) {
+ graph_ = InferShape(std::move(graph_), {}, "");
+ mxnet::ShapeVector rshape =
graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
+ const auto& idx = graph_.indexed_graph();
+ for (size_t nid = 0; nid < idx.num_nodes(); ++nid) {
+ const auto& inode = idx[nid];
+ if (inode.source->is_variable()) continue;
+ OpNode& opnode = op_nodes_[nid];
+ if (opnode.skip_exec_node) continue;
+ for (NDArray &array : opnode.exec->in_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ }
+ int i = 0;
+ for (NDArray &array : opnode.exec->in_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ if (!shape_is_known(array.shape())) {
+ mxnet::TShape shape = rshape[idx.entry_id(inode.inputs[i])];
+ if (shape_is_known(shape)) {
+ array.ReshapeAndAlloc(shape);
+ }
+ }
+ ++i;
+ }
+ i = 0;
+ for (NDArray &array : opnode.exec->out_array) {
+ array.WaitToRead();
+ if (!shape_is_known(array.shape())) {
+ array.SetShapeFromChunk();
+ }
+ if (!shape_is_known(array.shape())) {
+ mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
+ if (shape_is_known(shape)) {
+ array.ReshapeAndAlloc(shape);
+ }
+ }
+ ++i;
}
}
+ graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
}
+ const auto& idx = graph_.indexed_graph();
Review comment:
what is this used for?
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services