This is an automated email from the ASF dual-hosted git repository.

zhasheng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new d452162  instance norm and reflection padding (#7938)
d452162 is described below

commit d4521621a9c21e1ec8c23dab6e4139bd6ba45a65
Author: Hang Zhang <8041160+zhanghang1...@users.noreply.github.com>
AuthorDate: Fri Feb 2 23:30:32 2018 -0800

    instance norm and reflection padding (#7938)
    
    * instance norm and reflection padding
    
    * r prefix
    
    * indent and space
    
    * fix docs
    
    * change docs
    
    * spacing
    
    * typo
    
    * hybrid forward
    
    * spcaing
    
    * add test for instance norm
    
    * fix typo
    
    * add to __all__
    
    * rm white space
    
    * integer value
    
    * add test
    
    * make line short
    
    * rm white space
    
    * add docs ref
    
    * fix docs
    
    * RFpad2D docs
    
    * read shape from weight
    
    * rm condition
---
 docs/api/python/gluon/nn.md           |  2 +
 python/mxnet/gluon/nn/basic_layers.py | 85 +++++++++++++++++++++++++++++++++--
 python/mxnet/gluon/nn/conv_layers.py  | 34 +++++++++++++-
 tests/python/unittest/test_gluon.py   | 11 +++++
 4 files changed, 128 insertions(+), 4 deletions(-)

diff --git a/docs/api/python/gluon/nn.md b/docs/api/python/gluon/nn.md
index 5e2dbe0..4515644 100644
--- a/docs/api/python/gluon/nn.md
+++ b/docs/api/python/gluon/nn.md
@@ -20,6 +20,7 @@ This document lists the neural network blocks in Gluon:
     Activation
     Dropout
     BatchNorm
+    InstanceNorm
     LeakyReLU
     Embedding
     Flatten
@@ -62,6 +63,7 @@ This document lists the neural network blocks in Gluon:
     GlobalAvgPool1D
     GlobalAvgPool2D
     GlobalAvgPool3D
+    ReflectionPad2D
 ```
 
 
diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index 43b4bda..1517539 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -19,8 +19,8 @@
 # pylint: disable= arguments-differ
 """Basic neural network layers."""
 __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Activation',
-           'Dropout', 'BatchNorm', 'LeakyReLU', 'Embedding', 'Flatten',
-           'Lambda', 'HybridLambda']
+           'Dropout', 'BatchNorm', 'InstanceNorm', 'LeakyReLU', 'Embedding',
+           'Flatten', 'Lambda', 'HybridLambda']
 import warnings
 import numpy as np
 
@@ -480,6 +480,86 @@ class Flatten(HybridBlock):
         return self.__class__.__name__
 
 
+class InstanceNorm(HybridBlock):
+    r"""
+    Applies instance normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array where (n>2) and normalizes
+    the input using the following formula:
+
+    .. math::
+
+      out = \frac{x - mean[data]}{ \sqrt{Var[data]} + \epsilon} * gamma + beta
+
+    Parameters
+    ----------
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    center: bool, default True
+        If True, add offset of `beta` to normalized tensor.
+        If False, `beta` is ignored.
+    scale: bool, default True
+        If True, multiply by `gamma`. If False, `gamma` is not used.
+        When the next layer is linear (also e.g. `nn.relu`),
+        this can be disabled since the scaling
+        will be done by the next layer.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+    in_channels : int, default 0
+        Number of channels (feature maps) in input data. If not specified,
+        initialization will be deferred to the first time `forward` is called
+        and `in_channels` will be inferred from the shape of input data.
+
+    Inputs:
+        - **data**: input tensor with arbitrary shape.
+
+    Outputs:
+        - **out**: output tensor with the same shape as `data`.
+
+    References
+    ----------
+        `Instance Normalization: The Missing Ingredient for Fast Stylization
+        <https://arxiv.org/abs/1607.08022>`_
+
+    Examples
+    --------
+    >>> # Input of shape (2,1,2)
+    >>> x = mx.nd.array([[[ 1.1,  2.2]],
+    ...                 [[ 3.3,  4.4]]])
+    >>> # Instance normalization is calculated with the above formula
+    >>> layer = InstanceNorm()
+    >>> layer.initialize(ctx=mx.cpu(0))
+    >>> layer(x)
+    [[[-0.99998355  0.99998331]]
+     [[-0.99998319  0.99998361]]]
+    <NDArray 2x1x2 @cpu(0)>
+    """
+    def __init__(self, epsilon=1e-5, center=True, scale=False,
+                 beta_initializer='zeros', gamma_initializer='ones',
+                 in_channels=0, **kwargs):
+        super(InstanceNorm, self).__init__(**kwargs)
+        self._kwargs = {'eps': epsilon}
+        self.gamma = self.params.get('gamma', grad_req='write' if scale else 
'null',
+                                     shape=(in_channels,), 
init=gamma_initializer,
+                                     allow_deferred_init=True)
+        self.beta = self.params.get('beta', grad_req='write' if center else 
'null',
+                                    shape=(in_channels,), 
init=beta_initializer,
+                                    allow_deferred_init=True)
+
+    def hybrid_forward(self, F, x, gamma, beta):
+        return F.InstanceNorm(x, gamma, beta,
+                              name='fwd', **self._kwargs)
+
+    def __repr__(self):
+        s = '{name}({content}'
+        in_channels = self.gamma.shape[0]
+        s += ', in_channels={0}'.format(in_channels)
+        s += ')'
+        return s.format(name=self.__class__.__name__,
+                        content=', '.join(['='.join([k, v.__repr__()])
+                                           for k, v in self._kwargs.items()]))
+
 class Lambda(Block):
     r"""Wraps an operator or an expression as a Block object.
 
@@ -526,7 +606,6 @@ class Lambda(Block):
 class HybridLambda(HybridBlock):
     r"""Wraps an operator or an expression as a HybridBlock object.
 
-
     Parameters
     ----------
     function : str or function
diff --git a/python/mxnet/gluon/nn/conv_layers.py 
b/python/mxnet/gluon/nn/conv_layers.py
index 822a81c..4e9ee68 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -23,7 +23,8 @@ __all__ = ['Conv1D', 'Conv2D', 'Conv3D',
            'MaxPool1D', 'MaxPool2D', 'MaxPool3D',
            'AvgPool1D', 'AvgPool2D', 'AvgPool3D',
            'GlobalMaxPool1D', 'GlobalMaxPool2D', 'GlobalMaxPool3D',
-           'GlobalAvgPool1D', 'GlobalAvgPool2D', 'GlobalAvgPool3D']
+           'GlobalAvgPool1D', 'GlobalAvgPool2D', 'GlobalAvgPool3D',
+           'ReflectionPad2D']
 
 from ..block import HybridBlock
 from ... import symbol
@@ -1007,3 +1008,34 @@ class GlobalAvgPool3D(_Pooling):
         assert layout == 'NCDHW', "Only supports NCDHW layout for now"
         super(GlobalAvgPool3D, self).__init__(
             (1, 1, 1), None, 0, True, True, 'avg', **kwargs)
+
+
+class ReflectionPad2D(HybridBlock):
+    """Pads the input tensor using the reflection of the input boundary.
+
+    Parameters
+    ----------
+    padding: int
+        An integer padding size
+
+    Shape:
+        - Input: :math:`(N, C, H_{in}, W_{in})`
+        - Output: :math:`(N, C, H_{out}, W_{out})` where
+          :math:`H_{out} = H_{in} + 2 * padding
+          :math:`W_{out} = W_{in} + 2 * padding
+
+    Examples
+    --------
+    >>> m = nn.ReflectionPad2D(3)
+    >>> input = mx.nd.random.normal(shape=(16, 3, 224, 224))
+    >>> output = m(input)
+    """
+    def __init__(self, padding=0, **kwargs):
+        super(ReflectionPad2D, self).__init__(**kwargs)
+        if isinstance(padding, numeric_types):
+            padding = (0, 0, 0, 0, padding, padding, padding, padding)
+        assert(len(padding) == 8)
+        self._padding = padding
+
+    def hybrid_forward(self, F, x):
+        return F.pad(x, mode='reflect', pad_width=self._padding)
diff --git a/tests/python/unittest/test_gluon.py 
b/tests/python/unittest/test_gluon.py
index 80109cf..d239705 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -328,11 +328,22 @@ def test_pool():
     layer.collect_params().initialize()
     assert (layer(x).shape==(2, 2, 4, 4))
 
+
 def test_batchnorm():
     layer = nn.BatchNorm(in_channels=10)
     check_layer_forward(layer, (2, 10, 10, 10))
 
 
+def test_instancenorm():
+    layer = nn.InstanceNorm(in_channels=10)
+    check_layer_forward(layer, (2, 10, 10, 10))
+
+
+def test_reflectionpad():
+    layer = nn.ReflectionPad2D(3)
+    check_layer_forward(layer, (2, 3, 24, 24))
+
+
 def test_reshape():
     x = mx.nd.ones((2, 4, 10, 10))
     layer = nn.Conv2D(10, 2, in_channels=4)

-- 
To stop receiving notification emails like this one, please contact
zhash...@apache.org.

Reply via email to