zhanghang1989 closed pull request #12853: Add GN
URL: https://github.com/apache/incubator-mxnet/pull/12853
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index d26841977ac..9ec4d5d5e4b 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -19,7 +19,8 @@
 # pylint: disable= arguments-differ
 """Basic neural network layers."""
 __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
-           'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 
'HybridLambda']
+           'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 
'HybridLambda',
+           'GroupNorm']
 import warnings
 import numpy as np
 
@@ -27,6 +28,7 @@
 from ..block import Block, HybridBlock
 from ..utils import _indent
 from ... import nd, sym
+from ... import autograd
 
 
 class Sequential(Block):
@@ -700,3 +702,83 @@ def hybrid_forward(self, F, x, *args):
     def __repr__(self):
         return '{name}({function})'.format(name=self.__class__.__name__,
                                            function=self._func_name)
+
+
+class GroupNorm(Block):
+    """GroupNorm normalization layer (Wu and He, 2014).
+    Parameters
+    ----------
+    ngroups : int
+        Numnber of channel groups in GN.
+    in_channels : int, default 0
+        Number of channels (feature maps) in input data. If not specified,
+        initialization will be deferred to the first time `forward` is called
+        and `in_channels` will be inferred from the shape of input data.
+    axis : int, default 1
+        The axis that should be normalized. This is typically the channels
+        (C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
+        set `axis=1` in `GroupNorm`. If `layout='NHWC'`, then set `axis=3`.
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+
+    Inputs:
+        - **data**: input tensor with arbitrary shape.
+
+    Outputs:
+        - **out**: output tensor with the same shape as `data`.
+    """
+    def __init__(self, ngroups, in_channels=0, axis=1, epsilon=1e-5,
+                 beta_initializer='zeros', gamma_initializer='ones', **kwargs):
+        super(GroupNorm, self).__init__(**kwargs)
+        self._kwargs = {'axis': axis, 'eps': epsilon, 'momentum': 0,
+                        'fix_gamma': True, 'use_global_stats': False}
+        self.ngroups = ngroups
+        assert in_channels % ngroups == 0, "Channel number should be divisible 
by groups."
+        if in_channels != 0:
+            self.in_channels = in_channels
+
+        self.gamma = self.params.get('gamma', grad_req='write',
+                                     shape=(in_channels,), 
init=gamma_initializer,
+                                     allow_deferred_init=True, 
differentiable=True)
+        self.beta = self.params.get('beta', grad_req='write',
+                                    shape=(in_channels,), 
init=beta_initializer,
+                                    allow_deferred_init=True, 
differentiable=True)
+        # hacky
+        self.hacky_zeros = self.params.get('hacky_zeros', grad_req='null',
+                                           shape=(ngroups,), init='zeros',
+                                           allow_deferred_init=True, 
differentiable=False)
+        self.hacky_ones = self.params.get('hacky_ones', grad_req='null',
+                                          shape=(ngroups,), init='ones',
+                                          allow_deferred_init=True, 
differentiable=False)
+ 
+
+    def cast(self, dtype):
+        if np.dtype(dtype).name == 'float16':
+            dtype = 'float32'
+        super(GroupNorm, self).cast(dtype)
+
+    def forward(self, x):
+        xshape = x.shape
+        # normalization
+        with autograd.train_mode():
+            y = nd.BatchNorm(x.reshape(xshape[0], self.ngroups, -1),
+                             self.hacky_ones.data(), self.hacky_zeros.data(),
+                             self.hacky_zeros.data(), self.hacky_ones.data(),
+                             name='fwd', **self._kwargs)
+        # scale and shift
+        y = y.reshape(xshape[0], xshape[1], -1)
+        y = y * self.gamma.data().reshape(1, -1, 1) + 
self.beta.data().reshape(1, -1, 1)
+        return y.reshape(xshape)
+
+    def __repr__(self):
+        s = '{name}({content}'
+        in_channels = self.gamma.shape[0]
+        s += ', in_channels={0}'.format(in_channels if in_channels else None)
+        s += ')'
+        return s.format(name=self.__class__.__name__,
+                        content=', '.join(['='.join([k, v.__repr__()])
+                                           for k, v in self._kwargs.items()]))


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to