This is an automated email from the ASF dual-hosted git repository.

skm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 0bab6d5  [MXNET-979] Add fix_beta support in BatchNorm (#12625)
0bab6d5 is described below

commit 0bab6d529343f0ce186859ba75c9bb02067e9cfe
Author: Sandeep Krishnamurthy <[email protected]>
AuthorDate: Wed Oct 10 11:38:34 2018 -0700

    [MXNET-979] Add fix_beta support in BatchNorm (#12625)
    
    * Add fix_beta support in BatchNorm CPU implementation
    
    * Fix lint checks. Update GPU tests
    
    * Fix gpu tests
    
    * make fix_beta not available for sparse. Update fix_beta for mkldnn
    
    * Make default fix_beta to False for backward compatibility
    
    * Add fix_beta to cudnn batchnorm operator
    
    * Add tests for missing fix_beta and fix_gamma params
    
    * fix indentation
    
    * Fix failing tests
    
    * simplify the cases with defaults for gamma, beta
---
 python/mxnet/gluon/nn/basic_layers.py         |  2 +-
 src/operator/nn/batch_norm-inl.h              |  5 ++
 src/operator/nn/batch_norm.cc                 | 59 +++++++++--------
 src/operator/nn/batch_norm.cu                 | 30 +++++++--
 src/operator/nn/cudnn/cudnn_batch_norm-inl.h  |  4 ++
 tests/python/gpu/test_operator_gpu.py         | 92 +++++++++++++++++----------
 tests/python/mkl/test_mkldnn.py               |  2 +-
 tests/python/unittest/test_operator.py        | 16 ++---
 tests/python/unittest/test_sparse_operator.py | 20 ++++--
 9 files changed, 150 insertions(+), 80 deletions(-)

diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index d268419..26ef64d 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -324,7 +324,7 @@ class BatchNorm(HybridBlock):
                  in_channels=0, **kwargs):
         super(BatchNorm, self).__init__(**kwargs)
         self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': momentum,
-                        'fix_gamma': not scale, 'use_global_stats': 
use_global_stats}
+                        'fix_gamma': not scale, 'fix_beta': not center, 
'use_global_stats': use_global_stats}
         if in_channels != 0:
             self.in_channels = in_channels
 
