ptrendx opened a new pull request #16553: Fix for wrong reqs set after 
switching from training to inference
URL: https://github.com/apache/incubator-mxnet/pull/16553
 
 
   ## Description ##
   `StaticAllocMemory` in `CachedOp` used `storage_inplace_index` attribute 
(generated by `MXPlanMemory` graph pass) to assign reqs for edges in the graph. 
However, the `MXPlanMemory` pass is called only once per the type of memory 
plan (full, forward or backward). While the memory plan itself is stored in 
separate attributes in the graph (`forward_mem_plan`, `full_mem_plan` etc.), 
the `storage_inplace_index` (and so the reqs) were overwritten by the last 
called `MXPlanMemory`.
   
   The following code:
   
   ```
   with mx.autograd.record():
      result = net(x)
   result.backward()
   
   result2 = net(x)
   
   with mx.autograd.record():
      result3 = net(x)
   result3.backward()
   ```
   
   calls first the plan memory for the full graph and then for just the forward 
graph. The third invocation to `net` does not invoke a plan memory pass. Let us 
assume that inside `net` is an op that produces the output needed for the 
backward pass only when needed (req not set to `kNullOp`). Since the reqs are 
overwritten by the second `net` invocation, that output's req is set to 
`kNullOp` (because there is no backward pass there). Then the 3rd invocation to 
`net` does not change the req value and so the op does not produces the 
required output - `result3` gradient is therefore produced using the stale 
values and so is wrong.
   
   This PR fixes it by changing the `StaticAllocMemory` to use per mem plan 
values to assign reqs (`storage_inplace_index_forward` etc.) to keep the 
benefits of caching (`MXPlanMemory` called once per type) while ensuring 
correctness.
   
   @eric-haibin-lin 
   
   ## Checklist ##
   ### Essentials ###
   Please feel free to remove inapplicable items for your PR.
   - [x] Changes are complete (i.e. I finished coding on this PR)
   - [ ] All changes have test coverage:
   - Unit tests are added for small changes to verify correctness (e.g. adding 
a new operator)

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to