This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new baedf7f04d [TOPI] Group normalization (#14193)
baedf7f04d is described below

commit baedf7f04dd97ed3f2de073c4d11649ec95283e5
Author: Ruihang Lai <[email protected]>
AuthorDate: Sat Mar 4 19:22:11 2023 -0500

    [TOPI] Group normalization (#14193)
    
    As more and more ML models nowadays contain the group normalization
    computation, we find it beneficial to introduce this op to TOPI level.
    It will enable us to optimize the group normalization operation as a
    whole in a more convenient way.
    
    This PR introduces the group normalization op to TOPI. The group norm
    operation was introduced in https://arxiv.org/abs/1803.08494. The
    implementation uses tuple reduction, same as the implementation of layer
    norm. Implemented with tuple reduction, the corresponding generated TIR
    function can be optimized by cross-thread reduction or rfactor through
    MetaSchedule.
    
    
    Co-authored-by: Bohan Hou <[email protected]>
---
 include/tvm/topi/nn/group_norm.h                   | 151 +++++++++++++++++++++
 python/tvm/topi/nn/__init__.py                     |   1 +
 python/tvm/topi/nn/group_norm.py                   |  52 +++++++
 python/tvm/topi/testing/__init__.py                |   1 +
 python/tvm/topi/testing/group_norm_python.py       |  82 +++++++++++
 src/topi/nn.cc                                     |   7 +
 ..._topi_layer_norm.py => test_topi_group_norm.py} |  30 ++--
 tests/python/topi/python/test_topi_layer_norm.py   |   2 +-
 8 files changed, 312 insertions(+), 14 deletions(-)

diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h
new file mode 100644
index 0000000000..43760bab1f
--- /dev/null
+++ b/include/tvm/topi/nn/group_norm.h
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \brief group normalization op constructions
+ * \file nn/group_norm.h
+ */
+#ifndef TVM_TOPI_NN_GROUP_NORM_H_
+#define TVM_TOPI_NN_GROUP_NORM_H_
+
+#include <tvm/te/operation.h>
+#include <tvm/topi/tags.h>
+
+#include <algorithm>
+#include <string>
+#include <vector>
+
+namespace tvm {
+namespace topi {
+namespace nn {
+
+using namespace tvm::te;
+
+inline Tensor group_norm(const Tensor& data, const Tensor& gamma, const 
Tensor& beta,
+                         int num_groups, int channel_axis, const 
Array<Integer>& axes,
+                         double epsilon, std::string name = "T_group_norm",
+                         std::string tag = kInjective) {
+  // reshape data C -> G, C/G
+  int ndim = data->shape.size();
+  channel_axis = GetRealAxis(ndim, {channel_axis})[0];
+
+  auto shape = data->shape;
+  auto group_size = floordiv(shape[channel_axis], num_groups);
+  auto new_shape = Array<PrimExpr>();
+  for (int i = 0; i < ndim; ++i) {
+    if (i == channel_axis) {
+      new_shape.push_back(num_groups);
+      new_shape.push_back(group_size);
+    } else {
+      new_shape.push_back(shape[i]);
+    }
+  }
+  auto data_reshaped = reshape(data, new_shape);
+  // reshape gamma and beta, C -> G, C/G
+  Tensor gamma_reshaped;
+  if (gamma.defined()) {
+    gamma_reshaped = reshape(gamma, {num_groups, group_size});
+  }
+  Tensor beta_reshaped;
+  if (beta.defined()) {
+    beta_reshaped = reshape(beta, {num_groups, group_size});
+  }
+
+  // get the new axes to normalize after reshape
+  std::vector<int> new_axes{channel_axis + 1};
+  for (auto axis : axes) {
+    int new_axis = GetRealAxis(ndim, {axis})[0];
+    if (new_axis < channel_axis) {
+      new_axes.push_back(new_axis);
+    } else if (new_axis > channel_axis) {
+      new_axes.push_back(new_axis + 1);
+    } else {
+      ICHECK(false) << "axes can not contain channel axis";
+    }
+  }
+  std::sort(new_axes.begin(), new_axes.end());
+
+  // sum x and x^2
+  ndim = data_reshaped->shape.size();
+  auto reduce_axes = MakeReduceAxes(new_axes, data_reshaped);
+  auto target_shape =
+      MakeReduceTargetShape(new_axes, data_reshaped, /*keepdims=*/false, 
/*atleast1d=*/true);
+  auto func = MakeTupleSumReducer();
+
+  auto compute = [ndim, &new_axes, &reduce_axes, &func, &data_reshaped](const 
Array<Var>& indices) {
+    Array<PrimExpr> eval_range;
+    int arg_counter = 0;
+    int red_counter = 0;
+
+    for (int i = 0; i < ndim; ++i) {
+      if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
+        // new_axes contains i
+        eval_range.push_back(reduce_axes[red_counter]);
+        red_counter++;
+      } else {
+        eval_range.push_back(indices[arg_counter]);
+        arg_counter++;
+      }
+    }
+    auto square = [](const PrimExpr& x) { return x * x; };
+    return func({data_reshaped(eval_range), 
square(data_reshaped(eval_range))}, reduce_axes,
+                nullptr);
+  };
+
+  auto temp_x_x2 =
+      tvm::te::compute(target_shape, compute, data->op->name + "_red_temp", 
kCommReduce);
+
+  auto temp_x = temp_x_x2[0];
+  auto temp_x2 = temp_x_x2[1];
+  auto reduce_extent = make_const(data->dtype, 1);
+  for (auto axis : new_axes) {
+    reduce_extent *= data_reshaped->shape[axis];
+  }
+  auto group_norm_func = [&](const Array<Var>& indices) {
+    Array<Var> reduce_indices, non_reduce_indices, gamma_indices;
+    for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
+      if (std::find(new_axes.begin(), new_axes.end(), i) != new_axes.end()) {
+        reduce_indices.push_back(indices[i]);
+      } else {
+        non_reduce_indices.push_back(indices[i]);
+      }
+    }
+    gamma_indices = {indices[channel_axis], indices[channel_axis + 1]};
+    auto mean = temp_x(non_reduce_indices) / reduce_extent;
+    auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
+    auto group_norm =
+        (data_reshaped(indices) - mean) * tvm::rsqrt(var + 
make_const(data->dtype, epsilon));
+    if (gamma.defined()) {
+      group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices));
+    }
+    if (beta.defined()) {
+      group_norm = topi::add(group_norm, beta_reshaped(gamma_indices));
+    }
+    return group_norm;
+  };
+  auto group_norm_out = tvm::te::compute(data_reshaped->shape, 
group_norm_func, name, tag);
+  auto group_norm_out_reshaped = reshape(group_norm_out, shape);
+  return group_norm_out_reshaped;
+}
+
+}  // namespace nn
+}  // namespace topi
+}  // namespace tvm
+
+#endif  // TVM_TOPI_NN_GROUP_NORM_H_
diff --git a/python/tvm/topi/nn/__init__.py b/python/tvm/topi/nn/__init__.py
index 8f081242fa..80a21e6531 100644
--- a/python/tvm/topi/nn/__init__.py
+++ b/python/tvm/topi/nn/__init__.py
@@ -39,6 +39,7 @@ from .bnn import *
 from .qnn import *
 from .upsampling import *
 from .layer_norm import layer_norm
