This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 86ec016650 [Unity][TOPI] fp16 LayerNorm & GroupNorm (#14264)
86ec016650 is described below
commit 86ec0166509efc60cdf916dd24cf3908fb239df8
Author: Bohan Hou <[email protected]>
AuthorDate: Sat Mar 18 20:47:16 2023 -0400
[Unity][TOPI] fp16 LayerNorm & GroupNorm (#14264)
This pr modifies the topi implementation (which is also the legalizer's
backend of Relax) of LayerNorm and GroupNorm operators to allow them to accept
fp16 inputs, cast to fp32 internally, and produce fp16 outputs.
This can help eliminate unnecessary casts caused by AMP.
---
include/tvm/topi/nn/group_norm.h | 31 +++--
include/tvm/topi/nn/layer_norm.h | 28 +++-
python/tvm/topi/nn/group_norm.py | 2 +
python/tvm/topi/nn/layer_norm.py | 2 +
python/tvm/topi/testing/group_norm_python.py | 5 +-
python/tvm/topi/testing/layer_norm_python.py | 3 +
.../python/relax/test_transform_legalize_ops_nn.py | 144 +++++++++++++++++++++
tests/python/topi/python/test_topi_group_norm.py | 3 +-
tests/python/topi/python/test_topi_layer_norm.py | 3 +-
9 files changed, 206 insertions(+), 15 deletions(-)
diff --git a/include/tvm/topi/nn/group_norm.h b/include/tvm/topi/nn/group_norm.h
index 43760bab1f..5636de11ac 100644
--- a/include/tvm/topi/nn/group_norm.h
+++ b/include/tvm/topi/nn/group_norm.h
@@ -25,7 +25,6 @@
#define TVM_TOPI_NN_GROUP_NORM_H_
#include <tvm/te/operation.h>
-#include <tvm/topi/tags.h>
#include <algorithm>
#include <string>
@@ -41,9 +40,17 @@ inline Tensor group_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
int num_groups, int channel_axis, const
Array<Integer>& axes,
double epsilon, std::string name = "T_group_norm",
std::string tag = kInjective) {
+ const auto& data_type = data->dtype;
+ const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type;
+ const auto& beta_type = beta.defined() ? beta->dtype : data_type;
+ ICHECK(data_type == gamma_type && data_type == beta_type)
+ << "group_norm: data, gamma and beta must have the same type";
+ ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
+ << "group_norm: only support float32 and float16 for now";
+ bool is_float16 = data_type == DataType::Float(16);
// reshape data C -> G, C/G
int ndim = data->shape.size();
- channel_axis = GetRealAxis(ndim, {channel_axis})[0];
+ channel_axis = GetRealAxis(static_cast<int>(ndim), {channel_axis})[0];
auto shape = data->shape;
auto group_size = floordiv(shape[channel_axis], num_groups);
@@ -56,8 +63,13 @@ inline Tensor group_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
new_shape.push_back(shape[i]);
}
}
- auto data_reshaped = reshape(data, new_shape);
- // reshape gamma and beta, C -> G, C/G
+ Tensor data_reshaped;
+ if (is_float16) {
+ data_reshaped = cast(reshape(data, new_shape), DataType::Float(32));
+ } else {
+ data_reshaped = reshape(data, new_shape);
+ }
+ // reshape gamma and beta, C -> G, C/G, cast to float32 if float16
Tensor gamma_reshaped;
if (gamma.defined()) {
gamma_reshaped = reshape(gamma, {num_groups, group_size});
@@ -70,7 +82,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
// get the new axes to normalize after reshape
std::vector<int> new_axes{channel_axis + 1};
for (auto axis : axes) {
- int new_axis = GetRealAxis(ndim, {axis})[0];
+ int new_axis = GetRealAxis(static_cast<int>(ndim), {axis})[0];
if (new_axis < channel_axis) {
new_axes.push_back(new_axis);
} else if (new_axis > channel_axis) {
@@ -81,7 +93,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
}
std::sort(new_axes.begin(), new_axes.end());
- // sum x and x^2
+ // sum x and x^2, cast to float32 if float16
ndim = data_reshaped->shape.size();
auto reduce_axes = MakeReduceAxes(new_axes, data_reshaped);
auto target_shape =
@@ -113,7 +125,7 @@ inline Tensor group_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
auto temp_x = temp_x_x2[0];
auto temp_x2 = temp_x_x2[1];
- auto reduce_extent = make_const(data->dtype, 1);
+ auto reduce_extent = make_const(DataType::Float(32), 1);
for (auto axis : new_axes) {
reduce_extent *= data_reshaped->shape[axis];
}
@@ -129,8 +141,11 @@ inline Tensor group_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
gamma_indices = {indices[channel_axis], indices[channel_axis + 1]};
auto mean = temp_x(non_reduce_indices) / reduce_extent;
auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
- auto group_norm =
+ PrimExpr group_norm =
(data_reshaped(indices) - mean) * tvm::rsqrt(var +
make_const(data->dtype, epsilon));
+ if (is_float16) {
+ group_norm = Cast(DataType::Float(16), group_norm);
+ }
if (gamma.defined()) {
group_norm = topi::multiply(group_norm, gamma_reshaped(gamma_indices));
}
diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h
index 93e5582ef1..ee0cba74dd 100644
--- a/include/tvm/topi/nn/layer_norm.h
+++ b/include/tvm/topi/nn/layer_norm.h
@@ -51,6 +51,14 @@ using namespace tvm::te;
inline Tensor layer_norm(const Tensor& data, const Tensor& gamma, const
Tensor& beta,
const Array<Integer>& axis, double epsilon,
std::string name = "T_layer_norm", std::string tag =
kInjective) {
+ const auto& data_type = data->dtype;
+ const auto& gamma_type = gamma.defined() ? gamma->dtype : data_type;
+ const auto& beta_type = beta.defined() ? beta->dtype : data_type;
+ ICHECK(data_type == gamma_type && data_type == beta_type)
+ << "layer_norm: data, gamma and beta must have the same type";
+ ICHECK(data_type == DataType::Float(32) || data_type == DataType::Float(16))
+ << "layer_norm: only support float32 and float16 for now";
+ bool is_float16 = data_type == DataType::Float(16);
// sum x and x^2
auto ndim = data->shape.size();
ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
@@ -60,7 +68,8 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
MakeReduceTargetShape(real_axis, data, /*keepdims=*/false,
/*atleast1d=*/true);
auto func = MakeTupleSumReducer();
- auto compute = [ndim, &real_axis, &reduce_axes, &func, &data](const
Array<Var>& indices) {
+ auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
+ &data](const Array<Var>& indices) {
Array<PrimExpr> eval_range;
int arg_counter = 0;
int red_counter = 0;
@@ -75,8 +84,18 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
arg_counter++;
}
}
- auto square = [](const PrimExpr& x) { return x * x; };
- return func({data(eval_range), square(data(eval_range))}, reduce_axes,
nullptr);
+ auto square = [is_float16](const PrimExpr& x) {
+ if (is_float16) {
+ return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x);
+ }
+ return x * x;
+ };
+ if (is_float16) {
+ return func({Cast(DataType::Float(32), data(eval_range)),
square(data(eval_range))},
+ reduce_axes, nullptr);
+ } else {
+ return func({data(eval_range), square(data(eval_range))}, reduce_axes,
nullptr);
+ }
};
auto temp_x_x2 =
@@ -101,6 +120,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
auto mean = temp_x(non_reduce_indices) / reduce_extent;
auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var +
make_const(var->dtype, epsilon));
+ if (is_float16) {
+ layer_norm = Cast(DataType::Float(16), layer_norm);
+ }
layer_norm = topi::multiply(layer_norm, gamma(reduce_indices));
if (beta.defined()) {
layer_norm = topi::add(layer_norm, beta(reduce_indices));
diff --git a/python/tvm/topi/nn/group_norm.py b/python/tvm/topi/nn/group_norm.py
index c6358b8bc6..ea9d5da077 100644
--- a/python/tvm/topi/nn/group_norm.py
+++ b/python/tvm/topi/nn/group_norm.py
@@ -20,6 +20,8 @@ from .. import cpp
def group_norm(data, gamma, beta, num_groups, channel_axis, axes,
epsilon=1e-5):
"""Group normalization operator.
+ It accepts fp16 and fp32 as input data type. It will cast the input to fp32
+ to perform the computation. The output will have the same data type as
input.
Parameters
----------
diff --git a/python/tvm/topi/nn/layer_norm.py b/python/tvm/topi/nn/layer_norm.py
index 3bdeaaac61..7363f99c49 100644
--- a/python/tvm/topi/nn/layer_norm.py
+++ b/python/tvm/topi/nn/layer_norm.py
@@ -20,6 +20,8 @@ from .. import cpp
def layer_norm(data, gamma, beta, axis, epsilon=1e-5):
"""Layer normalization operator.
+ It accepts fp16 and fp32 as input data type. It will cast the input to fp32
+ to perform the computation. The output will have the same data type as
input.
Parameters
----------
diff --git a/python/tvm/topi/testing/group_norm_python.py
b/python/tvm/topi/testing/group_norm_python.py
index d1c0d4a6ab..7677348426 100644
--- a/python/tvm/topi/testing/group_norm_python.py
+++ b/python/tvm/topi/testing/group_norm_python.py
@@ -51,10 +51,11 @@ def group_norm_python(data, gamma, beta, num_groups,
channel_axis, axes, epsilon
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
old_shape = data.shape
+ old_dtype = data.dtype
new_shape = list(old_shape)
new_shape[channel_axis] = data.shape[channel_axis] // num_groups
new_shape.insert(channel_axis, num_groups)
- data = np.reshape(data, new_shape)
+ data = np.reshape(data, new_shape).astype("float32")
new_axes = [channel_axis + 1]
for axis in axes:
if axis < channel_axis:
@@ -64,7 +65,7 @@ def group_norm_python(data, gamma, beta, num_groups,
channel_axis, axes, epsilon
mean = np.mean(data, axis=tuple(new_axes), keepdims=True)
var = np.var(data, axis=tuple(new_axes), keepdims=True)
data = (data - mean) / np.sqrt(var + epsilon)
- data = np.reshape(data, old_shape)
+ data = np.reshape(data, old_shape).astype(old_dtype)
gamma_broadcast_shape = [1 for _ in range(len(old_shape))]
gamma_broadcast_shape[channel_axis] = gamma.shape[0]
diff --git a/python/tvm/topi/testing/layer_norm_python.py
b/python/tvm/topi/testing/layer_norm_python.py
index 6b3b001469..662383363b 100644
--- a/python/tvm/topi/testing/layer_norm_python.py
+++ b/python/tvm/topi/testing/layer_norm_python.py
@@ -44,9 +44,12 @@ def layer_norm_python(data, gamma, beta, axis, epsilon=1e-5):
result : np.ndarray
N-D with shape (d_0, d_1, ..., d_{N-1})
"""
+ old_dtype = data.dtype
+ data = data.astype("float32")
mean = np.mean(data, axis, keepdims=True)
var = np.var(data, axis, keepdims=True)
result = (data - mean) / np.sqrt(var + epsilon)
+ result = result.astype(old_dtype)
result *= gamma
if beta is not None:
result += beta
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 63db69ff14..a1fe266d68 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -1740,6 +1740,69 @@ def test_layer_norm():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_layer_norm_fp16():
+ # fmt: off
+ @tvm.script.ir_module
+ class LayerNorm:
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), "float16"), gamma: R.Tensor((4, 5),
"float16"), beta: R.Tensor((4, 5), "float16")) -> R.Tensor((2, 3, 4, 5),
"float16"):
+ gv: R.Tensor((2, 3, 4, 5), "float16") = R.nn.layer_norm(x, gamma,
beta, axes=[-2, -1])
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @T.prim_func
+ def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
+ T.func_attr({"tir.noalias": True})
+ rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(2),
T.int64(3), T.int64(4), T.int64(5)), "float16")
+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4),
T.int64(5)), "float16")
+ rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4),
T.int64(5)), "float16")
+ T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(2),
T.int64(3), T.int64(4), T.int64(5)), "float16")
+ with T.block("root"):
+ T.reads()
+ T.writes()
+ rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2),
T.int64(3)))
+ rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2),
T.int64(3)))
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for k2 in range(T.int64(4)):
+ for k3 in range(T.int64(5)):
+ with T.block("rxplaceholder_red_temp"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_k2 = T.axis.reduce(T.int64(4), k2)
+ v_k3 = T.axis.reduce(T.int64(5), k3)
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_k2,
v_k3])
+ T.writes(rxplaceholder_red_temp_v0[v_ax0,
v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1])
+ with T.init():
+ rxplaceholder_red_temp_v0[v_ax0,
v_ax1] = T.float32(0)
+ rxplaceholder_red_temp_v1[v_ax0,
v_ax1] = T.float32(0)
+ v_rxplaceholder_red_temp_v0: T.float32 =
rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32",
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
+ v_rxplaceholder_red_temp_v1: T.float32 =
rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32",
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) * T.Cast("float32",
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
+ rxplaceholder_red_temp_v0[v_ax0, v_ax1] =
v_rxplaceholder_red_temp_v0
+ rxplaceholder_red_temp_v1[v_ax0, v_ax1] =
v_rxplaceholder_red_temp_v1
+ for ax0 in range(T.int64(2)):
+ for ax1 in range(T.int64(3)):
+ for ax2 in range(T.int64(4)):
+ for ax3 in range(T.int64(5)):
+ with T.block("T_layer_norm"):
+ v_ax0 = T.axis.spatial(T.int64(2), ax0)
+ v_ax1 = T.axis.spatial(T.int64(3), ax1)
+ v_ax2 = T.axis.spatial(T.int64(4), ax2)
+ v_ax3 = T.axis.spatial(T.int64(5), ax3)
+ T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2,
v_ax3], rxplaceholder_red_temp_v0[v_ax0, v_ax1],
rxplaceholder_red_temp_v1[v_ax0, v_ax1], rxplaceholder_1[v_ax2, v_ax3],
rxplaceholder_2[v_ax2, v_ax3])
+ T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2,
v_ax3])
+ T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] =
T.Cast("float16", (T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])
- rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) *
T.float16(5))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] /
T.Cast("float32", T.float16(4) * T.float16(5)) -
rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) *
T.float16(5)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast [...]
+
+ @R.function
+ def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), gamma:
R.Tensor((4, 5), dtype="float16"), beta: R.Tensor((4, 5), dtype="float16")) ->
R.Tensor((2, 3, 4, 5), dtype="float16"):
+ gv = R.call_tir(Expected.layer_norm, (x, gamma, beta),
out_sinfo=R.Tensor((2, 3, 4, 5), dtype="float16"))
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(LayerNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_layer_norm_symbolic():
# fmt: off
@tvm.script.ir_module
@@ -1870,6 +1933,87 @@ def test_group_norm():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_group_norm_fp16():
+ # fmt: off
+ @tvm.script.ir_module
+ class GroupNorm:
+ @R.function
+ def main(x: R.Tensor((2, 4, 4, 5), "float16"), gamma: R.Tensor((4,),
"float16"), beta: R.Tensor((4,), "float16")) -> R.Tensor((2, 4, 4, 5),
"float16"):
+ gv: R.Tensor((2, 4, 4, 5), "float16") = R.nn.group_norm(x, gamma,
beta, num_groups=2, channel_axis=1, axes=[2, 3])
+ return gv
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor((2, 4, 4, 5), dtype="float16"), gamma:
R.Tensor((4,), dtype="float16"), beta: R.Tensor((4,), dtype="float16")) ->
R.Tensor((2, 4, 4, 5), dtype="float16"):
+ gv = R.call_tir(Expected.group_norm, (x, gamma, beta),
out_sinfo=R.Tensor((2, 4, 4, 5), dtype="float16"))
+ return gv
+
+ @T.prim_func
+ def group_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(4),
T.int64(4), T.int64(5)), "float16"), rxplaceholder_1: T.Buffer((T.int64(4),),
"float16"), rxplaceholder_2: T.Buffer((T.int64(4),), "float16"), T_reshape:
T.Buffer((T.int64(2), T.int64(4), T.int64(4), T.int64(5)), "float16")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ T_reshape_1 = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2),
T.int64(4), T.int64(5)), "float16")
+ T_cast = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2),
T.int64(4), T.int64(5)))
+ rxplaceholder_red_temp_v0 = T.alloc_buffer((T.int64(2),
T.int64(2)))
+ rxplaceholder_red_temp_v1 = T.alloc_buffer((T.int64(2),
T.int64(2)))
+ T_reshape_2 = T.alloc_buffer((T.int64(2), T.int64(2)), "float16")
+ T_reshape_3 = T.alloc_buffer((T.int64(2), T.int64(2)), "float16")
+ T_group_norm = T.alloc_buffer((T.int64(2), T.int64(2), T.int64(2),
T.int64(4), T.int64(5)), "float16")
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2),
T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("T_reshape"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS",
[ax0, ax1, ax2, ax3, ax4])
+ T.reads(rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 //
T.int64(5) + v_ax3) // T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2),
(v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) %
T.int64(4), (v_ax4 // T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)])
+ T.writes(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
rxplaceholder[((v_ax1 * T.int64(2) + (v_ax4 // T.int64(5) + v_ax3) //
T.int64(4) + v_ax2) // T.int64(4) + v_ax0) % T.int64(2), (v_ax1 * T.int64(2) +
(v_ax4 // T.int64(5) + v_ax3) // T.int64(4) + v_ax2) % T.int64(4), (v_ax4 //
T.int64(5) + v_ax3) % T.int64(4), v_ax4 % T.int64(5)]
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2),
T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("T_cast"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS",
[ax0, ax1, ax2, ax3, ax4])
+ T.reads(T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ T.writes(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
T.Cast("float32", T_reshape_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ for ax0, ax1, k2, k3, k4 in T.grid(T.int64(2), T.int64(2),
T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("rxplaceholder_red_temp"):
+ v_ax0, v_ax1, v_k2, v_k3, v_k4 = T.axis.remap("SSRRR",
[ax0, ax1, k2, k3, k4])
+ T.reads(T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4])
+ T.writes(rxplaceholder_red_temp_v0[v_ax0, v_ax1],
rxplaceholder_red_temp_v1[v_ax0, v_ax1])
+ with T.init():
+ rxplaceholder_red_temp_v0[v_ax0, v_ax1] = T.float32(0)
+ rxplaceholder_red_temp_v1[v_ax0, v_ax1] = T.float32(0)
+ v_rxplaceholder_red_temp_v0: T.float32 =
rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4]
+ v_rxplaceholder_red_temp_v1: T.float32 =
rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T_cast[v_ax0, v_ax1, v_k2, v_k3,
v_k4] * T_cast[v_ax0, v_ax1, v_k2, v_k3, v_k4]
+ rxplaceholder_red_temp_v0[v_ax0, v_ax1] =
v_rxplaceholder_red_temp_v0
+ rxplaceholder_red_temp_v1[v_ax0, v_ax1] =
v_rxplaceholder_red_temp_v1
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(2)):
+ with T.block("T_reshape_1"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder_1[(v_ax0 * T.int64(2) + v_ax1) %
T.int64(4)])
+ T.writes(T_reshape_2[v_ax0, v_ax1])
+ T_reshape_2[v_ax0, v_ax1] = rxplaceholder_1[(v_ax0 *
T.int64(2) + v_ax1) % T.int64(4)]
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(2)):
+ with T.block("T_reshape_2"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(rxplaceholder_2[(v_ax0 * T.int64(2) + v_ax1) %
T.int64(4)])
+ T.writes(T_reshape_3[v_ax0, v_ax1])
+ T_reshape_3[v_ax0, v_ax1] = rxplaceholder_2[(v_ax0 *
T.int64(2) + v_ax1) % T.int64(4)]
+ for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(2),
T.int64(2), T.int64(4), T.int64(5)):
+ with T.block("T_group_norm"):
+ v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS",
[ax0, ax1, ax2, ax3, ax4])
+ T.reads(T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4],
rxplaceholder_red_temp_v0[v_ax0, v_ax1], rxplaceholder_red_temp_v1[v_ax0,
v_ax1], T_reshape_2[v_ax1, v_ax2], T_reshape_3[v_ax1, v_ax2])
+ T.writes(T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
+ T_group_norm[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] =
T.Cast("float16", (T_cast[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] -
rxplaceholder_red_temp_v0[v_ax0, v_ax1] * T.float32(0.025000000000000001)) *
T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] *
T.float32(0.025000000000000001) - rxplaceholder_red_temp_v0[v_ax0, v_ax1] *
T.float32(0.025000000000000001) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] *
T.float32(0.025000000000000001)) + T.float32(1.0000000000000001e-05))) * T_resh
[...]
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(4),
T.int64(4), T.int64(5)):
+ with T.block("T_reshape_3"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) //
T.int64(4) + v_ax1) // T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5)
+ v_ax2) // T.int64(4) + v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 //
T.int64(5) + v_ax2) // T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) +
v_ax2) % T.int64(4), v_ax3 % T.int64(5)])
+ T.writes(T_reshape[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_reshape[v_ax0, v_ax1, v_ax2, v_ax3] =
T_group_norm[(((v_ax3 // T.int64(5) + v_ax2) // T.int64(4) + v_ax1) //
T.int64(4) + v_ax0) % T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) // T.int64(4)
+ v_ax1) % T.int64(4) // T.int64(2), ((v_ax3 // T.int64(5) + v_ax2) //
T.int64(4) + v_ax1) % T.int64(2), (v_ax3 // T.int64(5) + v_ax2) % T.int64(4),
v_ax3 % T.int64(5)]
+ # fmt: on
+
+ mod = LegalizeOps()(GroupNorm)
+ tvm.ir.assert_structural_equal(mod, Expected)
+
+
def test_group_norm_symbolic():
# fmt: off
@tvm.script.ir_module
diff --git a/tests/python/topi/python/test_topi_group_norm.py
b/tests/python/topi/python/test_topi_group_norm.py
index f094423916..8f8ab75b8a 100644
--- a/tests/python/topi/python/test_topi_group_norm.py
+++ b/tests/python/topi/python/test_topi_group_norm.py
@@ -34,7 +34,8 @@ _group_norm_schedule = {
# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize("shape, axis", [([2, 4, 16], (2,)), ([2, 4, 4, 16],
(2, 3))])
-def test_group_norm(target, dev, shape, axis, epsilon=1e-5, dtype="float32",
rtol=1e-5, atol=1e-5):
[email protected]("dtype", ["float32", "float16"])
+def test_group_norm(target, dev, shape, axis, dtype, epsilon=1e-5, rtol=1e-5,
atol=1e-5):
data = te.placeholder(shape, dtype=dtype, name="data")
num_groups = 2
channel_axis = 1
diff --git a/tests/python/topi/python/test_topi_layer_norm.py
b/tests/python/topi/python/test_topi_layer_norm.py
index f875bb09e2..ff9eedd4e5 100644
--- a/tests/python/topi/python/test_topi_layer_norm.py
+++ b/tests/python/topi/python/test_topi_layer_norm.py
@@ -34,7 +34,8 @@ _layer_norm_schedule = {
# only test on llvm because schedule is missing
@tvm.testing.parametrize_targets("llvm")
@pytest.mark.parametrize("shape,axis", [([4, 16], (1,)), ([4, 16, 16], (1,
2))])
-def test_layer_norm(target, dev, shape, axis, episilon=1e-5, dtype="float32",
rtol=1e-5, atol=1e-5):
[email protected]("dtype", ["float32", "float16"])
+def test_layer_norm(target, dev, shape, axis, dtype, episilon=1e-5, rtol=5e-4,
atol=5e-4):
data = te.placeholder(shape, dtype=dtype, name="data")
scale_shape = [shape[dim] for dim in axis]
gamma = te.placeholder(scale_shape, dtype=dtype, name="gamma")