diff --git a/src/operator/nn/batch_norm-inl.h b/src/operator/nn/batch_norm-inl.h
index 3f47d58..f8b381c 100644
--- a/src/operator/nn/batch_norm-inl.h
+++ b/src/operator/nn/batch_norm-inl.h
@@ -62,6 +62,7 @@ struct BatchNormParam : public 
dmlc::Parameter<BatchNormParam> {
   double eps;
   float momentum;
   bool fix_gamma;
+  bool fix_beta;
   bool use_global_stats;
   bool output_mean_var;
   int axis;
@@ -75,6 +76,8 @@ struct BatchNormParam : public 
dmlc::Parameter<BatchNormParam> {
     .describe("Momentum for moving average");
     DMLC_DECLARE_FIELD(fix_gamma).set_default(true)
     .describe("Fix gamma while training");
+    DMLC_DECLARE_FIELD(fix_beta).set_default(false)
+    .describe("Fix beta while training");
     DMLC_DECLARE_FIELD(use_global_stats).set_default(false)
     .describe("Whether use global moving statistics instead of local 
batch-norm. "
               "This will force change batch-norm into a scale shift 
operator.");
@@ -90,6 +93,7 @@ struct BatchNormParam : public 
dmlc::Parameter<BatchNormParam> {
     return this->eps == other.eps &&
            this->momentum == other.momentum &&
            this->fix_gamma == other.fix_gamma &&
+           this->fix_beta == other.fix_beta &&
            this->use_global_stats == other.use_global_stats &&
            this->output_mean_var == other.output_mean_var &&
            this->axis == other.axis &&
@@ -107,6 +111,7 @@ struct hash<mxnet::op::BatchNormParam> {
     size_t ret = 0;
     ret = dmlc::HashCombine(ret, val.momentum);
     ret = dmlc::HashCombine(ret, val.fix_gamma);
+    ret = dmlc::HashCombine(ret, val.fix_beta);
     ret = dmlc::HashCombine(ret, val.use_global_stats);
     ret = dmlc::HashCombine(ret, val.output_mean_var);
     ret = dmlc::HashCombine(ret, val.axis);
diff --git a/src/operator/nn/batch_norm.cc b/src/operator/nn/batch_norm.cc
index be542ba..ec90a30 100644
--- a/src/operator/nn/batch_norm.cc
+++ b/src/operator/nn/batch_norm.cc
@@ -155,35 +155,34 @@ void BatchNormForwardImpl(mshadow::Stream<cpu> *,
 
     // compute output
     AccReal *w = weights.dptr<AccReal>();
-    const AccReal *b = bias.dptr<AccReal>();
+    AccReal *b = bias.dptr<AccReal>();
+
+    // Ignore gamma
+    if (param_.fix_gamma) {
+      if (IsBNWriting(req[batchnorm::kGamma])) {
+        w[channel] = AccReal(1);
+      }
+    }
+
+    // Ignore beta
+    if (param_.fix_beta) {
+       if (IsBNWriting(req[batchnorm::kBeta])) {
+          b[channel] = AccReal(0);
+        }
+    }
 
     const AccReal thisMean = mean[channel];
     const AccReal thisInvstd = var[channel];
     const AccReal thisWeight = w[channel];
     const AccReal thisBias = b[channel];
 
-    // note that var is still invstd
-    if (!param_.fix_gamma) {
-      if (IsBNWriting(req[batchnorm::kData])) {
-        ForEachFast(inputData, outputData, channel,
-                    [thisWeight, thisBias, thisMean, thisInvstd](const DType 
*in_data,
-                                                                 DType 
*out_data) {
-                      *out_data = static_cast<DType>(
-                        ((*in_data - thisMean) * thisInvstd) * thisWeight + 
thisBias);
-                    });
-      }
-    } else {
-      if (IsBNWriting(req[batchnorm::kGamma])) {
-        w[channel] = AccReal(1);
-      }
-      if (IsBNWriting(req[batchnorm::kData])) {
-        ForEachFast(inputData, outputData, channel,
-                    [thisWeight, thisBias, thisMean, thisInvstd](const DType 
*in_data,
-                                                                 DType 
*out_data) {
-                      *out_data = static_cast<DType>(
-                        ((*in_data - thisMean) * thisInvstd) + thisBias);
-                    });
-      }
+    if (IsBNWriting(req[batchnorm::kData])) {
+          ForEachFast(inputData, outputData, channel,
+                      [thisWeight, thisBias, thisMean, thisInvstd](const DType 
*in_data,
+                                                                  DType 
*out_data) {
+                        *out_data = static_cast<DType>(
+                          ((*in_data - thisMean) * thisInvstd) * thisWeight + 
thisBias);
+                      });
     }
   }
 }
@@ -309,7 +308,11 @@ void BatchNormBackwardImpl(mshadow::Stream<cpu> *,
     }
 
     if (IsBNWriting(req[batchnorm::kBeta])) {
-      gradBiasData[channel] = scale * sumGradOut;
+      if (!param_.fix_beta) {
+        gradBiasData[channel] = scale * sumGradOut;
+      } else {
+        gradBiasData[channel] = AccReal(0);
+      }
     }
   }
 }
@@ -478,6 +481,9 @@ static inline bool BatchNormStorageType(const 
nnvm::NodeAttrs &attrs,
   if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && 
param.fix_gamma) {
     LOG(FATAL) << "fix_gamma=True is not supported for sparse ndarrays. 
Tracked at #11647";
   }
+  if (!common::ContainsOnlyStorage(*in_attrs, kDefaultStorage) && 
param.fix_beta) {
+    LOG(FATAL) << "fix_beta=True is not supported for sparse ndarrays. Tracked 
at #11647";
+  }
   return dispatched;
 }
 
@@ -565,11 +571,12 @@ the 'channel' (separately normalized groups).  The 
default is 1.  Specifying -1
 axis to be the last item in the input shape.
 
 Both ``gamma`` and ``beta`` are learnable parameters. But if ``fix_gamma`` is 
true,
-then set ``gamma`` to 1 and its gradient to 0.
+then set ``gamma`` to 1 and its gradient to 0. If ``fix_beta`` is true, then 
set ``beta`` to 0
+and its gradient to 0.
 
 Note::
 
-When fix_gamma is set to True, no sparse support is provided. If fix_gamma is 
set to False,
+When fix_gamma/fix_beta is set to True, no sparse support is provided. If 
fix_gamma/fix_beta is set to False,
 the sparse tensors will fallback.
 
 )code" ADD_FILELINE)
diff --git a/src/operator/nn/batch_norm.cu b/src/operator/nn/batch_norm.cu
index 03962cb..309542d 100644
--- a/src/operator/nn/batch_norm.cu
+++ b/src/operator/nn/batch_norm.cu
@@ -32,8 +32,9 @@
 #define WRITE_GAMMA_FLAG      2
 #define WRITE_BETA_FLAG       4
 #define FIX_GAMMA_FLAG        8
-#define IS_TRAINING_FLAG      16
-#define USE_GLOBAL_STATS_FLAG 32
+#define FIX_BETA_FLAG         16
+#define IS_TRAINING_FLAG      32
+#define USE_GLOBAL_STATS_FLAG 64
 
 #if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 5
 #include "./cudnn/cudnn_batch_norm-inl.h"
@@ -223,8 +224,9 @@ __global__ void 
BatchNormalizationUpdateOutputInferenceKernel(
   AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() > 0)
                   ? ScalarConvert<DType, AccReal>::to(weight[plane])
                   : ScalarConvert<int, AccReal>::to(1);
-  AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, 
AccReal>::to(bias[plane])
-                                        : ScalarConvert<int, AccReal>::to(0);
+  AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
+                  ? ScalarConvert<DType, AccReal>::to(bias[plane])
+                  : ScalarConvert<int, AccReal>::to(0);
   if (threadIdx.x == 0) {
     saveMean[plane] = runningMean[plane];
     saveInvStd[plane] = VARIANCE_TO_INVSTD(runningVar[plane], epsilon);
@@ -232,6 +234,10 @@ __global__ void 
BatchNormalizationUpdateOutputInferenceKernel(
         && weight.numElements() > 0) {
       weight[plane] = AccReal(1);
     }
+    if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
+        && bias.numElements() > 0) {
+      bias[plane] = AccReal(0);
+    }
   }
   // Write normalized and update the output
   for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
