This is an automated email from the ASF dual-hosted git repository.

sxjscience pushed a commit to branch master
in repository

The following commit(s) were added to refs/heads/master by this push:
     new eec0fb4  Group Normalization (#14959)
eec0fb4 is described below

commit eec0fb4eda40f4fb8222a8d93d8face454aead09
Author: Hao Jin <>
AuthorDate: Thu Jul 18 22:44:34 2019 -0700

    Group Normalization (#14959)
    * GroupNorm
    * add to amp list
    * re-write forward
 python/mxnet/contrib/amp/lists/ |   1 +
 python/mxnet/gluon/nn/    |  91 +++++++-
 src/operator/nn/group_norm-inl.h         | 347 +++++++++++++++++++++++++++++++
 src/operator/nn/            | 131 ++++++++++++
 src/operator/nn/            |  37 ++++
 tests/python/unittest/      |   9 +
 tests/python/unittest/   |  91 ++++++++
 7 files changed, 706 insertions(+), 1 deletion(-)

diff --git a/python/mxnet/contrib/amp/lists/ 
index 9a587df..c6cc3d1 100644
--- a/python/mxnet/contrib/amp/lists/
+++ b/python/mxnet/contrib/amp/lists/
@@ -471,6 +471,7 @@ FP32_FUNCS = [
+    'GroupNorm',
diff --git a/python/mxnet/gluon/nn/ 
index 3d6976c..b1482ce 100644
--- a/python/mxnet/gluon/nn/
+++ b/python/mxnet/gluon/nn/
@@ -19,7 +19,8 @@
 # pylint: disable= arguments-differ
 """Basic neural network layers."""
 __all__ = ['Sequential', 'HybridSequential', 'Dense', 'Dropout', 'Embedding',
-           'BatchNorm', 'InstanceNorm', 'LayerNorm', 'Flatten', 'Lambda', 
+           'BatchNorm', 'InstanceNorm', 'LayerNorm', 'GroupNorm',
+           'Flatten', 'Lambda', 'HybridLambda']
 import warnings
 import numpy as np
@@ -616,6 +617,94 @@ class LayerNorm(HybridBlock):
                                            for k, v in self._kwargs.items()]))
