piiswrong closed pull request #10470: CUDNN not training on backward pass similar to pytorch URL: https://github.com/apache/incubator-mxnet/pull/10470
This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h index e3d5dd9204b..d4b9f84ed2f 100644 --- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h +++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h @@ -165,8 +165,6 @@ class CuDNNBatchNormOp { using namespace mshadow::expr; CHECK_EQ(inputs.size(), 8U); CHECK_EQ(outputs.size(), 3U); - CHECK(ctx.is_train && !param_.use_global_stats) - << "use global statistics is not yet supported in CuDNNBatchNorm"; // Rename the inputs and outputs. const TBlob &out_grad = inputs[0]; @@ -183,6 +181,8 @@ class CuDNNBatchNormOp { in_grad[cudnnbatchnorm::kData].get_with_shape<gpu, 4, DType>(shape_, s); Tensor<gpu, 4, DType> dy = out_grad.get_with_shape<gpu, 4, DType>(shape_, s); + const bool global_stats = !ctx.is_train || param_.use_global_stats; + #if CUDNN_VERSION >= 4007 #if CUDNN_VERSION >= 7002 auto mode = CUDNN_BATCHNORM_SPATIAL_PERSISTENT; @@ -226,8 +226,8 @@ class CuDNNBatchNormOp { dgamma.dptr_, dbeta.dptr_, param_.eps, - save_mean.dptr_, - save_inv_var.dptr_)); + global_stats ? nullptr : save_mean.dptr_, + global_stats ? nullptr : save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; }) #else // CUDNN_VERSION < 4007 @@ -264,8 +264,8 @@ class CuDNNBatchNormOp { dgamma.dptr_, dbeta.dptr_, param_.eps, - save_mean.dptr_, - save_inv_var.dptr_)); + global_stats ? nullptr : save_mean.dptr_, + global_stats ? nullptr : save_inv_var.dptr_)); if (param_.fix_gamma) dgamma = 0.f; }) #endif diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 2dd66ee2d10..36af524737d 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -27,6 +27,7 @@ from nose.tools import assert_raises from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal from mxnet.base import MXNetError +from mxnet import autograd from numpy.testing import assert_allclose curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) @@ -1786,7 +1787,24 @@ def test_incorrect_gpu(): # Try setting dev_id to a really big number assert_raises(MXNetError, mx.nd.ones, (2,2), ctx=mx.gpu(100001)) +@with_seed() +def test_batchnorm_backwards_notrain(): + for ctx in [mx.cpu(0), mx.gpu(0)]: + for cudnn_o in [False, True]: + B,C,H,W = 4,3,2,2 + x = mx.nd.random.poisson(1,shape=(B,C,H,W)).as_in_context(ctx) + gamma = mx.nd.random.normal(shape=(C)).as_in_context(ctx) + beta = mx.nd.random.normal(shape=(C)).as_in_context(ctx) + mean = mx.nd.random.normal(shape=(C)).as_in_context(ctx) + std = mx.nd.random.normal(shape=(C)).as_in_context(ctx) + x.attach_grad() + + with autograd.record(False): + y = mx.ndarray.BatchNorm(x, gamma, beta, mean, std.square(), + fix_gamma=False, cudnn_off=cudnn_o) + loss=y.square().sum() + loss.backward(train_mode=False) + if __name__ == '__main__': import nose nose.runmodule() - ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services