larroy commented on a change in pull request #11641: [WIP] make CachedOp a 
normal operator
URL: https://github.com/apache/incubator-mxnet/pull/11641#discussion_r201624918
 
 

 ##########
 File path: src/imperative/cached_op.cc
 ##########
 @@ -1047,6 +1047,105 @@ void CachedOp::Backward(
   Engine::Get()->set_bulk_size(prev_bulk_size);
 }
 
+struct CachedOpActualState {
+  std::shared_ptr<CachedOp> op;
+  OpStatePtr forward_state;
+
+  CachedOpActualState(std::shared_ptr<CachedOp> op) {
+    this->op = op;
+  }
+};
+
+void CachedOpForward(const OpStatePtr& state_ptr,
+                     const OpContext& ctx,
+                     const std::vector<NDArray>& inputs,
+                     const std::vector<OpReqType>& req,
+                     const std::vector<NDArray>& outputs) {
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs(in_bufs.size());
+  std::vector<NDArray *> out_ptrs(out_bufs.size());
+  for (size_t i = 0; i < in_ptrs.size(); i++)
+    in_ptrs[i] = &in_bufs[i];
+  for (size_t i = 0; i < out_ptrs.size(); i++)
+    out_ptrs[i] = &out_bufs[i];
+  s.forward_state = s.op->Forward(nullptr, in_ptrs, out_ptrs);
+  // The arrays in out_ptrs may be changed by CachedOp.
+  // If it is, we need to copy data back.
+  for (size_t i = 0; i < out_bufs.size(); i++)
+    if (!out_bufs[i].IsSame(outputs[i]))
+      CopyFromTo(out_bufs[i], outputs[i]);
+}
+
+void CachedOpBackward(const OpStatePtr& state_ptr,
+                      const OpContext& ctx,
+                      const std::vector<NDArray>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<NDArray>& outputs) {
+  using namespace nnvm;
+  using namespace imperative;
+  CachedOpActualState &s = state_ptr.get_state<CachedOpActualState>();
+  std::vector<NDArray> in_bufs = inputs;
+  std::vector<NDArray> out_bufs = outputs;
+  std::vector<NDArray *> in_ptrs;
+  std::vector<NDArray *> out_ptrs;
+  in_ptrs.reserve(s.op->num_backward_inputs());
+  out_ptrs.reserve(s.op->num_inputs());
+  for (size_t i = 0; i < in_ptrs.size(); i++)
+    in_ptrs[i] = &in_bufs[i];
+  for (size_t i = 0; i < out_ptrs.size(); i++)
+    out_ptrs[i] = &out_bufs[i];
+
+  auto ograds_begin = in_bufs.begin();
+  auto ograds_end = in_bufs.begin() + s.op->num_outputs();
+  auto in_begin = ograds_end;
+  auto in_end = in_begin + s.op->num_inputs();
+  auto out_begin = in_end;
+  auto out_end = in_bufs.end();
+
+  for (auto it = ograds_begin; it != ograds_end; it++)
+    in_ptrs.push_back(&(*it));
+
+  const std::vector<bool> &save_inputs = s.op->save_inputs();
+  const std::vector<bool> &save_outputs = s.op->save_outputs();
+  CHECK_EQ(save_inputs.size(), in_end - in_begin);
+  CHECK_EQ(s.op->num_outputs(), out_end - out_begin);
+  for (auto it = in_begin; it != in_end; it++) {
 
 Review comment:
   ++it is potentially faster

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to