wkcn commented on a change in pull request #18492:
URL: https://github.com/apache/incubator-mxnet/pull/18492#discussion_r436195297



##########
File path: src/operator/nn/cudnn/cudnn_batch_norm-inl.h
##########
@@ -228,7 +228,7 @@ class CuDNNBatchNormOp {
         &a,
         &b,
         &a,
-        req[cudnnbatchnorm::kGamma] == kWriteTo ? &b: &b_add,
+        req[cudnnbatchnorm::kGamma] == kAddTo ? &b_add : &b,

Review comment:
       The gradient of input, gamma and beta on CPU is wrong when `grad_req` is 
True. The gradient of input is not accumulated. The gradient of gamma and beta 
are both zero.
   ```python
   import mxnet as mx
   from mxnet.gluon import nn
   
   N = 1
   C = 3
   H = W = 2
   block = nn.BatchNorm() 
   block.collect_params().initialize()
   block.collect_params().setattr('grad_req', 'add')
   
   x = mx.nd.arange(N*C*H*W).reshape((N, C, H, W))
   x.attach_grad()
   for i in range(2):
       with mx.autograd.record():
           y = block(x)
           loss = (y * y).sum() 
       loss.backward()
   print(x.grad, block.gamma.grad(), block.beta.grad())
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to