This is an automated email from the ASF dual-hosted git repository.
skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 0bab6d5 [MXNET-979] Add fix_beta support in BatchNorm (#12625)
0bab6d5 is described below
commit 0bab6d529343f0ce186859ba75c9bb02067e9cfe
Author: Sandeep Krishnamurthy <[email protected]>
AuthorDate: Wed Oct 10 11:38:34 2018 -0700
[MXNET-979] Add fix_beta support in BatchNorm (#12625)
* Add fix_beta support in BatchNorm CPU implementation
* Fix lint checks. Update GPU tests
* Fix gpu tests
* make fix_beta not available for sparse. Update fix_beta for mkldnn
* Make default fix_beta to False for backward compatibility
* Add fix_beta to cudnn batchnorm operator
* Add tests for missing fix_beta and fix_gamma params
* fix indentation
* Fix failing tests
* simplify the cases with defaults for gamma, beta
---
python/mxnet/gluon/nn/basic_layers.py | 2 +-
src/operator/nn/batch_norm-inl.h | 5 ++
src/operator/nn/batch_norm.cc | 59 +++++++++--------
src/operator/nn/batch_norm.cu | 30 +++++++--
src/operator/nn/cudnn/cudnn_batch_norm-inl.h | 4 ++
tests/python/gpu/test_operator_gpu.py | 92 +++++++++++++++++----------
tests/python/mkl/test_mkldnn.py | 2 +-
tests/python/unittest/test_operator.py | 16 ++---
tests/python/unittest/test_sparse_operator.py | 20 ++++--
9 files changed, 150 insertions(+), 80 deletions(-)
diff --git a/python/mxnet/gluon/nn/basic_layers.py
b/python/mxnet/gluon/nn/basic_layers.py
index d268419..26ef64d 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -324,7 +324,7 @@ class BatchNorm(HybridBlock):
in_channels=0, **kwargs):
super(BatchNorm, self).__init__(**kwargs)
self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
- 'fix_gamma': not scale, 'use_global_stats':
use_global_stats}
+ 'fix_gamma': not scale, 'fix_beta': not center,
'use_global_stats': use_global_stats}
if in_channels != 0:
self.in_channels = in_channels
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 3f47d58..f8b381c 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -62,6 +62,7 @@ struct BatchNormParam : public
dmlc::Parameter<BatchNormParam> {
double eps;
float momentum;
bool fix_gamma;
+ bool fix_beta;
bool use_global_stats;
bool output_mean_var;
int axis;
@@ -75,6 +76,8 @@ struct BatchNormParam : public
dmlc::Parameter<BatchNormParam> {
.describe("Momentum for moving average");
DMLC_DECLARE_FIELD(fix_gamma).set_default(true)
.describe("Fix gamma while training");
+ DMLC_DECLARE_FIELD(fix_beta).set_default(false)
+ .describe("Fix beta while training");
DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
.describe("Whether use global moving statistics instead of local
batch-norm. "
"This will force change batch-norm into a scale shift
operator.");
@@ -90,6 +93,7 @@ struct BatchNormParam : public
dmlc::Parameter<BatchNormParam> {
return this->eps == other.eps &&
this->momentum == other.momentum &&
this->fix_gamma == other.fix_gamma &&
+ this->fix_beta == other.fix_beta &&
this->use_global_stats == other.use_global_stats &&
this->output_mean_var == other.output_mean_var &&
this->axis == other.axis &&
@@ -107,6 +111,7 @@ struct hash<mxnet::op::BatchNormParam> {
size_t ret = 0;
ret = dmlc::HashCombine(ret, val.momentum);
ret = dmlc::HashCombine(ret, val.fix_gamma);
+ ret = dmlc::HashCombine(ret, val.fix_beta);
ret = dmlc::HashCombine(ret, val.use_global_stats);
ret = dmlc::HashCombine(ret, val.output_mean_var);
ret = dmlc::HashCombine(ret, val.axis);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index be542ba..ec90a30 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -155,35 +155,34 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *,
// compute output
AccReal *w = weights.dptr<AccReal>();
- const AccReal *b = bias.dptr<AccReal>();
+ AccReal *b = bias.dptr<AccReal>();
+
+ // Ignore gamma
+ if (param_.fix_gamma) {
+ if (IsBNWriting(req[batchnorm::kGamma])) {
+ w[channel] = AccReal(1);
+ }
+ }
+
+ // Ignore beta
+ if (param_.fix_beta) {
+ if (IsBNWriting(req[batchnorm::kBeta])) {
+ b[channel] = AccReal(0);
+ }
+ }
const AccReal thisMean = mean[channel];
const AccReal thisInvstd = var[channel];
const AccReal thisWeight = w[channel];
const AccReal thisBias = b[channel];
- // note that var is still invstd
- if (!param_.fix_gamma) {
- if (IsBNWriting(req[batchnorm::kData])) {
- ForEachFast(inputData, outputData, channel,
- [thisWeight, thisBias, thisMean, thisInvstd](const DType
*in_data,
- DType
*out_data) {
- *out_data = static_cast<DType>(
- ((*in_data - thisMean) * thisInvstd) * thisWeight +
thisBias);
- });
- }
- } else {
- if (IsBNWriting(req[batchnorm::kGamma])) {
- w[channel] = AccReal(1);
- }
- if (IsBNWriting(req[batchnorm::kData])) {
- ForEachFast(inputData, outputData, channel,
- [thisWeight, thisBias, thisMean, thisInvstd](const DType
*in_data,
- DType
*out_data) {
- *out_data = static_cast<DType>(
- ((*in_data - thisMean) * thisInvstd) + thisBias);
- });
- }
+ if (IsBNWriting(req[batchnorm::kData])) {
+ ForEachFast(inputData, outputData, channel,
+ [thisWeight, thisBias, thisMean, thisInvstd](const DType
*in_data,
+ DType
*out_data) {
+ *out_data = static_cast<DType>(
+ ((*in_data - thisMean) * thisInvstd) * thisWeight +
thisBias);
+ });
}
}
}
@@ -309,7 +308,11 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
}
if (IsBNWriting(req[batchnorm::kBeta])) {
- gradBiasData[channel] = scale * sumGradOut;
+ if (!param_.fix_beta) {
+ gradBiasData[channel] = scale * sumGradOut;
+ } else {
+ gradBiasData[channel] = AccReal(0);
+ }
}
}
}
@@ -478,6 +481,9 @@ static inline bool BatchNormStorageType(const
nnvm::NodeAttrs &attrs,
if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) &&
param.fix_gamma) {
LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays.
Tracked at #11647";
}
+ if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) &&
param.fix_beta) {
+ LOG(FATAL) << "fix_beta=True is not supported for sparse ndarrays. Tracked
at #11647";
+ }
return dispatched;
}
@@ -565,11 +571,12 @@ the 'channel' (separately normalized groups). The
default is 1. Specifying -1
axis to be the last item in the input shape.
Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is
true,
-then set ``gamma`` to 1 and its gradient to 0.
+then set ``gamma`` to 1 and its gradient to 0. If ``fix_beta`` is true, then
set ``beta`` to 0
+and its gradient to 0.
Note::
-When fix_gamma is set to True, no sparse support is provided. If fix_gamma is
set to False,
+When fix_gamma/fix_beta is set to True, no sparse support is provided. If
fix_gamma/fix_beta is set to False,
the sparse tensors will fallback.
)code" ADD_FILELINE)
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 03962cb..309542d 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -32,8 +32,9 @@
#define WRITE_GAMMA_FLAG 2
#define WRITE_BETA_FLAG 4
#define FIX_GAMMA_FLAG 8
-#define IS_TRAINING_FLAG 16
-#define USE_GLOBAL_STATS_FLAG 32
+#define FIX_BETA_FLAG 16
+#define IS_TRAINING_FLAG 32
+#define USE_GLOBAL_STATS_FLAG 64
#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
#include "./cudnn/cudnn_batch_norm-inl.h"
@@ -223,8 +224,9 @@ __global__ void
BatchNormalizationUpdateOutputInferenceKernel(
AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
- AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType,
AccReal>::to(bias[plane])
- : ScalarConvert<int, AccReal>::to(0);
+ AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
+ ? ScalarConvert<DType, AccReal>::to(bias[plane])
+ : ScalarConvert<int, AccReal>::to(0);
if (threadIdx.x == 0) {
saveMean[plane] = runningMean[plane];
saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
@@ -232,6 +234,10 @@ __global__ void
BatchNormalizationUpdateOutputInferenceKernel(
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
+ if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
+ && bias.numElements() > 0) {
+ bias[plane] = AccReal(0);
+ }
}
// Write normalized and update the output
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
@@ -282,14 +288,19 @@ __global__ void BatchNormalizationUpdateOutputKernel(
&& weight.numElements() > 0) {
weight[plane] = AccReal(1);
}
+ if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
+ && bias.numElements() > 0) {
+ bias[plane] = AccReal(0);
+ }
}
// Write normalized and update the output
const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements()
> 0)
? ScalarConvert<DType, AccReal>::to(weight[plane])
: ScalarConvert<int, AccReal>::to(1);
- const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType,
AccReal>::to(bias[plane])
- : ScalarConvert<int,
AccReal>::to(0);
+ const AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
+ ? ScalarConvert<DType, AccReal>::to(bias[plane])
+ : ScalarConvert<int, AccReal>::to(0);
for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x)
{
const DType inp = input.get_ref(batch, plane, x);
@@ -388,7 +399,11 @@ static __global__ void BatchNormalizationBackwardKernel(
}
if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags &
WRITE_BETA_FLAG) != 0) {
- tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+ if ((flags & FIX_BETA_FLAG) == 0) {
+ tensors.gradBias[plane] = ScalarConvert<AccReal,
DType>::to(gradOutputSum);
+ } else {
+ tensors.gradBias[plane] = DType(0);
+ }
}
}
@@ -582,6 +597,7 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
uint32_t flags = 0;
flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
+ flags |= params.fix_beta ? FIX_BETA_FLAG : 0;
flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
if (IsBNWriting(req[batchnorm::kData])) {
flags |= WRITE_DATA_FLAG;
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index d4b9f84..9caa9d3 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -115,6 +115,8 @@ class CuDNNBatchNormOp {
if (param_.fix_gamma) gamma = 1.f;
+ if (param_.fix_beta) beta = 0.f;
+
if (ctx.is_train) {
Tensor<gpu, 1, DTypeParam> save_mean =
out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1,
DTypeParam>(Shape1(shape_[1]), s);
@@ -229,6 +231,7 @@ class CuDNNBatchNormOp {
global_stats ? nullptr : save_mean.dptr_,
global_stats ? nullptr : save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
+ if (param_.fix_beta) dbeta = 0.f;
})
#else // CUDNN_VERSION < 4007
MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
@@ -267,6 +270,7 @@ class CuDNNBatchNormOp {
global_stats ? nullptr :
save_mean.dptr_,
global_stats ? nullptr :
save_inv_var.dptr_));
if (param_.fix_gamma) dgamma = 0.f;
+ if (param_.fix_beta) dbeta = 0.f;
})
#endif
}
diff --git a/tests/python/gpu/test_operator_gpu.py
b/tests/python/gpu/test_operator_gpu.py
index dd7ec98..13022c1 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -303,35 +303,52 @@ def test_batchnorm_with_type():
# V2, 2D
- sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False,
cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
- sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True,
cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
- sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False,
cudnn_off=True)
+ check_consistency(sym, ctx_list_v2_2D)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True,
cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
+ # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_2D)
+ # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+ sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
+ check_consistency(sym, ctx_list_v2_2D)
# V2, 1D
- sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False,
cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
- sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True,
cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
- sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False,
cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True,
cudnn_off=True)
+ check_consistency(sym, ctx_list_v2_1D)
+ # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_1D)
- #
+ # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+ sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
+ check_consistency(sym, ctx_list_v2_1D)
+
# # V2, 3D
- sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True,
cudnn_off=True)
check_consistency(sym, ctx_list_v2_3D)
+ sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False,
cudnn_off=True)
+ check_consistency(sym, ctx_list_v2_3D)
+ # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
check_consistency(sym, ctx_list_v2_3D)
-
+ # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+ sym = mx.sym.BatchNorm(name='norm', fix_beta=False, cudnn_off=True)
+ check_consistency(sym, ctx_list_v2_3D)
@with_seed()
def test_batchnorm_versions():
- def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma,
use_global_stats):
+ def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma,
fix_beta, use_global_stats):
ctx_list = []
sym_list = []
# BatchNormV1 cpu
@@ -352,6 +369,7 @@ def test_batchnorm_versions():
if 'batchnorm_cpu' in batchnorm_op_list:
ctx_list.append({'ctx': mx.cpu(0), 'batchnorm_data': data, 'type_dict':
{'batchnorm_data': np.float32}})
sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+ fix_beta=fix_beta,
use_global_stats=use_global_stats,
name='batchnorm'))
@@ -359,6 +377,7 @@ def test_batchnorm_versions():
if 'batchnorm_gpu' in batchnorm_op_list:
ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict':
{'batchnorm_data': np.float32}})
sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+ fix_beta=fix_beta,
use_global_stats=use_global_stats,
name='batchnorm', cudnn_off=True))
@@ -366,47 +385,54 @@ def test_batchnorm_versions():
if 'batchnorm_cudnn' in batchnorm_op_list:
ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict':
{'batchnorm_data': np.float32}})
sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+ fix_beta=fix_beta,
use_global_stats=use_global_stats,
name='batchnorm', cudnn_off=False))
check_consistency(sym_list, ctx_list)
- def test_1d_batchnorm(fix_gamma, use_global_stats):
+ def test_1d_batchnorm(fix_gamma, fix_beta, use_global_stats):
data = (2, 3, 20)
test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
'batchnorm_gpu',
'batchnorm_cudnn'],
data=data,
- fix_gamma=fix_gamma,
use_global_stats=use_global_stats)
+ fix_gamma=fix_gamma, fix_beta=fix_beta,
use_global_stats=use_global_stats)
- def test_2d_batchnorm(fix_gamma, use_global_stats):
+ def test_2d_batchnorm(fix_gamma, fix_beta, use_global_stats):
data = (2, 3, 10, 10)
- test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu',
'batchnorm_v1_gpu',
- 'batchnorm_cpu',
+ # batchmorm_v1 is deprecated.
+ # `fix_beta` parameter is available only in new batchnorm operator.
+ # Checking consistency separately for batchnormv1 and batchnorm.
+ test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu',
'batchnorm_v1_gpu'],
+ data=data,
+ fix_gamma=fix_gamma, fix_beta=fix_beta,
use_global_stats=use_global_stats)
+
+ test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
'batchnorm_gpu',
'batchnorm_cudnn'],
data=data,
- fix_gamma=fix_gamma,
use_global_stats=use_global_stats)
+ fix_gamma=fix_gamma, fix_beta=fix_beta,
use_global_stats=use_global_stats)
- def test_3d_batchnorm(fix_gamma, use_global_stats):
+ def test_3d_batchnorm(fix_gamma, fix_beta, use_global_stats):
data = (2, 3, 3, 5, 5)
test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
'batchnorm_gpu'],
data=data,
- fix_gamma=fix_gamma,
use_global_stats=use_global_stats)
-
- test_1d_batchnorm(True, False)
- test_1d_batchnorm(False, False)
- test_1d_batchnorm(False, True)
- test_1d_batchnorm(True, True)
-
- test_2d_batchnorm(True, False)
- test_2d_batchnorm(False, False)
- test_2d_batchnorm(False, True)
- test_2d_batchnorm(True, True)
-
- test_3d_batchnorm(True, False)
- test_3d_batchnorm(False, False)
- test_3d_batchnorm(False, True)
- test_3d_batchnorm(True, True)
+ fix_gamma=fix_gamma, fix_beta=fix_beta,
use_global_stats=use_global_stats)
+
+ test_1d_batchnorm(True, False, False)
+ test_1d_batchnorm(False, True, False)
+ test_1d_batchnorm(False, False, True)
+ test_1d_batchnorm(True, True, True)
+
+ test_2d_batchnorm(True, False, False)
+ test_2d_batchnorm(False, True, False)
+ test_2d_batchnorm(False, False, True)
+ test_2d_batchnorm(True, True, True)
+
+ test_3d_batchnorm(True, False, False)
+ test_3d_batchnorm(False, True, False)
+ test_3d_batchnorm(False, False, True)
+ test_3d_batchnorm(True, True, True)
@with_seed(1234)
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index e549009..a3c39ff 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -235,7 +235,7 @@ def test_batchnorm():
mx.nd.array(beta).tostype(stype)]
mean_std = [mx.nd.array(rolling_mean).tostype(stype),
mx.nd.array(rolling_std).tostype(stype)]
- test = mx.symbol.BatchNorm(data, fix_gamma=False)
+ test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
stypes = ['row_sparse', 'default']
diff --git a/tests/python/unittest/test_operator.py
b/tests/python/unittest/test_operator.py
index 5332517..5a5d956 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1534,25 +1534,25 @@ def test_batchnorm_training():
test = mx.symbol.BatchNorm_v1(data, fix_gamma=True)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
- test = mx.symbol.BatchNorm(data, fix_gamma=True)
+ test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
test = mx.symbol.BatchNorm_v1(data, fix_gamma=True,
use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
- test = mx.symbol.BatchNorm(data, fix_gamma=True,
use_global_stats=True)
+ test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True,
use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
test = mx.symbol.BatchNorm_v1(data, fix_gamma=False)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
- test = mx.symbol.BatchNorm(data, fix_gamma=False)
+ test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
test = mx.symbol.BatchNorm_v1(data, fix_gamma=False,
use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
- test = mx.symbol.BatchNorm(data, fix_gamma=False,
use_global_stats=True)
+ test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False,
use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std,
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
# Test varying channel axis
@@ -1581,16 +1581,16 @@ def test_batchnorm_training():
xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
mx.nd.array(xrolling_std).tostype(stype)]
- test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
+ test = mx.symbol.BatchNorm(data, fix_gamma=True,
fix_beta=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std,
numeric_eps=1e-2, rtol=0.2, atol=0.01)
- test = mx.symbol.BatchNorm(data, fix_gamma=True,
use_global_stats=True, axis=chaxis)
+ test = mx.symbol.BatchNorm(data, fix_gamma=True,
fix_beta=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std,
numeric_eps=1e-2, rtol=0.2, atol=0.01)
- test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+ test = mx.symbol.BatchNorm(data, fix_gamma=False,
fix_beta=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std,
numeric_eps=1e-2, rtol=0.2, atol=0.01)
- test = mx.symbol.BatchNorm(data, fix_gamma=False,
use_global_stats=True, axis=chaxis)
+ test = mx.symbol.BatchNorm(data, fix_gamma=False,
fix_beta=False, use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std,
numeric_eps=1e-2, rtol=0.2, atol=0.01)
check_batchnorm_training('default')
diff --git a/tests/python/unittest/test_sparse_operator.py
b/tests/python/unittest/test_sparse_operator.py
index 5780824..bddab11 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -2124,13 +2124,19 @@ def test_batchnorm_fallback():
test = mx.symbol.BatchNorm(data, fix_gamma=True)
assertRaises(MXNetError, check_numeric_gradient, test, in_location,
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
+ test = mx.symbol.BatchNorm(data, fix_beta=True)
+ assertRaises(MXNetError, check_numeric_gradient, test, in_location,
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
+
test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
assertRaises(MXNetError, check_numeric_gradient, test, in_location,
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
- test = mx.symbol.BatchNorm(data, fix_gamma=False)
+ test = mx.symbol.BatchNorm(data, fix_beta=True, use_global_stats=True)
+ assertRaises(MXNetError, check_numeric_gradient, test, in_location,
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
+
+ test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3,
rtol=0.16, atol=1e-2)
- test = mx.symbol.BatchNorm(data, fix_gamma=False,
use_global_stats=True)
+ test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False,
use_global_stats=True)
check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3,
rtol=0.16, atol=1e-2)
# Test varying channel axis
@@ -2161,14 +2167,20 @@ def test_batchnorm_fallback():
test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
assertRaises(MXNetError, check_numeric_gradient, test,
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
+
+ test = mx.symbol.BatchNorm(data, fix_beta=True, axis=chaxis)
+ assertRaises(MXNetError, check_numeric_gradient, test,
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
test = mx.symbol.BatchNorm(data, fix_gamma=True,
use_global_stats=True, axis=chaxis)
assertRaises(MXNetError, check_numeric_gradient, test,
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
- test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+ test = mx.symbol.BatchNorm(data, fix_beta=True,
use_global_stats=True, axis=chaxis)
+ assertRaises(MXNetError, check_numeric_gradient, test,
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
+
+ test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False,
axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std,
numeric_eps=1e-3, rtol=0.2, atol=0.01)
- test = mx.symbol.BatchNorm(data, fix_gamma=False,
use_global_stats=True, axis=chaxis)
+ test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False,
use_global_stats=True, axis=chaxis)
check_numeric_gradient(test, in_location, xmean_std,
numeric_eps=1e-3, rtol=0.2, atol=0.01)