This is an automated email from the ASF dual-hosted git repository. jxie pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new d048615 temporary solution for instancenorm, will refactor using backend (#9807) d048615 is described below commit d04861575f18cd476132e56f4e738e9e6771338b Author: Hang Zhang <8041160+zhanghang1...@users.noreply.github.com> AuthorDate: Mon Feb 19 11:47:14 2018 -0800 temporary solution for instancenorm, will refactor using backend (#9807) * temporary solution for instancenorm, will refactor using backend * fix typo * rm space * fix doc * fix doc * Update conv_layers.py * Update basic_layers.py * Update conv_layers.py * Update basic_layers.py * fix typo * fix typo --- python/mxnet/gluon/nn/basic_layers.py | 19 +++++++++++++++---- python/mxnet/gluon/nn/conv_layers.py | 20 ++++++++++++++------ 2 files changed, 29 insertions(+), 10 deletions(-) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 0f38119..b61540d 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -420,6 +420,10 @@ class InstanceNorm(HybridBlock): Parameters ---------- + axis : int, default 1 + The axis that should be normalized. This is typically the channels + (C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`, + set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`. epsilon: float, default 1e-5 Small float added to variance to avoid dividing by zero. center: bool, default True @@ -439,6 +443,7 @@ class InstanceNorm(HybridBlock): initialization will be deferred to the first time `forward` is called and `in_channels` will be inferred from the shape of input data. + Inputs: - **data**: input tensor with arbitrary shape. @@ -463,11 +468,13 @@ class InstanceNorm(HybridBlock): [[-0.99998319 0.99998361]]] <NDArray 2x1x2 @cpu(0)> """ - def __init__(self, epsilon=1e-5, center=True, scale=False, + def __init__(self, axis=1, epsilon=1e-5, center=True, scale=False, beta_initializer='zeros', gamma_initializer='ones', in_channels=0, **kwargs): super(InstanceNorm, self).__init__(**kwargs) - self._kwargs = {'eps': epsilon} + self._kwargs = {'eps': epsilon, 'axis': axis} + self._axis = axis + self._epsilon = epsilon self.gamma = self.params.get('gamma', grad_req='write' if scale else 'null', shape=(in_channels,), init=gamma_initializer, allow_deferred_init=True) @@ -476,8 +483,12 @@ class InstanceNorm(HybridBlock): allow_deferred_init=True) def hybrid_forward(self, F, x, gamma, beta): - return F.InstanceNorm(x, gamma, beta, - name='fwd', **self._kwargs) + if self._axis == 1: + return F.InstanceNorm(x, gamma, beta, + name='fwd', eps=self._epsilon) + x = x.swapaxes(1, self._axis) + return F.InstanceNorm(x, gamma, beta, name='fwd', + eps=self._epsilon).swapaxes(1, self._axis) def __repr__(self): s = '{name}({content}' diff --git a/python/mxnet/gluon/nn/conv_layers.py b/python/mxnet/gluon/nn/conv_layers.py index a69cb8a..87a62bc 100644 --- a/python/mxnet/gluon/nn/conv_layers.py +++ b/python/mxnet/gluon/nn/conv_layers.py @@ -1011,18 +1011,26 @@ class GlobalAvgPool3D(_Pooling): class ReflectionPad2D(HybridBlock): - """Pads the input tensor using the reflection of the input boundary. + r"""Pads the input tensor using the reflection of the input boundary. Parameters ---------- padding: int An integer padding size - Shape: - - Input: :math:`(N, C, H_{in}, W_{in})` - - Output: :math:`(N, C, H_{out}, W_{out})` where - :math:`H_{out} = H_{in} + 2 * padding - :math:`W_{out} = W_{in} + 2 * padding + + Inputs: + - **data**: input tensor with the shape :math:`(N, C, H_{in}, W_{in})`. + + Outputs: + - **out**: output tensor with the shape :math:`(N, C, H_{out}, W_{out})`, where + + .. math:: + + H_{out} = H_{in} + 2 \cdot padding + + W_{out} = W_{in} + 2 \cdot padding + Examples -------- -- To stop receiving notification emails like this one, please contact j...@apache.org.