@@ -282,14 +288,19 @@ __global__ void BatchNormalizationUpdateOutputKernel(
         && weight.numElements() > 0) {
       weight[plane] = AccReal(1);
     }
+    if ((flags & WRITE_BETA_FLAG) != 0 && (flags & FIX_BETA_FLAG) != 0
+        && bias.numElements() > 0) {
+      bias[plane] = AccReal(0);
+    }
   }
 
   // Write normalized and update the output
   const AccReal gamma = ((flags & FIX_GAMMA_FLAG) == 0 && weight.numElements() 
> 0)
                         ? ScalarConvert<DType, AccReal>::to(weight[plane])
                         : ScalarConvert<int, AccReal>::to(1);
-  const AccReal beta = bias.numElements() > 0 ? ScalarConvert<DType, 
AccReal>::to(bias[plane])
-                                              : ScalarConvert<int, 
AccReal>::to(0);
+  const AccReal beta = ((flags & FIX_BETA_FLAG) == 0 && bias.numElements() > 0)
+                        ? ScalarConvert<DType, AccReal>::to(bias[plane])
+                        : ScalarConvert<int, AccReal>::to(0);
   for (int batch = 0, nbatch = input.OuterSize(); batch < nbatch; ++batch) {
     for (int x = threadIdx.x, nx = input.InnerSize(); x < nx; x += blockDim.x) 
{
       const DType inp = input.get_ref(batch, plane, x);
@@ -388,7 +399,11 @@ static __global__ void BatchNormalizationBackwardKernel(
   }
 
   if (tensors.gradBias.numElements() > 0 && threadIdx.x == 0 && (flags & 
WRITE_BETA_FLAG) != 0) {
-    tensors.gradBias[plane] = ScalarConvert<AccReal, DType>::to(gradOutputSum);
+    if ((flags & FIX_BETA_FLAG) == 0) {
+      tensors.gradBias[plane] = ScalarConvert<AccReal, 
DType>::to(gradOutputSum);
+    } else {
+      tensors.gradBias[plane] = DType(0);
+    }
   }
 }
 
@@ -582,6 +597,7 @@ static inline uint32_t SetupFlags(const OpContext &ctx,
   uint32_t flags = 0;
   flags |= ctx.is_train ? IS_TRAINING_FLAG : 0;
   flags |= params.fix_gamma ? FIX_GAMMA_FLAG : 0;
+  flags |= params.fix_beta ? FIX_BETA_FLAG : 0;
   flags |= params.use_global_stats ? USE_GLOBAL_STATS_FLAG : 0;
   if (IsBNWriting(req[batchnorm::kData])) {
     flags |= WRITE_DATA_FLAG;
diff --git a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h 
b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
index d4b9f84..9caa9d3 100644
--- a/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
+++ b/src/operator/nn/cudnn/cudnn_batch_norm-inl.h
@@ -115,6 +115,8 @@ class CuDNNBatchNormOp {
 
       if (param_.fix_gamma) gamma = 1.f;
 
+      if (param_.fix_beta) beta = 0.f;
+
       if (ctx.is_train) {
         Tensor<gpu, 1, DTypeParam> save_mean =
           out_data[cudnnbatchnorm::kMean].get_with_shape<gpu, 1, 
DTypeParam>(Shape1(shape_[1]), s);
@@ -229,6 +231,7 @@ class CuDNNBatchNormOp {
         global_stats ? nullptr : save_mean.dptr_,
         global_stats ? nullptr : save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
+      if (param_.fix_beta) dbeta = 0.f;
     })
 #else  // CUDNN_VERSION < 4007
     MSHADOW_REAL_TYPE_SWITCH(dtype_param_, DTypeParam, {
@@ -267,6 +270,7 @@ class CuDNNBatchNormOp {
                                                  global_stats ? nullptr : 
save_mean.dptr_,
                                                  global_stats ? nullptr : 
save_inv_var.dptr_));
       if (param_.fix_gamma) dgamma = 0.f;
+      if (param_.fix_beta) dbeta = 0.f;
     })
 #endif
   }
diff --git a/tests/python/gpu/test_operator_gpu.py 
b/tests/python/gpu/test_operator_gpu.py
index dd7ec98..13022c1 100644
--- a/tests/python/gpu/test_operator_gpu.py
+++ b/tests/python/gpu/test_operator_gpu.py
@@ -303,35 +303,52 @@ def test_batchnorm_with_type():
 
 
   # V2, 2D
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, 
cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, 
cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, 
cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_2D)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, 
cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
+  # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
   sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_2D)