+class GroupNorm(HybridBlock):
+    r"""
+    Applies group normalization to the n-dimensional input array.
+    This operator takes an n-dimensional input array where the leftmost 2 axis 
+    `batch` and `channel` respectively:
+    .. math::
+      x = x.reshape((N, num_groups, C // num_groups, ...))
+      axis = (2, ...)
+      out = \frac{x - mean[x, axis]}{ \sqrt{Var[x, axis] + \epsilon}} * gamma 
+ beta
+    Parameters
+    ----------
+    num_groups: int, default 1
+        Number of groups to separate the channel axis into.
+    epsilon: float, default 1e-5
+        Small float added to variance to avoid dividing by zero.
+    center: bool, default True
+        If True, add offset of `beta` to normalized tensor.
+        If False, `beta` is ignored.
+    scale: bool, default True
+        If True, multiply by `gamma`. If False, `gamma` is not used.
+    beta_initializer: str or `Initializer`, default 'zeros'
+        Initializer for the beta weight.
+    gamma_initializer: str or `Initializer`, default 'ones'
+        Initializer for the gamma weight.
+    Inputs:
+        - **data**: input tensor with shape (N, C, ...).
+    Outputs:
+        - **out**: output tensor with the same shape as `data`.
+    References
+    ----------
+        `Group Normalization
+        <>`_
+    Examples
+    --------
+    >>> # Input of shape (2, 3, 4)
+    >>> x = mx.nd.array([[[ 0,  1,  2,  3],
+                          [ 4,  5,  6,  7],
+                          [ 8,  9, 10, 11]],
+                         [[12, 13, 14, 15],
+                          [16, 17, 18, 19],
+                          [20, 21, 22, 23]]])
+    >>> # Group normalization is calculated with the above formula
+    >>> layer = GroupNorm()
+    >>> layer.initialize(ctx=mx.cpu(0))
+    >>> layer(x)
+    [[[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
+      [-0.4345239 -0.1448413  0.1448413  0.4345239]
+      [ 0.7242065  1.0138891  1.3035717  1.5932543]]
+     [[-1.5932543 -1.3035717 -1.0138891 -0.7242065]
+      [-0.4345239 -0.1448413  0.1448413  0.4345239]
+      [ 0.7242065  1.0138891  1.3035717  1.5932543]]]
+    <NDArray 2x3x4 @cpu(0)>
+    """
+    def __init__(self, num_groups=1, epsilon=1e-5, center=True, scale=True,
+                 beta_initializer='zeros', gamma_initializer='ones',
+                 prefix=None, params=None):
+        super(GroupNorm, self).__init__(prefix=prefix, params=params)
+        self._kwargs = {'eps': epsilon, 'num_groups': num_groups, 'center': 
center, 'scale': scale}
+        self._num_groups = num_groups
+        self._epsilon = epsilon
+        self._center = center
+        self._scale = scale
+        self.gamma = self.params.get('gamma', grad_req='write' if scale else 
+                                     shape=(num_groups,), 
+                                     allow_deferred_init=True)
+        self.beta = self.params.get('beta', grad_req='write' if center else 
+                                    shape=(num_groups,), init=beta_initializer,
+                                    allow_deferred_init=True)
+    def hybrid_forward(self, F, data, gamma, beta):
+        norm_data = F.GroupNorm(data, gamma=gamma, beta=beta, 
num_groups=self._num_groups, eps=self._epsilon)
+        return norm_data
+    def __repr__(self):
+        s = '{name}({content})'
+        return s.format(name=self.__class__.__name__,
+                        content=', '.join(['='.join([k, v.__repr__()])
+                                           for k, v in self._kwargs.items()]))
 class Lambda(Block):
     r"""Wraps an operator or an expression as a Block object.
diff --git a/src/operator/nn/group_norm-inl.h b/src/operator/nn/group_norm-inl.h
new file mode 100644
index 0000000..69d5a30
--- /dev/null
+++ b/src/operator/nn/group_norm-inl.h
@@ -0,0 +1,347 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ * Copyright (c) 2019 by Contributors
+ * \file group_norm-inl.h
+ * \brief Implements Group Normalization (
+ * \author Hao Jin
+#include <dmlc/logging.h>
+#include <dmlc/parameter.h>
+#include <mxnet/operator.h>
+#include <mshadow/base.h>
+#include <map>
+#include <algorithm>
+#include <vector>
+#include <string>
+#include <utility>
+#include "./moments-inl.h"
+#include "../mshadow_op.h"
+#include "../operator_common.h"
+#include "../mxnet_op.h"
+#include "../tensor/broadcast_reduce_op.h"
+namespace mxnet {
+namespace op {
+namespace groupnorm {
+enum GroupNormOpInputs {kData, kGamma, kBeta};  // kGamma: scaling parameters, 
kBeta: shift biases
+enum GroupNormOpOutputs {kOut, kMean, kStd};  // req, out_data
+}  // namespace groupnorm
+struct GroupNormParam : public dmlc::Parameter<GroupNormParam> {
+  int num_groups;
+  float eps;
+  bool output_mean_var;
+    DMLC_DECLARE_FIELD(num_groups).set_default(1)
+      .describe("Total number of groups.");
+    DMLC_DECLARE_FIELD(eps).set_default(1e-5f)
+      .describe("An `epsilon` parameter to prevent division by 0.");
+    DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
+      .describe("Output the mean and std calculated along the given axis.");
+  }
+template<typename xpu>
+void GroupNormCompute(const nnvm::NodeAttrs& attrs,
+                      const OpContext& ctx,
+                      const std::vector<TBlob>& inputs,
+                      const std::vector<OpReqType>& req,
+                      const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  const int num_groups = param.num_groups;
+  if (req[0] == kNullOp) return;
+  CHECK_NE(req[0], kAddTo);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  const TBlob& data = inputs[groupnorm::kData];
+  const TBlob& mean = outputs[groupnorm::kMean];
+  const TBlob& std = outputs[groupnorm::kStd];
+  const mxnet::TShape& data_shape = data.shape_;
+  CHECK_GE(data_shape.ndim(), 3U)
+    << "input should have at least 3 dims and "
+    << "the first 2 dims should be batch and channel respectively";
+  CHECK_EQ(data_shape[1] % num_groups, 0)
+    << "number of channel should be divisible by num_groups.";
+  mxnet::TShape temp_data_shape(data_shape.ndim() + 1, 1);
+  temp_data_shape[0] = data_shape[0];
+  temp_data_shape[1] = num_groups;
+  temp_data_shape[2] = data_shape[1] / num_groups;
+  for (int i = 2; i < data_shape.ndim(); ++i) {
+    temp_data_shape[i+1] = data_shape[i];
+  }
+  mxnet::TShape moments_shape(temp_data_shape.ndim(), 1);
+  for (int i = 0; i < data.shape_.ndim(); ++i) {
+    moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1;
+  }
+  mxnet::TShape red_src_shape, red_dst_shape;
+  BroadcastReduceShapeCompact(temp_data_shape, moments_shape, &red_src_shape, 
+  int channel_size = red_src_shape.Size() / red_dst_shape.Size();
+  TBlob data_ = data.reshape(red_src_shape);
+  const TBlob& mean_ = mean.reshape(red_dst_shape);
+  const TBlob& std_ = std.reshape(red_dst_shape);
+  Tensor<xpu, 1, char> workspace;
+  size_t workspace_size = 0;
+  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      workspace_size =
+        broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_dst_shape, req[0], 
+    });
+  });
+  workspace = ctx.requested[0].get_space_typed<xpu, 1, 
char>(Shape1(workspace_size), s);
+  // Calculate mean
+  MSHADOW_REAL_TYPE_SWITCH(data.type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::identity, 
+        s, mean_, req[0], workspace, data_);
+      Tensor<xpu, 1, DType> mean_data_tensor = mean_.FlatTo1D<xpu, DType>(s);
+      mean_data_tensor /= scalar<DType>(channel_size);
+    });
+  });
+  TBlob data_grp = data.reshape(temp_data_shape);
+  const TBlob& mean_grp = mean.reshape(moments_shape);
+  const TBlob& std_grp = std.reshape(moments_shape);
+  const TBlob& output = outputs[groupnorm::kOut].reshape(temp_data_shape);
+  // Calculate data = data - mean
+  BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                     {data_grp, mean_grp},
+                                                     {kWriteTo}, {output});
+  // Calculate std
+  const TBlob centered_out = outputs[groupnorm::kOut].reshape(red_src_shape);
+  MSHADOW_REAL_TYPE_SWITCH(output.type_flag_, DType, {
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      broadcast::Reduce<mshadow_op::sum, NDim, DType, mshadow_op::square, 
+        s, std_, req[0], workspace, centered_out);
+      Tensor<xpu, 1, DType> std_data_tensor = std_.FlatTo1D<xpu, DType>(s);
+      std_data_tensor = F<mshadow_op::square_root>(std_data_tensor / 
+                        + scalar<DType>(param.eps));
+    });
+  });
+  // Calculate data = data / std
+  BinaryBroadcastCompute<xpu, mshadow_op::div>(attrs, ctx,
+                                               {output, std_grp},
+                                               {kWriteTo}, {output});
+  mxnet::TShape new_param_shape(data_shape.ndim() + 1, 1);
+  new_param_shape[1] = num_groups;
+  const TBlob& gamma = inputs[groupnorm::kGamma].reshape(new_param_shape);
+  const TBlob& beta = inputs[groupnorm::kBeta].reshape(new_param_shape);
+  // Calculate data = data * gamma
+  BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                   {output, gamma},
+                                                   {kWriteTo}, {output});
+  // Calculate data = data + beta
+  BinaryBroadcastCompute<xpu, op::mshadow_op::plus>(attrs, ctx,
+                                                   {output, beta},
+                                                   {kWriteTo}, {output});
+Calculate the gradient of group normalization.
+We have the following gradient for gamma, beta and x:
+\bar{x} = (x - mean) / std
+w = og * r / std
+grad_gamma = sum(\bar{x} og, exclude_axis)
+grad_beta = sum(og, exclude_axis)
+grad_x = w - mean(w, axis) - \bar{x} * mean(w * \bar{x}, axis)
+template<typename xpu>
+void GroupNormGradCompute(const nnvm::NodeAttrs& attrs,
+                          const OpContext& ctx,
+                          const std::vector<TBlob>& inputs,
+                          const std::vector<OpReqType>& req,
+                          const std::vector<TBlob>& outputs) {
+  using namespace mshadow;
+  using namespace mshadow::expr;
+  using namespace mxnet_op;
+  CHECK_EQ(inputs.size(), 5U);
+  CHECK_EQ(outputs.size(), 3U);
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  const int num_groups = param.num_groups;
+  const TBlob& data = inputs[1];
+  const mxnet::TShape& dshape = data.shape_;
+  mxnet::TShape temp_dshape(dshape.ndim() + 1, 1);
+  temp_dshape[0] = dshape[0];
+  temp_dshape[1] = num_groups;
+  temp_dshape[2] = dshape[1] / num_groups;
+  for (int i = 2; i < dshape.ndim(); ++i) {
+    temp_dshape[i+1] = dshape[i];
+  }
+  const TBlob& data_ = data.reshape(temp_dshape);
+  const TBlob& ograd = inputs[0].reshape(temp_dshape);
+  Stream<xpu> *s = ctx.get_stream<xpu>();
+  // Reshape gamma to be broadcastable
+  mxnet::TShape new_param_shape(dshape.ndim() + 1, 1);
+  new_param_shape[1] = num_groups;
+  const TBlob& gamma = inputs[2].reshape(new_param_shape);
+  const TBlob& mean = inputs[3];
+  const TBlob& std = inputs[4];
+  mxnet::TShape moments_shape(temp_dshape.ndim(), 1);
+  for (int i = 0; i < dshape.ndim(); ++i) {
+    moments_shape[i] = (i < mean.shape_.ndim()) ? mean.shape_[i] : 1;
+  }
+  const TBlob& mean_ = mean.reshape(moments_shape);
+  const TBlob& std_ = std.reshape(moments_shape);
+  // Prepare the necessary shapes for reduction
+  mxnet::TShape red_src_shape, red_dst_shape, red_exclude_src_shape, 
+  BroadcastReduceShapeCompact(temp_dshape, mean_.shape_, &red_src_shape, 
+  BroadcastReduceShapeCompact(temp_dshape, gamma.shape_,
+                              &red_exclude_src_shape, &red_exclude_dst_shape);
+  int N = red_src_shape.Size() / red_dst_shape.Size();
+  // Initialize the workspace + Construct the temporary TBlobs
+  Tensor<xpu, 1, char> workspace;
+  size_t reduce_workspace_size = 0;
+  size_t data_size = 0;
+  size_t red_out_size = 0;
+  MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+    data_size = sizeof(DType) * data.Size();
+    red_out_size = sizeof(DType) * mean.Size();
+    // There are two types of reduction workloads: reduce over axis and reduce 
exclude axis
+    // We take the maximum of the workspace sizes required by these workloads.
+    // Also, we explicitly set the req_type=kAddto in case we want to use it.
+    BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+      reduce_workspace_size =
+        std::max(reduce_workspace_size,
+                 broadcast::ReduceWorkspaceSize<NDim, DType>(s, red_dst_shape,
+                                                             kAddTo, 
+    });
+    BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+      reduce_workspace_size =
+        std::max(reduce_workspace_size,
+                 broadcast::ReduceWorkspaceSize<NDim, DType>(s, 
red_exclude_dst_shape, kAddTo,
+    });
+  });
+  workspace = ctx.requested[0].get_space_typed<xpu, 1, char>(
+    Shape1(reduce_workspace_size + data_size * 2 + red_out_size), s);
+  const TBlob normalized_data =
+    TBlob(workspace.dptr_ + reduce_workspace_size,
+          data_.shape_, data.dev_mask(), data.type_flag_, data.dev_id());
+  const TBlob ograd_mult = TBlob(workspace.dptr_ + reduce_workspace_size + 
+                                 data_.shape_, ograd.dev_mask(), 
ograd.type_flag_, ograd.dev_id());
+  const TBlob red_out = TBlob(workspace.dptr_ + reduce_workspace_size + 
data_size * 2,
+                              mean_.shape_, mean.dev_mask(), mean.type_flag_, 
+  // Compute normalized_data = (data - mean) / std
+  BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                    {data_, mean_},
+                                                    {kWriteTo}, 
+  BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
+                                                   {normalized_data, std_},
+                                                   {kWriteTo}, 
+  // Calculate grad_beta
+  if (req[2] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[2].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<red::sum, NDim, DType, op::mshadow_op::identity, 
+          s, outputs[2].reshape(red_exclude_dst_shape), req[2], workspace,
+          ograd.reshape(red_exclude_src_shape));
+      });
+    });
+  }
+  // Calculate grad_gamma, it will be sum(ograd * normalized_data, 
+  ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, 
{normalized_data, ograd},
+                                                      {kWriteTo}, 
+  if (req[1] != kNullOp) {
+    MSHADOW_REAL_TYPE_SWITCH(outputs[1].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<mshadow_op::sum, NDim, DType, 
op::mshadow_op::identity, true>(
+          s, outputs[1].reshape(red_exclude_dst_shape), req[1], workspace,
+          ograd_mult.reshape(red_exclude_src_shape));
+      });
+    });
+  }
+  // Calculate grad_data:
+  //   ograd_mult = ograd * gamma / std
+  //   grad_data = ograd_mult - mean(ograd_mult, axis)
+  //               + normalized_data * (-mean(normalized_data * ograd_mult, 
+  if (req[0] != kNullOp) {
+    const TBlob output_ = outputs[0].reshape(data_.shape_);
+    BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                    {ograd, gamma},
+                                                    {kWriteTo}, {ograd_mult});
+    BinaryBroadcastCompute<xpu, op::mshadow_op::div>(attrs, ctx,
+                                                    {ograd_mult, std_},
+                                                    {kWriteTo}, {ograd_mult});
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<mshadow_op::sum, NDim, DType, 
op::mshadow_op::identity, true>(
+          s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+          ograd_mult.reshape(red_src_shape));
+      });
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /= scalar<DType>(N);
+    });
+    BinaryBroadcastCompute<xpu, op::mshadow_op::minus>(attrs, ctx,
+                                                      {ograd_mult, red_out},
+                                                      {req[0]}, {output_});
+    ElemwiseBinaryOp::Compute<xpu, op::mshadow_op::mul>(attrs, ctx, 
{ograd_mult, normalized_data},
+                                                        {kWriteTo}, 
+    MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
+      BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, {
+        broadcast::Reduce<mshadow_op::sum, NDim, DType, 
op::mshadow_op::identity, true>(
+          s, red_out.reshape(red_dst_shape), kWriteTo, workspace,
+          ograd_mult.reshape(red_src_shape));
+      });
+      Tensor<xpu, 1, DType> red_out_tensor = red_out.FlatTo1D<xpu, DType>(s);
+      red_out_tensor /= scalar<DType>(-N);
+    });
+    BinaryBroadcastCompute<xpu, op::mshadow_op::mul>(attrs, ctx,
+                                                     {normalized_data, 
+                                                     {kAddTo}, {output_});
+  }
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/nn/ b/src/operator/nn/
new file mode 100644
index 0000000..b4698ab
--- /dev/null
+++ b/src/operator/nn/
@@ -0,0 +1,131 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ * Copyright (c) 2019 by Contributors
+ * \file
+ * \brief Implements Group Normalization (
+#include "group_norm-inl.h"
+#include <nnvm/op_attr_types.h>
+#include "../elemwise_op_common.h"
+namespace mxnet {
+namespace op {
+static bool GroupNormShape(const nnvm::NodeAttrs& attrs,
+                           mxnet::ShapeVector *in_shape,
+                           mxnet::ShapeVector *out_shape) {
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  using namespace mshadow;
+  CHECK_EQ(in_shape->size(), 3U) << "Input:[data, gamma, beta]";
+  const mxnet::TShape &dshape = in_shape->at(groupnorm::kData);
+  CHECK_GE(dshape.ndim(), 3U);
+  const int num_groups = param.num_groups;
+  CHECK_EQ(dshape[1] % num_groups, 0) << "# of channels must be divisible by # 
of groups";
+  if (!mxnet::ndim_is_known(dshape)) {
+    return false;
+  }
+  in_shape->at(groupnorm::kGamma) = mxnet::TShape(Shape1(num_groups));
+  in_shape->at(groupnorm::kBeta) = mxnet::TShape(Shape1(num_groups));
+  out_shape->clear();
+  out_shape->push_back(dshape);
+  mxnet::TShape moments_shape(2, 1);
+  moments_shape[0] = dshape[0];
+  moments_shape[1] = num_groups;
+  out_shape->push_back(moments_shape);
+  out_shape->push_back(moments_shape);
+  return true;
+.describe(R"code(Group normalization.
+The input channels are separated into ``num_groups`` groups, each containing 
``num_channels / num_groups`` channels.
+The mean and standard-deviation are calculated separately over the each group.
+.. math::
+  data = data.reshape((N, num_groups, C // num_groups, ...))
+  out = \frac{data - mean(data, axis)}{\sqrt{var(data, axis) + \epsilon}} * 
gamma + beta
+Both ``gamma`` and ``beta`` are learnable parameters.
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"data", "gamma", "beta"};
+    [](const NodeAttrs& attrs) {
+  return std::vector<std::string>{"output", "mean", "std"};
+    [](const NodeAttrs& attrs) {
+  const GroupNormParam& param = nnvm::get<GroupNormParam>(attrs.parsed);
+  return param.output_mean_var ? 3 : 1;
+.set_attr<mxnet::FInferShape>("FInferShape", GroupNormShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<3, 3>)
+.set_attr<FCompute>("FCompute<cpu>", GroupNormCompute<cpu>)
+.set_attr<nnvm::FGradient>("FGradient", [](const nnvm::NodePtr& n,
+                                           const std::vector<nnvm::NodeEntry>& 
ograds) {
+  std::vector<nnvm::NodeEntry> heads;
+  heads.push_back(ograds[0]);  // ograd
+  heads.push_back(n->inputs[0]);  // data
+  heads.push_back(n->inputs[1]);  // gamma
+  heads.emplace_back(nnvm::NodeEntry{n, 1, 0});  // mean
+  heads.emplace_back(nnvm::NodeEntry{ n, 2, 0 });  // std
+  return MakeGradNode("_backward_GroupNorm", n, heads, n->attrs.dict);
+  [](const NodeAttrs& attrs) {
+  return std::vector<std::pair<int, int> >{{0, 0}};
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+.add_argument("data", "NDArray-or-Symbol", "Input data")
+.add_argument("gamma", "NDArray-or-Symbol", "gamma array")
+.add_argument("beta", "NDArray-or-Symbol", "beta array")
+.set_attr<nnvm::TIsBackward>("TIsBackward", true)
+.set_attr<FCompute>("FCompute<cpu>", GroupNormGradCompute<cpu>)
+.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
+  return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
+}  // namespace op
+}  // namespace mxnet
diff --git a/src/operator/nn/ b/src/operator/nn/
new file mode 100644
index 0000000..136c333
--- /dev/null
+++ b/src/operator/nn/
@@ -0,0 +1,37 @@
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+ * Copyright (c) 2019 by Contributors
+ * \file
+ * \brief Implements Group Normalization (
+#include "./group_norm-inl.h"
+namespace mxnet {
+namespace op {
+.set_attr<FCompute>("FCompute<gpu>", GroupNormCompute<gpu>);
+.set_attr<FCompute>("FCompute<gpu>", GroupNormGradCompute<gpu>);
+}  // namespace op
+}  // namespace mxnet
diff --git a/tests/python/unittest/ 
index d52e7f8..b59ce2d 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -744,6 +744,15 @@ def test_layernorm():
+def test_groupnorm():
+    layer = nn.GroupNorm()
+    check_layer_forward(layer, (2, 10, 10, 10))
+    layer = nn.GroupNorm(num_groups=2)
+    check_layer_forward(layer, (2, 10, 10, 10))
+    layer = nn.GroupNorm(num_groups=5)
+    check_layer_forward(layer, (2, 10, 10, 10))
 def test_reflectionpad():
     layer = nn.ReflectionPad2D(3)
     check_layer_forward(layer, (2, 3, 24, 24))
diff --git a/tests/python/unittest/ 
index aeddc7a..749f0f2 100644
--- a/tests/python/unittest/
+++ b/tests/python/unittest/
@@ -1831,6 +1831,97 @@ def test_batchnorm():
+def test_groupnorm():
+    acc_types = {'float16': 'float32', 'float32': 'float64', 'float64': 
+    def x_hat_helper(x, num_groups, eps):
+        dtype = x.dtype
+        dshape = x.shape
+        assert len(dshape) == 4
+        acc_type = acc_types[str(dtype)]
+        new_shape = (dshape[0], num_groups, int(dshape[1] / num_groups), 
dshape[2], dshape[3])
+        new_moments_shape = (dshape[0], num_groups, 1, 1, 1)
+        data = x.reshape(new_shape)
+        mean = np.mean(data, axis=(2, 3, 4), keepdims=False, 
+        std = np.sqrt(np.var(data, axis=(2, 3, 4), dtype=acc_type, 
keepdims=False).astype(dtype) + eps)
+        x_hat = (data - mean.reshape(new_moments_shape)) / 
+        return x_hat, mean, std
+    def np_groupnorm(data, gamma, beta, num_groups, eps):
+        new_param_shape = (1, num_groups, 1, 1, 1)
+        x_hat, mean, std = x_hat_helper(data, num_groups, eps)
+        out = x_hat * gamma.reshape(new_param_shape) + 
+        return out.reshape(dshape), mean, std
+    def np_groupnorm_grad(ograd, data, gamma, beta, mean, std, num_groups, 
+        x_hat, mean, std = x_hat_helper(data, num_groups, eps)
+        new_shape = x_hat.shape
+        dshape = data.shape
+        dtype = data.dtype
+        new_moments_shape = (new_shape[0], num_groups, 1, 1, 1)
+        new_param_shape = (1, num_groups, 1, 1, 1)
+        acc_type = acc_types[str(dtype)]
+        ograd = ograd.reshape(new_shape)
+        data = data.reshape(new_shape)
+        gamma = gamma.reshape(new_param_shape)
+        beta = beta.reshape(new_param_shape)
+        mean = mean.reshape(new_moments_shape)
+        std = std.reshape(new_moments_shape)
+        beta_grad = np.sum(ograd, axis=(0, 2, 3, 4), dtype=acc_type, 
+        gamma_grad = np.sum(x_hat * ograd, axis=(0, 2, 3, 4), dtype=acc_type, 
+        x_hat_grad = ograd * gamma
+        ograd_mult = x_hat_grad / std
+        red_out = np.mean(ograd_mult, axis=(2, 3, 4), dtype=acc_type, 
+        data_grad = ograd_mult - red_out
+        red_out = np.mean(ograd_mult * x_hat, axis=(2, 3, 4), dtype=acc_type, 
+        data_grad = data_grad - x_hat * red_out
+        return data_grad.reshape(dshape), gamma_grad, beta_grad
+    batch_size = random.randint(1, 8)
+    num_groups = random.randint(2, 3)
+    num_channels = random.randint(2, 3) * num_groups
+    height = random.randint(1, 5)
+    width = random.randint(1, 5)
+    dshape = (batch_size, num_channels, height, width)
+    param_shape = (num_groups,)
+    temp_shape = (batch_size, num_groups, int(num_channels / num_groups), 
height, width)
+    np_data = np.random.uniform(0.2, 1.0, dshape)
+    np_gamma = np.random.uniform(-1.0, 1.0, param_shape)
+    np_beta = np.random.uniform(-1.0, 1.0, param_shape)
+    data_sym = mx.sym.Variable("data")
+    gamma_sym = mx.sym.Variable("gamma")
+    beta_sym = mx.sym.Variable("beta")
+    for dtype in [np.float16, np.float32, np.float64]:
+        eps = 1e-2 if dtype == np.float16 else 1e-5
+        mx_data = mx.nd.array(np_data, dtype=dtype)
+        mx_gamma = mx.nd.array(np_gamma, dtype=dtype)
+        mx_beta = mx.nd.array(np_beta, dtype=dtype)
+        np_out, np_mean, np_std = np_groupnorm(np_data.astype(dtype),
+                                               np_gamma.astype(dtype),
+                                               np_beta.astype(dtype),
+                                               num_groups=num_groups,
+                                               eps=eps)
+        mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, 
+                                  num_groups=num_groups, eps=eps, 
+        check_symbolic_forward(mx_sym, [mx_data, mx_gamma, mx_beta], [np_out, 
np_mean, np_std],
+                               rtol=1e-2 if dtype == np.float16 else 1e-3,
+                               atol=5e-3 if dtype == np.float16 else 1e-5, 
+        mx_sym = mx.sym.GroupNorm(data=data_sym, gamma=gamma_sym, 
+                                  num_groups=num_groups, eps=eps, 
+        np_ograd = np.random.uniform(-1.0, 1.0, dshape).astype(dtype)
+        np_data_grad, np_gamma_grad, np_beta_grad = np_groupnorm_grad(np_ograd,
+                                                                      np_mean, 
num_groups, eps)
+        check_symbolic_backward(mx_sym, [mx_data, mx_gamma, mx_beta], 
+                                [np_data_grad, np_gamma_grad, np_beta_grad],
+                                rtol=1e-2 if dtype == np.float16 else 1e-3,
+                                atol=5e-2 if dtype == np.float16 else 1e-5, 
 def test_convolution_grouping():
     for dim in [1, 2, 3]:
         num_filter = 4

Reply via email to