eric-haibin-lin commented on a change in pull request #11641: [MXNET-876] make 
CachedOp a normal operator
URL: https://github.com/apache/incubator-mxnet/pull/11641#discussion_r219712647
 
 

 ##########
 File path: src/imperative/cached_op.cc
 ##########
 @@ -1067,34 +1067,145 @@ void CachedOp::Backward(
   Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
-bool CachedOp::ForwardStorageType(const nnvm::NodeAttrs& attrs,
-                                  const int dev_mask,
-                                  DispatchMode* dispatch_mode,
-                                  std::vector<int> *in_attrs,
-                                  std::vector<int> *out_attrs) {
-  using namespace imperative;
-  nnvm::Graph g(fwd_graph_);
-  const auto& idx = g.indexed_graph();
-  const auto &outputs = idx.outputs();
+/*
+ * This is the operator state of CachedOp when CachedOp is used in the symbol
+ * executor. This is different from the OpState returned by CachedOp::Forward.
+ * The main reason why we need this OpState is that CachedOp and the symbol 
executor
+ * maintain OpState differently. The symbol executor generates OpState in 
advance
+ * while CachedOp generates OpState after Forward is called. We need this data
+ * structure to keep the OpState generated by CachedOp::Forward and pass it to
+ * Backward.
+ */
+struct CachedOpActualState {
+  std::shared_ptr<CachedOp> op;
+  OpStatePtr forward_state;
 
-  // Prepare stypes and contexts based on inputs
-  StorageTypeVector storage_type_inputs;
-  storage_type_inputs.reserve(in_attrs->size());
-  for (size_t i = 0; i < in_attrs->size(); ++i) {
-    storage_type_inputs.emplace_back(in_attrs->at(i));
+  explicit CachedOpActualState(std::shared_ptr<CachedOp> op) {
+    this->op = op;
   }
-  exec::DevMaskVector dev_masks(idx.num_nodes(), dev_mask);
+};
 
-  // Forward graph storage type inference
-  CheckAndInferStorageType(&g, std::move(dev_masks), 
std::move(storage_type_inputs), true);
-  // Retrieve result and set outputs
-  const auto& inferred_stypes = g.GetAttr<StorageTypeVector>("storage_type");
-  for (size_t i = 0; i < out_attrs->size(); i++) {
-    const auto eid = idx.entry_id(outputs[i]);
-    STORAGE_TYPE_ASSIGN_CHECK(*out_attrs, i, inferred_stypes[eid]);
+/*
+ * This is the forward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpForward(const OpStatePtr& state_ptr,
+                     const OpContext& ctx,
+                     const std::vector<NDArray>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<NDArray>& outputs) {
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs(in_bufs.size());
+  std::vector<NDArray *> out_ptrs(out_bufs.size());
+  for (size_t i = 0; i < in_ptrs.size(); i++)
+    in_ptrs[i] = &in_bufs[i];
+  for (size_t i = 0; i < out_ptrs.size(); i++)
+    out_ptrs[i] = &out_bufs[i];
+
+  // Set is_recording correct for the imperative executor.
+  bool orig_is_record;
+  if (ctx.need_grad)
+    orig_is_record = Imperative::Get()->set_is_recording(true);
+  else
+    orig_is_record = Imperative::Get()->is_recording();
+  // Set is_training correct for the imperative executor.
+  bool orig_is_train;
+  if (ctx.is_train)
+    orig_is_train = Imperative::Get()->set_is_training(true);
+  else
+    orig_is_train = Imperative::Get()->is_training();
+  s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
+  Imperative::Get()->set_is_training(orig_is_train);
+  Imperative::Get()->set_is_recording(orig_is_record);
+  // The arrays in out_ptrs may be changed by CachedOp.
+  // If it is, we need to copy data back.
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    if (!out_bufs[i].IsSame(outputs[i]))
+      CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+/*
+ * This is the backward computation when CachedOp is used as an operator in
+ * a symbol executor.
+ */
+void CachedOpBackward(const OpStatePtr& state_ptr,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs;
+  std::vector<NDArray *> out_ptrs;
+  CHECK_EQ(s.op->num_backward_inputs(), inputs.size());
+  in_ptrs.reserve(s.op->num_backward_inputs());
+  out_ptrs.reserve(s.op->num_inputs());
+
+  const std::vector<bool> &save_inputs = s.op->save_inputs();
+  const std::vector<bool> &save_outputs = s.op->save_outputs();
+  size_t bwd_in_dep = s.op->num_inputs();
+  size_t bwd_out_dep = s.op->num_outputs();
+  CHECK(s.op->num_backward_inputs() > bwd_in_dep + bwd_out_dep);
+  size_t bwd_ograd_dep = s.op->num_backward_inputs() - bwd_in_dep - 
bwd_out_dep;
+
+  // Find inputs, outputs and ograds
+  auto ograds_begin = in_bufs.begin();
+  auto ograds_end = in_bufs.begin() + bwd_ograd_dep;
+  auto in_begin = ograds_end;
+  auto in_end = in_begin + bwd_in_dep;
+  auto out_begin = in_end;
+  auto out_end = in_bufs.end();
+
+  for (auto it = ograds_begin; it != ograds_end; it++)
+    in_ptrs.push_back(&(*it));
+
+  CHECK_EQ(save_inputs.size(), in_end - in_begin);
+  CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
+  for (auto it = in_begin; it != in_end; it++) {
+    auto i = it - in_begin;
+    if (save_inputs[i])
+      in_ptrs.push_back(&(*it));
   }
-  DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
-  return true;
+  for (auto it = out_begin; it != out_end; it++) {
+    auto i = it - out_begin;
+    if (save_outputs[i])
+      in_ptrs.push_back(&(*it));
+  }
+  CHECK_EQ(in_ptrs.size(), s.op->num_backward_inputs());
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    out_ptrs.push_back(&out_bufs[i]);
+  CHECK_EQ(out_ptrs.size(), s.op->num_backward_outputs());
+  // Set is_training correct for the imperative executor.
+  bool orig_is_train;
+  if (ctx.is_train)
+    orig_is_train = Imperative::Get()->set_is_training(true);
+  else
+    orig_is_train = Imperative::Get()->is_training();
+  // TODO(zhengda) is it right to use false here?
+  s.op->Backward(false, s.forward_state, in_ptrs, req, out_ptrs);
+  Imperative::Get()->set_is_training(orig_is_train);
+
+  // Clean up what we recorded.
+  s.forward_state.reset();
+
+  // The arrays in out_ptrs may be changed by CachedOp.
 
 Review comment:
   Thanks for updating the comments

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to