piiswrong closed pull request #11951: [MXNET-750] fix nested call on CachedOp.
URL: https://github.com/apache/incubator-mxnet/pull/11951
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index d4da99ea9e8..1e7f8e0de1b 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -821,12 +821,11 @@ OpStatePtr CachedOp::DynamicForward(
 
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
-  if (recording && !inlining_) Imperative::Get()->set_is_recording(false);
-
+  // If we are already recording, we don't need RunGraph to record all
+  // computation again.
   RunGraph(false, idx, arrays, 0, idx.num_nodes(), std::move(array_reqs),
-           std::move(ref_count), &states, dispatch_modes);
-
-  Imperative::Get()->set_is_recording(recording);
+           std::move(ref_count), &states, dispatch_modes,
+           !recording || inlining_);
 
   return op_state;
 }
@@ -947,7 +946,8 @@ void CachedOp::DynamicBackward(
   const auto& dispatch_modes = g.GetAttr<DispatchModeVector>("dispatch_mode");
 
   RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
-           std::move(array_reqs), std::move(ref_count), &states, 
dispatch_modes);
+           std::move(array_reqs), std::move(ref_count), &states, 
dispatch_modes,
+           Imperative::Get()->is_recording());
 
   if (retain_graph) {
     buff.resize(num_forward_entries);
diff --git a/src/imperative/imperative.cc b/src/imperative/imperative.cc
index e1654259a2f..0c5ff841775 100644
--- a/src/imperative/imperative.cc
+++ b/src/imperative/imperative.cc
@@ -495,7 +495,8 @@ std::vector<NDArray*> Imperative::Backward(
   int prev_bulk_size = Engine::Get()->set_bulk_size(backward_bulk_size_);
 
   RunGraph(retain_graph, idx, arrays, num_forward_nodes, idx.num_nodes(),
-           std::move(array_reqs), std::move(ref_count), &states, 
dispatch_modes);
+           std::move(array_reqs), std::move(ref_count), &states, 
dispatch_modes,
+           is_recording());
 
   Engine::Get()->set_bulk_size(prev_bulk_size);
   set_is_recording(prev_recording);
diff --git a/src/imperative/imperative_utils.cc 
b/src/imperative/imperative_utils.cc
index 464aefc220d..c84a3b9be50 100644
--- a/src/imperative/imperative_utils.cc
+++ b/src/imperative/imperative_utils.cc
@@ -30,7 +30,8 @@ void RunGraph(
     std::vector<OpReqType>&& array_reqs,
     std::vector<uint32_t>&& ref_count,
     std::vector<OpStatePtr> *p_states,
-    const DispatchModeVector &dispatch_modes) {
+    const DispatchModeVector &dispatch_modes,
+    bool recording) {
   using namespace nnvm;
   using namespace imperative;
   static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
@@ -40,7 +41,6 @@ void RunGraph(
   const auto imp = Imperative::Get();
 
   std::vector<OpStatePtr>& states = *p_states;
-  bool recording = imp->is_recording();
 
   std::vector<NDArray*> ndinputs, ndoutputs;
   ShapeVector arg_shapes;
diff --git a/src/imperative/imperative_utils.h 
b/src/imperative/imperative_utils.h
index 6daf96e60d0..9c86843ca7a 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -994,7 +994,8 @@ void RunGraph(const bool retain_graph,
               std::vector<OpReqType>&& array_reqs,
               std::vector<uint32_t>&& ref_count,
               std::vector<OpStatePtr> *p_states,
-              const DispatchModeVector &dispatch_modes);
+              const DispatchModeVector &dispatch_modes,
+              bool recording);
 
 }  // namespace imperative
 }  // namespace mxnet
diff --git a/tests/python/unittest/test_contrib_control_flow.py 
b/tests/python/unittest/test_contrib_control_flow.py
index 67ed78ee030..f1188b53d81 100644
--- a/tests/python/unittest/test_contrib_control_flow.py
+++ b/tests/python/unittest/test_contrib_control_flow.py
@@ -1159,6 +1159,7 @@ def check_contrib_rnn(cell_type, num_states):
 
     configs = [
             {},
+            {'inline_limit': 0},
             {'static_alloc': True},
             {'static_alloc': True, 'static_shape': True} ]
     for config in configs:


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to