ptrendx opened a new pull request #16553: Fix for wrong reqs set after switching from training to inference URL: https://github.com/apache/incubator-mxnet/pull/16553 ## Description ## `StaticAllocMemory` in `CachedOp` used `storage_inplace_index` attribute (generated by `MXPlanMemory` graph pass) to assign reqs for edges in the graph. However, the `MXPlanMemory` pass is called only once per the type of memory plan (full, forward or backward). While the memory plan itself is stored in separate attributes in the graph (`forward_mem_plan`, `full_mem_plan` etc.), the `storage_inplace_index` (and so the reqs) were overwritten by the last called `MXPlanMemory`. The following code: ``` with mx.autograd.record(): result = net(x) result.backward() result2 = net(x) with mx.autograd.record(): result3 = net(x) result3.backward() ``` calls first the plan memory for the full graph and then for just the forward graph. The third invocation to `net` does not invoke a plan memory pass. Let us assume that inside `net` is an op that produces the output needed for the backward pass only when needed (req not set to `kNullOp`). Since the reqs are overwritten by the second `net` invocation, that output's req is set to `kNullOp` (because there is no backward pass there). Then the 3rd invocation to `net` does not change the req value and so the op does not produces the required output - `result3` gradient is therefore produced using the stale values and so is wrong. This PR fixes it by changing the `StaticAllocMemory` to use per mem plan values to assign reqs (`storage_inplace_index_forward` etc.) to keep the benefits of caching (`MXPlanMemory` called once per type) while ensuring correctness. @eric-haibin-lin ## Checklist ## ### Essentials ### Please feel free to remove inapplicable items for your PR. - [x] Changes are complete (i.e. I finished coding on this PR) - [ ] All changes have test coverage: - Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services