+  # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+  sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_2D)
 
   # V2, 1D
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=False, 
cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, 
cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, 
cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=True, 
cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_1D)
+  # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
   sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_1D)
-  #
+  # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+  sym = mx.sym.BatchNorm(name='norm', fix_beta=True, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_1D)
+
   # # V2, 3D
-  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, cudnn_off=True)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=False, fix_beta=True, 
cudnn_off=True)
   check_consistency(sym, ctx_list_v2_3D)
+  sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, fix_beta=False, 
cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_3D)
+  # Don't specify fix_beta. Default i.e., fix_beta=False will be verified.
   sym = mx.sym.BatchNorm(name='norm', fix_gamma=True, cudnn_off=True)
   check_consistency(sym, ctx_list_v2_3D)
-
+  # Don't specify fix_gamma. Default i.e., fix_gamma=False will be verified.
+  sym = mx.sym.BatchNorm(name='norm', fix_beta=False, cudnn_off=True)
+  check_consistency(sym, ctx_list_v2_3D)
 
 @with_seed()
 def test_batchnorm_versions():
-  def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, 
use_global_stats):
+  def test_batchnorm_versions_helper(batchnorm_op_list, data, fix_gamma, 
fix_beta, use_global_stats):
     ctx_list = []
     sym_list = []
     # BatchNormV1 cpu
@@ -352,6 +369,7 @@ def test_batchnorm_versions():
     if 'batchnorm_cpu' in batchnorm_op_list:
       ctx_list.append({'ctx': mx.cpu(0), 'batchnorm_data': data, 'type_dict': 
{'batchnorm_data': np.float32}})
       sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+                                       fix_beta=fix_beta,
                                        use_global_stats=use_global_stats,
                                        name='batchnorm'))
 
