piiswrong closed pull request #10470: CUDNN not training on backward pass 
similar to pytorch
URL: https://github.com/apache/incubator-mxnet/pull/10470
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h 
b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index e3d5dd9204b..d4b9f84ed2f 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -165,8 +165,6 @@ class CuDNNBatchNormOp {
     using namespace mshadow::expr;
     CHECK_EQ(inputs.size(), 8U);
     CHECK_EQ(outputs.size(), 3U);
-    CHECK(ctx.is_train && !param_.use_global_stats)
-        << "use global statistics is not yet supported in CuDNNBatchNorm";
 
     // Rename the inputs and outputs.
     const TBlob &out_grad = inputs[0];
@@ -183,6 +181,8 @@ class CuDNNBatchNormOp {
       in_grad[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s);
     Tensor<gpu, 4, DType> dy = out_grad.get_with_shape<gpu, 4, DType>(shape_, 
s);
 
+    const bool global_stats = !ctx.is_train || param_.use_global_stats;
+
 #if CUDNN_VERSION >= 4007
 #if CUDNN_VERSION >= 7002
     auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT;
@@ -226,8 +226,8 @@ class CuDNNBatchNormOp {
         dgamma.dptr_,
         dbeta.dptr_,
         param_.eps,
-        save_mean.dptr_,
-        save_inv_var.dptr_));
+        global_stats ? nullptr : save_mean.dptr_,
+        global_stats ? nullptr : save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
     })
 #else  // CUDNN_VERSION < 4007
@@ -264,8 +264,8 @@ class CuDNNBatchNormOp {
                                                  dgamma.dptr_,
                                                  dbeta.dptr_,
                                                  param_.eps,
-                                                 save_mean.dptr_,
-                                                 save_inv_var.dptr_));
+                                                 global_stats ? nullptr : 
save_mean.dptr_,
+                                                 global_stats ? nullptr : 
save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
     })
 #endif
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index 2dd66ee2d10..36af524737d 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -27,6 +27,7 @@
 from nose.tools import assert_raises
 from mxnet.test_utils import check_consistency, set_default_context, 
assert_almost_equal
 from mxnet.base import MXNetError
+from mxnet import autograd
 from numpy.testing import assert_allclose
 
 curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
@@ -1786,7 +1787,24 @@ def test_incorrect_gpu():
     # Try setting dev_id to a really big number
     assert_raises(MXNetError, mx.nd.ones, (2,2), ctx=mx.gpu(100001))
 
+@with_seed()
+def test_batchnorm_backwards_notrain():
+    for ctx in [mx.cpu(0), mx.gpu(0)]:
+        for cudnn_o in [False, True]:
+            B,C,H,W = 4,3,2,2
+            x = mx.nd.random.poisson(1,shape=(B,C,H,W)).as_in_context(ctx)
+            gamma = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            beta = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            mean = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            std = mx.nd.random.normal(shape=(C)).as_in_context(ctx)
+            x.attach_grad()
+
+            with autograd.record(False):
+                y = mx.ndarray.BatchNorm(x, gamma, beta, mean, std.square(),
+                                         fix_gamma=False, cudnn_off=cudnn_o)
+                loss=y.square().sum()
+            loss.backward(train_mode=False)
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()
-


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to