This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch v1.x
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/v1.x by this push:
     new 36bd144  fix batchnorm (#18377) (#18470)
36bd144 is described below

commit 36bd144757586825fe78c7c2ec3db898daf1e19b
Author: Xinyu Chen <xinyu1.c...@intel.com>
AuthorDate: Wed Jun 3 13:46:42 2020 +0800

    fix batchnorm (#18377) (#18470)
    
    Update basic_layers.py
    
    fix
    
    fix
    
    Update basic_layers.py
    
    fix bug
    
    Co-authored-by: Xingjian Shi <xsh...@connect.ust.hk>
---
 python/mxnet/gluon/contrib/nn/basic_layers.py | 14 +++++++----
 python/mxnet/gluon/nn/basic_layers.py         | 36 +++++++++++++++------------
 2 files changed, 29 insertions(+), 21 deletions(-)

diff --git a/python/mxnet/gluon/contrib/nn/basic_layers.py 
b/python/mxnet/gluon/contrib/nn/basic_layers.py
index bc7c3ce..3c48a74 100644
--- a/python/mxnet/gluon/contrib/nn/basic_layers.py
+++ b/python/mxnet/gluon/contrib/nn/basic_layers.py
@@ -220,11 +220,15 @@ class SyncBatchNorm(BatchNorm):
                  center=True, scale=True, use_global_stats=False, 
beta_initializer='zeros',
                  gamma_initializer='ones', running_mean_initializer='zeros',
                  running_variance_initializer='ones', **kwargs):
-        fuse_relu = False
-        super(SyncBatchNorm, self).__init__(1, momentum, epsilon, center, 
scale, use_global_stats,
-                                            fuse_relu, beta_initializer, 
gamma_initializer,
-                                            running_mean_initializer, 
running_variance_initializer,
-                                            in_channels, **kwargs)
+        super(SyncBatchNorm, self).__init__(
+            axis=1, momentum=momentum, epsilon=epsilon,
+            center=center, scale=scale,
+            use_global_stats=use_global_stats,
+            beta_initializer=beta_initializer,
+            gamma_initializer=gamma_initializer,
+            running_mean_initializer=running_mean_initializer,
+            running_variance_initializer=running_variance_initializer,
+            in_channels=in_channels, **kwargs)
         num_devices = self._get_num_devices() if num_devices is None else 
num_devices
         self._kwargs = {'eps': epsilon, 'momentum': momentum,
                         'fix_gamma': not scale, 'use_global_stats': 
use_global_stats,
diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index 70b0a71..72230fe 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -410,7 +410,6 @@ class BatchNorm(_BatchNorm):
         If True, use global moving statistics instead of local batch-norm. 
This will force
         change batch-norm into a scale shift operator.
         If False, use local batch-norm.
-    fuse_relu: False
     beta_initializer: str or `Initializer`, default 'zeros'
         Initializer for the beta weight.
     gamma_initializer: str or `Initializer`, default 'ones'
@@ -432,17 +431,20 @@ class BatchNorm(_BatchNorm):
         - **out**: output tensor with the same shape as `data`.
     """
     def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, 
scale=True,
-                 use_global_stats=False, fuse_relu=False,
+                 use_global_stats=False,
                  beta_initializer='zeros', gamma_initializer='ones',
                  running_mean_initializer='zeros', 
running_variance_initializer='ones',
                  in_channels=0, **kwargs):
-        assert not fuse_relu, "Please use BatchNormReLU with Relu fusion"
         super(BatchNorm, self).__init__(
-            axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
-            use_global_stats=False, fuse_relu=False,
-            beta_initializer='zeros', gamma_initializer='ones',
-            running_mean_initializer='zeros', 
running_variance_initializer='ones',
-            in_channels=0, **kwargs)
+            axis=axis, momentum=momentum, epsilon=epsilon, center=center,
+            scale=scale,
+            use_global_stats=use_global_stats, fuse_relu=False,
+            beta_initializer=beta_initializer,
+            gamma_initializer=gamma_initializer,
+            running_mean_initializer=running_mean_initializer,
+            running_variance_initializer=running_variance_initializer,
+            in_channels=in_channels, **kwargs)
+
 
 class BatchNormReLU(_BatchNorm):
     """Batch normalization layer (Ioffe and Szegedy, 2014).
@@ -472,7 +474,6 @@ class BatchNormReLU(_BatchNorm):
         If True, use global moving statistics instead of local batch-norm. 
This will force
         change batch-norm into a scale shift operator.
         If False, use local batch-norm.
-    fuse_relu: True
     beta_initializer: str or `Initializer`, default 'zeros'
         Initializer for the beta weight.
     gamma_initializer: str or `Initializer`, default 'ones'
@@ -494,17 +495,20 @@ class BatchNormReLU(_BatchNorm):
         - **out**: output tensor with the same shape as `data`.
     """
     def __init__(self, axis=1, momentum=0.9, epsilon=1e-5, center=True, 
scale=True,
-                 use_global_stats=False, fuse_relu=True,
+                 use_global_stats=False,
                  beta_initializer='zeros', gamma_initializer='ones',
                  running_mean_initializer='zeros', 
running_variance_initializer='ones',
                  in_channels=0, **kwargs):
-        assert fuse_relu, "Please use BatchNorm w/o Relu fusion"
         super(BatchNormReLU, self).__init__(
-            axis=1, momentum=0.9, epsilon=1e-5, center=True, scale=True,
-            use_global_stats=False, fuse_relu=True,
-            beta_initializer='zeros', gamma_initializer='ones',
-            running_mean_initializer='zeros', 
running_variance_initializer='ones',
-            in_channels=0, **kwargs)
+            axis=axis, momentum=momentum, epsilon=epsilon,
+            center=center, scale=scale,
+            use_global_stats=use_global_stats, fuse_relu=True,
+            beta_initializer=beta_initializer,
+            gamma_initializer=gamma_initializer,
+            running_mean_initializer=running_mean_initializer,
+            running_variance_initializer=running_variance_initializer,
+            in_channels=in_channels, **kwargs)
+
 
 class Embedding(HybridBlock):
     r"""Turns non-negative integers (indexes/tokens) into dense vectors

Reply via email to