@@ -359,6 +377,7 @@ def test_batchnorm_versions():
     if 'batchnorm_gpu' in batchnorm_op_list:
       ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': 
{'batchnorm_data': np.float32}})
       sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+                                       fix_beta=fix_beta,
                                        use_global_stats=use_global_stats,
                                        name='batchnorm', cudnn_off=True))
 
@@ -366,47 +385,54 @@ def test_batchnorm_versions():
     if 'batchnorm_cudnn' in batchnorm_op_list:
       ctx_list.append({'ctx': mx.gpu(0), 'batchnorm_data': data, 'type_dict': 
{'batchnorm_data': np.float32}})
       sym_list.append(mx.sym.BatchNorm(fix_gamma=fix_gamma,
+                                       fix_beta=fix_beta,
                                        use_global_stats=use_global_stats,
                                        name='batchnorm', cudnn_off=False))
 
     check_consistency(sym_list, ctx_list)
 
-  def test_1d_batchnorm(fix_gamma, use_global_stats):
+  def test_1d_batchnorm(fix_gamma, fix_beta, use_global_stats):
     data = (2, 3, 20)
     test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
                                                       'batchnorm_gpu', 
'batchnorm_cudnn'],
                                    data=data,
-                                   fix_gamma=fix_gamma, 
use_global_stats=use_global_stats)
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, 
use_global_stats=use_global_stats)
 
-  def test_2d_batchnorm(fix_gamma, use_global_stats):
+  def test_2d_batchnorm(fix_gamma, fix_beta, use_global_stats):
     data = (2, 3, 10, 10)
-    test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 
'batchnorm_v1_gpu',
-                                                      'batchnorm_cpu',
+    # batchmorm_v1 is deprecated.
+    # `fix_beta` parameter is available only in new batchnorm operator.
+    # Checking consistency separately for batchnormv1 and batchnorm.
+    test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_v1_cpu', 
'batchnorm_v1_gpu'],
+                                   data=data,
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, 
use_global_stats=use_global_stats)
+
+    test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
                                                       'batchnorm_gpu', 
'batchnorm_cudnn'],
                                    data=data,
-                                   fix_gamma=fix_gamma, 
use_global_stats=use_global_stats)
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, 
use_global_stats=use_global_stats)
 
-  def test_3d_batchnorm(fix_gamma, use_global_stats):
+  def test_3d_batchnorm(fix_gamma, fix_beta, use_global_stats):
     data = (2, 3, 3, 5, 5)
     test_batchnorm_versions_helper(batchnorm_op_list=['batchnorm_cpu',
                                                       'batchnorm_gpu'],
                                    data=data,
-                                   fix_gamma=fix_gamma, 
use_global_stats=use_global_stats)
-
-  test_1d_batchnorm(True,  False)
-  test_1d_batchnorm(False, False)
-  test_1d_batchnorm(False, True)
-  test_1d_batchnorm(True,  True)
-
-  test_2d_batchnorm(True,  False)
-  test_2d_batchnorm(False, False)
-  test_2d_batchnorm(False, True)
-  test_2d_batchnorm(True,  True)
-
-  test_3d_batchnorm(True,  False)
-  test_3d_batchnorm(False, False)
-  test_3d_batchnorm(False, True)
-  test_3d_batchnorm(True,  True)
+                                   fix_gamma=fix_gamma, fix_beta=fix_beta, 
use_global_stats=use_global_stats)
+
+  test_1d_batchnorm(True,  False, False)
+  test_1d_batchnorm(False, True, False)
+  test_1d_batchnorm(False, False, True)
+  test_1d_batchnorm(True,  True, True)
+
+  test_2d_batchnorm(True,  False, False)
+  test_2d_batchnorm(False, True, False)
+  test_2d_batchnorm(False, False, True)
+  test_2d_batchnorm(True,  True, True)
+
+  test_3d_batchnorm(True,  False, False)
+  test_3d_batchnorm(False, True, False)
+  test_3d_batchnorm(False, False, True)
+  test_3d_batchnorm(True,  True, True)
 
 
 @with_seed(1234)
