piiswrong closed pull request #9807: temporary solution for instancenorm, will 
refactor using backend
URL: https://github.com/apache/incubator-mxnet/pull/9807
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/python/mxnet/gluon/nn/basic_layers.py 
b/python/mxnet/gluon/nn/basic_layers.py
index 0f38119af8..b61540dd61 100644
--- a/python/mxnet/gluon/nn/basic_layers.py
+++ b/python/mxnet/gluon/nn/basic_layers.py
@@ -420,6 +420,10 @@ class InstanceNorm(HybridBlock):
 
     Parameters
     ----------
+    axis : int, default 1
+        The axis that should be normalized. This is typically the channels
+        (C) axis. For instance, after a `Conv2D` layer with `layout='NCHW'`,
+        set `axis=1` in `InstanceNorm`. If `layout='NHWC'`, then set `axis=3`.
     epsilon: float, default 1e-5
         Small float added to variance to avoid dividing by zero.
     center: bool, default True
@@ -439,6 +443,7 @@ class InstanceNorm(HybridBlock):
         initialization will be deferred to the first time `forward` is called
         and `in_channels` will be inferred from the shape of input data.
 
+
     Inputs:
         - **data**: input tensor with arbitrary shape.
 
@@ -463,11 +468,13 @@ class InstanceNorm(HybridBlock):
      [[-0.99998319  0.99998361]]]
     <NDArray 2x1x2 @cpu(0)>
     """
-    def __init__(self, epsilon=1e-5, center=True, scale=False,
+    def __init__(self, axis=1, epsilon=1e-5, center=True, scale=False,
                  beta_initializer='zeros', gamma_initializer='ones',
                  in_channels=0, **kwargs):
         super(InstanceNorm, self).__init__(**kwargs)
-        self._kwargs = {'eps': epsilon}
+        self._kwargs = {'eps': epsilon, 'axis': axis}
+        self._axis = axis
+        self._epsilon = epsilon
         self.gamma = self.params.get('gamma', grad_req='write' if scale else 
'null',
                                      shape=(in_channels,), 
init=gamma_initializer,
                                      allow_deferred_init=True)
@@ -476,8 +483,12 @@ def __init__(self, epsilon=1e-5, center=True, scale=False,
                                     allow_deferred_init=True)
 
     def hybrid_forward(self, F, x, gamma, beta):
-        return F.InstanceNorm(x, gamma, beta,
-                              name='fwd', **self._kwargs)
+        if self._axis == 1:
+            return F.InstanceNorm(x, gamma, beta,
+                                  name='fwd', eps=self._epsilon)
+        x = x.swapaxes(1, self._axis)
+        return F.InstanceNorm(x, gamma, beta, name='fwd',
+                              eps=self._epsilon).swapaxes(1, self._axis)
 
     def __repr__(self):
         s = '{name}({content}'
diff --git a/python/mxnet/gluon/nn/conv_layers.py 
b/python/mxnet/gluon/nn/conv_layers.py
index a69cb8a060..87a62bc8c7 100644
--- a/python/mxnet/gluon/nn/conv_layers.py
+++ b/python/mxnet/gluon/nn/conv_layers.py
@@ -1011,18 +1011,26 @@ def __init__(self, layout='NCDHW', **kwargs):
 
 
 class ReflectionPad2D(HybridBlock):
-    """Pads the input tensor using the reflection of the input boundary.
+    r"""Pads the input tensor using the reflection of the input boundary.
 
     Parameters
     ----------
     padding: int
         An integer padding size
 
-    Shape:
-        - Input: :math:`(N, C, H_{in}, W_{in})`
-        - Output: :math:`(N, C, H_{out}, W_{out})` where
-          :math:`H_{out} = H_{in} + 2 * padding
-          :math:`W_{out} = W_{in} + 2 * padding
+
+    Inputs:
+        - **data**: input tensor with the shape :math:`(N, C, H_{in}, W_{in})`.
+
+    Outputs:
+        - **out**: output tensor with the shape :math:`(N, C, H_{out}, 
W_{out})`, where
+
+          .. math::
+
+            H_{out} = H_{in} + 2 \cdot padding
+
+            W_{out} = W_{in} + 2 \cdot padding
+
 
     Examples
     --------


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to