zheng-da commented on a change in pull request #13419: [MXNET-1233] Enable
dynamic shape in CachedOp
URL: https://github.com/apache/incubator-mxnet/pull/13419#discussion_r240890932
##########
File path: src/imperative/imperative_utils.cc
##########
@@ -22,6 +22,114 @@
namespace mxnet {
namespace imperative {
+
+void NaiveRunGraph(
+ const bool retain_graph,
+ const Context& default_ctx,
+ const nnvm::IndexedGraph& idx,
+ const std::vector<NDArray*> arrays,
+ size_t node_start, size_t node_end,
+ std::vector<OpReqType>&& array_reqs,
+ std::vector<uint32_t>&& ref_count,
+ std::vector<OpStatePtr> *p_states,
+ const DispatchModeVector &dispatch_modes,
+ bool recording) {
+ using namespace nnvm;
+ using namespace imperative;
+ static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
+ static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
+ static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
+
+ const auto imp = Imperative::Get();
+
+ std::vector<OpStatePtr>& states = *p_states;
+
+ for (size_t i = node_start; i < node_end; ++i) {
+ const nnvm::IndexedGraph::Node& node = idx[i];
+ if (node.source->op() == nullptr) {
+ continue;
+ }
+ size_t num_outputs = node.source->num_outputs();
+ // construct `ndinputs`
+ std::vector<NDArray*> ndinputs;
+ ndinputs.reserve(node.inputs.size());
+ for (const auto& j : node.inputs) {
+ ndinputs.emplace_back(arrays[idx.entry_id(j)]);
+ CHECK(!ndinputs.back()->is_none()) << idx[j.node_id].source->attrs.name
<< " " << j.index;
+ }
+ // construct `ndoutputs` and `req`
+ std::vector<NDArray*> ndoutputs;
+ ndoutputs.reserve(num_outputs);
+ for (size_t j = 0; j < num_outputs; ++j) {
+ size_t eid = idx.entry_id(i, j);
+ ndoutputs.emplace_back(arrays[eid]);
+ }
+ // other auxiliary data
+ Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs,
default_ctx);
+ auto invoke = [&](const OpStatePtr &state) {
+ DispatchMode dispatch_mode = DispatchMode::kUndefined;
+ SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs,
&dispatch_mode);
+ std::vector<OpReqType> req;
+ SetWriteInplaceReq(ndinputs, ndoutputs, &req);
+ imp->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs, req,
dispatch_mode, state);
+ for (size_t i = 0; i < ndoutputs.size(); i++) {
+ if (ndoutputs[i]->shape().ndim() == 0) {
+ ndoutputs[i]->WaitToRead();
+ ndoutputs[i]->SetShapeFromChunk();
+ }
+ }
+ if (recording) {
+ imp->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs,
state);
+ }
+ };
+ if (node.source->op() == bwd_cached_op) {
+ // case 1: backward cached op
+ std::vector<OpReqType> req;
+ req.reserve(num_outputs);
+ for (size_t j = 0; j < num_outputs; ++j) {
+ size_t eid = idx.entry_id(i, j);
+ req.push_back(array_reqs[eid]);
+ CHECK(array_reqs[eid] == kNullOp || !ndoutputs.back()->is_none());
+ }
+ const auto& cached_op =
dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
+ nnvm::Node* fwd_node = node.source->control_deps[0].get();
+ auto fwd_node_id = idx.node_id(fwd_node);
+ cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req,
ndoutputs);
+ } else if (createop.count(node.source->op())) {
+ // case 2: node is in createop
Review comment:
i think this is to handle stateful operators
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services