diff --git a/tests/python/mkl/test_mkldnn.py b/tests/python/mkl/test_mkldnn.py
index e549009..a3c39ff 100644
--- a/tests/python/mkl/test_mkldnn.py
+++ b/tests/python/mkl/test_mkldnn.py
@@ -235,7 +235,7 @@ def test_batchnorm():
                            mx.nd.array(beta).tostype(stype)]
             mean_std = [mx.nd.array(rolling_mean).tostype(stype), 
mx.nd.array(rolling_std).tostype(stype)]
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
     stypes = ['row_sparse', 'default']
diff --git a/tests/python/unittest/test_operator.py 
b/tests/python/unittest/test_operator.py
index 5332517..5a5d956 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -1534,25 +1534,25 @@ def test_batchnorm_training():
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=True)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=True, 
use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=True, 
use_global_stats=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=True, fix_beta=True, 
use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=False)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             test = mx.symbol.BatchNorm_v1(data, fix_gamma=False, 
use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, 
use_global_stats=True)
             check_numeric_gradient(test, in_location, mean_std, 
numeric_eps=1e-2, rtol=0.16, atol=1e-2)
 
             # Test varying channel axis
@@ -1581,16 +1581,16 @@ def test_batchnorm_training():
                 xmean_std = [mx.nd.array(xrolling_mean).tostype(stype),
                              mx.nd.array(xrolling_std).tostype(stype)]
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=True, 
fix_beta=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=True, 
use_global_stats=True, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=True, 
fix_beta=False, use_global_stats=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=False, 
fix_beta=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
-                test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True, axis=chaxis)
+                test = mx.symbol.BatchNorm(data, fix_gamma=False, 
fix_beta=False, use_global_stats=True, axis=chaxis)
                 check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-2, rtol=0.2, atol=0.01)
 
     check_batchnorm_training('default')
diff --git a/tests/python/unittest/test_sparse_operator.py 
b/tests/python/unittest/test_sparse_operator.py
index 5780824..bddab11 100644
--- a/tests/python/unittest/test_sparse_operator.py
+++ b/tests/python/unittest/test_sparse_operator.py
@@ -2124,13 +2124,19 @@ def test_batchnorm_fallback():
         test = mx.symbol.BatchNorm(data, fix_gamma=True)
         assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
 
+        test = mx.symbol.BatchNorm(data, fix_beta=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
+
         test = mx.symbol.BatchNorm(data, fix_gamma=True, use_global_stats=True)
         assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
 
-        test = mx.symbol.BatchNorm(data, fix_gamma=False)
+        test = mx.symbol.BatchNorm(data, fix_beta=True, use_global_stats=True)
+        assertRaises(MXNetError, check_numeric_gradient, test, in_location, 
mean_std, numeric_eps=1e-3, rtol=0.16, atol=1e-2)
+
+        test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False)
         check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, 
rtol=0.16, atol=1e-2)
 
-        test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True)
+        test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, 
use_global_stats=True)
         check_numeric_gradient(test, in_location, mean_std, numeric_eps=1e-3, 
rtol=0.16, atol=1e-2)
 
         # Test varying channel axis
@@ -2161,14 +2167,20 @@ def test_batchnorm_fallback():
 
             test = mx.symbol.BatchNorm(data, fix_gamma=True, axis=chaxis)
             assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
+            
+            test = mx.symbol.BatchNorm(data, fix_beta=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
             test = mx.symbol.BatchNorm(data, fix_gamma=True, 
use_global_stats=True, axis=chaxis)
             assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False, axis=chaxis)
+            test = mx.symbol.BatchNorm(data, fix_beta=True, 
use_global_stats=True, axis=chaxis)
+            assertRaises(MXNetError, check_numeric_gradient, test, 
in_location, xmean_std, numeric_eps=1e-3, rtol=0.2, atol=0.01)
+
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, 
axis=chaxis)
             check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
-            test = mx.symbol.BatchNorm(data, fix_gamma=False, 
use_global_stats=True, axis=chaxis)
+            test = mx.symbol.BatchNorm(data, fix_gamma=False, fix_beta=False, 
use_global_stats=True, axis=chaxis)
             check_numeric_gradient(test, in_location, xmean_std, 
numeric_eps=1e-3, rtol=0.2, atol=0.01)
 
 

Reply via email to