+from .group_norm import group_norm
 from .local_response_norm import *
 from .bitserial_conv2d import *
 from .bitserial_dense import *
diff --git a/python/tvm/topi/nn/group_norm.py b/python/tvm/topi/nn/group_norm.py
new file mode 100644
index 0000000000..c6358b8bc6
--- /dev/null
+++ b/python/tvm/topi/nn/group_norm.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Layer normalization operator."""
+from .. import cpp
+
+
+def group_norm(data, gamma, beta, num_groups, channel_axis, axes, 
epsilon=1e-5):
+    """Group normalization operator.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+
+    gamma: tvm.te.Tensor
+        1-D with shape (r_0) where r_0 == d_{channel_axis}
+
+    beta: tvm.te.Tensor
+        Optional, 1-D with shape (r_0) where r_0 == d_{channel_axis}
+
+    num_groups : int
+        The number of groups
+
+    channel_axis : int
+        The channel axis
+
+    axes : list of int
+        Axis over the normalization applied, excluding the channel axis
+
+    epsilon : float
+        The epsilon value to avoid division by zero.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+    """
+    return cpp.nn.group_norm(data, gamma, beta, num_groups, channel_axis, 
axes, epsilon)
diff --git a/python/tvm/topi/testing/__init__.py 
b/python/tvm/topi/testing/__init__.py
index 2922c30b50..ef48090583 100644
--- a/python/tvm/topi/testing/__init__.py
+++ b/python/tvm/topi/testing/__init__.py
@@ -44,6 +44,7 @@ from .reorg_python import reorg_python
 from .roi_align_python import roi_align_nchw_python, roi_align_nhwc_python
 from .roi_pool_python import roi_pool_nchw_python
 from .layer_norm_python import layer_norm_python
+from .group_norm_python import group_norm_python
 from .lrn_python import lrn_python
 from .l2_normalize_python import l2_normalize_python
 from .gather_python import gather_python
diff --git a/python/tvm/topi/testing/group_norm_python.py 
b/python/tvm/topi/testing/group_norm_python.py
new file mode 100644
index 0000000000..d1c0d4a6ab
--- /dev/null
+++ b/python/tvm/topi/testing/group_norm_python.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""Group normalization in python"""
+import numpy as np
+
+
+def group_norm_python(data, gamma, beta, num_groups, channel_axis, axes, 
epsilon=1e-5):
+    """Group normalization operator.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+
+    gamma: tvm.te.Tensor
+        1-D with shape (r_0) where r_0 == d_{channel_axis}
+
+    beta: tvm.te.Tensor
+        Optional, 1-D with shape (r_0) where r_0 == d_{channel_axis}
+
+    num_groups : int
+        The number of groups
+
+    channel_axis : int
+        The channel axis
+
+    axes : list of int
+        Axis over the normalization applied, excluding the channel axis
+
+    epsilon : float
+        The epsilon value to avoid division by zero.
+
+    Returns
+    -------
+    result : tvm.te.Tensor
+        N-D with shape (d_0, d_1, ..., d_{N-1})
+    """
+    old_shape = data.shape
+    new_shape = list(old_shape)
+    new_shape[channel_axis] = data.shape[channel_axis] // num_groups
+    new_shape.insert(channel_axis, num_groups)
+    data = np.reshape(data, new_shape)
+    new_axes = [channel_axis + 1]
+    for axis in axes:
+        if axis < channel_axis:
+            new_axes.append(axis)
+        else:
+            new_axes.append(axis + 1)
+    mean = np.mean(data, axis=tuple(new_axes), keepdims=True)
+    var = np.var(data, axis=tuple(new_axes), keepdims=True)
+    data = (data - mean) / np.sqrt(var + epsilon)
+    data = np.reshape(data, old_shape)
+
+    gamma_broadcast_shape = [1 for _ in range(len(old_shape))]
+    gamma_broadcast_shape[channel_axis] = gamma.shape[0]
+    gamma = np.reshape(gamma, gamma_broadcast_shape)
+
+    beta_broadcast_shape = [1 for _ in range(len(old_shape))]
+    beta_broadcast_shape[channel_axis] = beta.shape[0]
+    if beta is not None:
+        beta = np.reshape(beta, beta_broadcast_shape)
+
+    data *= gamma
+    if beta is not None:
+        data += beta
+
+    return data
diff --git a/src/topi/nn.cc b/src/topi/nn.cc
index 35dbf3a03e..3b2c11010f 100644
--- a/src/topi/nn.cc
+++ b/src/topi/nn.cc
@@ -29,6 +29,7 @@
 #include <tvm/topi/nn/dense.h>
 #include <tvm/topi/nn/dilate.h>
 #include <tvm/topi/nn/flatten.h>
