wkcn commented on a change in pull request #14542: Support SyncBatchNorm5D
URL: https://github.com/apache/incubator-mxnet/pull/14542#discussion_r284125264
##########
File path: tests/python/unittest/test_gluon.py
##########
@@ -583,6 +583,126 @@ def test_batchnorm():
check_layer_forward(layer, (2, 10, 10, 10))
+@with_seed()
+def test_sync_batchnorm():
+ def _check_batchnorm_result(input, num_devices=1, cuda=False):
+ from mxnet.gluon.utils import split_and_load
+
+ def _find_bn(module):
+ if isinstance(module, (mx.gluon.nn.BatchNorm,
mx.gluon.contrib.nn.SyncBatchNorm)):
+ return module
+ elif isinstance(module.module, (mx.gluon.nn.BatchNorm,
mx.gluon.contrib.nn.SyncBatchNorm)):
+ return module.module
+
+ raise RuntimeError('BN not found')
+
+ def _syncParameters(bn1, bn2, ctx):
+ ctx = input.context
+ bn2.gamma.set_data(bn1.gamma.data(ctx))
+ bn2.beta.set_data(bn1.beta.data(ctx))
+ bn2.running_mean.set_data(bn1.running_mean.data(ctx))
+ bn2.running_var.set_data(bn1.running_var.data(ctx))
+
+ input1 = input.copy()
+ input2 = input.copy()
+
+ if cuda:
+ input1 = input.as_in_context(mx.gpu(0))
+ ctx_list = [mx.gpu(i) for i in range(num_devices)]
+ else:
+ ctx_list = [mx.cpu(0) for _ in range(num_devices)]
+
+ nch = input.shape[1] if input.ndim > 1 else 1
+ bn1 = mx.gluon.nn.BatchNorm(in_channels=nch)
+ bn2 = mx.gluon.contrib.nn.SyncBatchNorm(
+ in_channels=nch, num_devices=num_devices)
+
+ bn1.initialize(ctx=ctx_list[0])
+ bn2.initialize(ctx=ctx_list)
+
+ # using the same values for gamma and beta
+ #_syncParameters(_find_bn(bn1), _find_bn(bn2), ctx_list[0])
+
+ input1.attach_grad()
+ inputs2 = split_and_load(input2, ctx_list, batch_axis=0)
+ for xi in inputs2:
+ xi.attach_grad()
+
+ with mx.autograd.record():
+ output1 = bn1(input1)
+ output2 = [bn2(xi) for xi in inputs2]
+ loss1 = (output1 ** 2).sum()
+ loss2 = [(output ** 2).sum() for output in output2]
+ mx.autograd.backward(loss1)
+ mx.autograd.backward(loss2)
+
+ output2 = mx.nd.concat(*[output.as_in_context(input.context)
+ for output in output2], dim=0)
+ # check bn1
+
+ momentum = 0.9
+ epsilon = 1e-5
+ axis = 1
+ data = input1
+ running_mean = mx.nd.zeros(nch, ctx=data.context)
+ running_var = mx.nd.ones(nch, ctx=data.context)
+
+ data_mean = data.mean(
+ axis=axis, exclude=True, keepdims=True)
+ data_var = (data - data_mean).square().mean(axis=axis,
+ exclude=True,
keepdims=True)
+
+ target_output = (data - data_mean) / (data_var + epsilon).sqrt()
+
+ # squeeze data_mean and data_var
+ data_mean_flat = data_mean.squeeze()
+ data_var_flat = data_var.squeeze()
+
+ running_mean = running_mean * momentum + \
+ data_mean_flat * (1 - momentum)
+ running_var = running_var * momentum + \
+ data_var_flat * (1 - momentum)
+
+ atol = 1e-2
+ rtol = 1e-2
+ assert_almost_equal(output1.asnumpy(), target_output.asnumpy(),
+ atol=atol, rtol=rtol)
+
assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(),
+ running_mean.asnumpy(),
+ atol=atol, rtol=rtol)
+
assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(),
+ running_var.asnumpy(),
+ atol=atol, rtol=rtol)
+ # assert forwarding
+ assert_almost_equal(input1.asnumpy(), input2.asnumpy(),
+ atol=atol, rtol=rtol)
+ assert_almost_equal(output1.asnumpy(),
+ output2.asnumpy(), atol=atol, rtol=rtol)
+
assert_almost_equal(_find_bn(bn1).running_mean.data(ctx_list[0]).asnumpy(),
+
_find_bn(bn2).running_mean.data(ctx_list[0]).asnumpy(),
+ atol=atol, rtol=rtol)
+
assert_almost_equal(_find_bn(bn1).running_var.data(ctx_list[0]).asnumpy(),
+
_find_bn(bn2).running_var.data(ctx_list[0]).asnumpy(),
+ atol=atol, rtol=rtol)
+ input2grad = mx.nd.concat(
+ *[output.grad.as_in_context(input.context) for output in inputs2],
dim=0)
+ assert_almost_equal(input1.grad.asnumpy(),
+ input2grad.asnumpy(), atol=atol, rtol=rtol)
+
+ cfgs = [(1, False)]
+ num_gpus = mx.context.num_gpus()
Review comment:
I don't know why a unknown CUDA error was raised.
https://github.com/apache/incubator-mxnet/blob/master/include/mxnet/base.h#L424
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services