+#include <tvm/topi/nn/group_norm.h>
 #include <tvm/topi/nn/layer_norm.h>
 #include <tvm/topi/nn/local_response_norm.h>
 #include <tvm/topi/nn/mapping.h>
@@ -163,5 +164,11 @@ 
TVM_REGISTER_GLOBAL("topi.nn.layer_norm").set_body([](TVMArgs args, TVMRetValue*
   *rv = nn::layer_norm(args[0], args[1], args[2], args[3], 
static_cast<double>(args[4]));
 });
 
+/* Ops from nn/group_norm.h */
+TVM_REGISTER_GLOBAL("topi.nn.group_norm").set_body([](TVMArgs args, 
TVMRetValue* rv) {
+  *rv = nn::group_norm(args[0], args[1], args[2], static_cast<int>(args[3]),
+                       static_cast<int>(args[4]), args[5], 
static_cast<double>(args[6]));
+});
+
 }  // namespace topi
 }  // namespace tvm
diff --git a/tests/python/topi/python/test_topi_layer_norm.py 
b/tests/python/topi/python/test_topi_group_norm.py
similarity index 62%
copy from tests/python/topi/python/test_topi_layer_norm.py
copy to tests/python/topi/python/test_topi_group_norm.py
index ead05470be..f094423916 100644
--- a/tests/python/topi/python/test_topi_layer_norm.py
+++ b/tests/python/topi/python/test_topi_group_norm.py
@@ -14,7 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Test code for layer_norm."""
+"""Test code for group_norm."""
 import numpy as np
 import pytest
 import tvm
@@ -26,28 +26,32 @@ import tvm.topi.testing
 import tvm.testing
 
 
-_layer_norm_schedule = {
+_group_norm_schedule = {
     "generic": topi.generic.schedule_injective,
 }
 
 
 # only test on llvm because schedule is missing
 @tvm.testing.parametrize_targets("llvm")
[email protected]("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1, 
2))])
-def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32", 
rtol=1e-5, atol=1e-5):
[email protected]("shape, axis", [([2, 4, 16], (2,)), ([2, 4, 4, 16], 
(2, 3))])
+def test_group_norm(target, dev, shape, axis, epsilon=1e-5, dtype="float32", 
rtol=1e-5, atol=1e-5):
     data = te.placeholder(shape, dtype=dtype, name="data")
-    scale_shape = [shape[dim] for dim in axis]
-    gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma")
-    beta = te.placeholder(scale_shape, dtype=dtype, name="beta")
-    B = topi.nn.layer_norm(data, gamma, beta, axis, episilon)
+    num_groups = 2
+    channel_axis = 1
+    gamma = te.placeholder((shape[channel_axis],), dtype=dtype, name="gamma")
+    beta = te.placeholder((shape[channel_axis],), dtype=dtype, name="beta")
+    B = topi.nn.group_norm(data, gamma, beta, num_groups, channel_axis, axis, 
epsilon)
 
+    np.random.seed(0)
     data_np = np.random.uniform(size=shape).astype(dtype)
-    gamma_np = np.random.uniform(size=scale_shape).astype(dtype)
-    beta_np = np.random.uniform(size=scale_shape).astype(dtype)
-    b_np = tvm.topi.testing.layer_norm_python(data_np, gamma_np, beta_np, 
axis, episilon)
+    gamma_np = np.random.uniform(size=(shape[channel_axis],)).astype(dtype)
+    beta_np = np.random.uniform(size=(shape[channel_axis],)).astype(dtype)
+    b_np = tvm.topi.testing.group_norm_python(
+        data_np, gamma_np, beta_np, num_groups, channel_axis, axis, epsilon
+    )
 
     with tvm.target.Target(target):
-        s_func = tvm.topi.testing.dispatch(target, _layer_norm_schedule)
+        s_func = tvm.topi.testing.dispatch(target, _group_norm_schedule)
         s = s_func([B])
     data_tvm = tvm.nd.array(data_np, dev)
     gamma_tvm = tvm.nd.array(gamma_np, dev)
@@ -55,7 +59,7 @@ def test_layer_norm(target, dev, shape, axis, episilon=1e-5, 
dtype="float32", rt
     b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
     f = tvm.build(s, [data, gamma, beta, B], target)
     f(data_tvm, gamma_tvm, beta_tvm, b_tvm)
-    tvm.testing.assert_allclose(b_tvm.asnumpy(), b_np, rtol=rtol, atol=atol)
+    tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/topi/python/test_topi_layer_norm.py 
b/tests/python/topi/python/test_topi_layer_norm.py
index ead05470be..f875bb09e2 100644
--- a/tests/python/topi/python/test_topi_layer_norm.py
+++ b/tests/python/topi/python/test_topi_layer_norm.py
@@ -55,7 +55,7 @@ def test_layer_norm(target, dev, shape, axis, episilon=1e-5, 
dtype="float32", rt
     b_tvm = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
     f = tvm.build(s, [data, gamma, beta, B], target)
     f(data_tvm, gamma_tvm, beta_tvm, b_tvm)
-    tvm.testing.assert_allclose(b_tvm.asnumpy(), b_np, rtol=rtol, atol=atol)
+    tvm.testing.assert_allclose(b_tvm.numpy(), b_np, rtol=rtol, atol=atol)
 
 
 if __name__ == "__main__":

Reply via email to