This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fca86a6872 [Fix] Stabilize layer_norm variance computation with
two-pass reduction (#19643)
fca86a6872 is described below
commit fca86a6872dacea0978f1d49eb4e22935c2b0820
Author: ConvolutedDog <[email protected]>
AuthorDate: Sun May 31 13:45:46 2026 +0800
[Fix] Stabilize layer_norm variance computation with two-pass reduction
(#19643)
This PR will fix https://github.com/apache/tvm/issues/19592.
LayerNorm could produce NaN on large-value, small-variance inputs due to
catastrophic cancellation in var = E[x^2] - E[x]^2.
Switch to a numerically stable two-pass formulation:
- pass1 computes mean via sum(x) / N
- pass2 computes variance via sum((x - mean)^2) / N
---
include/tvm/topi/nn/layer_norm.h | 73 ++++---
tests/python/relax/test_frontend_onnx.py | 38 ++++
.../python/relax/test_transform_legalize_ops_nn.py | 224 ++++++++++++---------
3 files changed, 211 insertions(+), 124 deletions(-)
diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h
index 873a5fd1b2..d74bbce23f 100644
--- a/include/tvm/topi/nn/layer_norm.h
+++ b/include/tvm/topi/nn/layer_norm.h
@@ -25,6 +25,7 @@
#define TVM_TOPI_NN_LAYER_NORM_H_
#include <tvm/te/operation.h>
+#include <tvm/topi/reduction.h>
#include <tvm/topi/tags.h>
#include <string>
@@ -59,17 +60,18 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
TVM_FFI_ICHECK(data_type == DataType::Float(32) || data_type ==
DataType::Float(16))
<< "layer_norm: only support float32 and float16 for now";
bool is_float16 = data_type == DataType::Float(16);
- // sum x and x^2
+ // Two-pass algorithm for improved numerical stability:
+ // pass1: mean = E[x]
+ // pass2: var = E[(x - mean)^2]
auto ndim = data->shape.size();
TVM_FFI_ICHECK_NE(ndim, 0) << "Cannot reduce a 0 dim Tensor";
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape =
MakeReduceTargetShape(real_axis, data, /*keepdims=*/false,
/*atleast1d=*/false);
- auto func = MakeTupleSumReducer();
- auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
- &data](const ffi::Array<Var>& indices) {
+ auto make_eval_range = [&real_axis, &reduce_axes,
+ ndim](const ffi::Array<Var>& non_reduce_indices) {
ffi::Array<PrimExpr> eval_range;
int arg_counter = 0;
int red_counter = 0;
@@ -80,34 +82,51 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
eval_range.push_back(reduce_axes[red_counter]);
red_counter++;
} else {
- eval_range.push_back(indices[arg_counter]);
+ eval_range.push_back(non_reduce_indices[arg_counter]);
arg_counter++;
}
}
- auto square = [is_float16](const PrimExpr& x) {
- if (is_float16) {
- return Cast(DataType::Float(32), x) * Cast(DataType::Float(32), x);
- }
- return x * x;
- };
- if (is_float16) {
- return func({Cast(DataType::Float(32), data(eval_range)),
square(data(eval_range))},
- reduce_axes, nullptr);
- } else {
- return func({data(eval_range), square(data(eval_range))}, reduce_axes,
nullptr);
- }
+ return eval_range;
};
- auto temp_x_x2 =
- tvm::te::compute(target_shape, compute, data->op->name + "_red_temp",
kCommReduce);
+ Tensor temp_sum = te::compute(
+ target_shape,
+ [is_float16, &data, &reduce_axes, &make_eval_range](const
ffi::Array<Var>& indices) {
+ auto eval_range = make_eval_range(indices);
+ PrimExpr x = data(eval_range);
+ if (is_float16) {
+ x = Cast(DataType::Float(32), x);
+ }
+ return sum(x, reduce_axes);
+ },
+ data->op->name + "_sum", kCommReduce);
- auto temp_x = temp_x_x2[0];
- auto temp_x2 = temp_x_x2[1];
-
- auto reduce_extent = make_const(data->dtype, 1);
+ DataType reduce_dtype = is_float16 ? DataType::Float(32) : data->dtype;
+ PrimExpr reduce_extent = make_const(reduce_dtype, 1);
for (int i : real_axis) {
reduce_extent *= data->shape[i];
}
+ Tensor temp_mean = te::compute(
+ target_shape,
+ [&temp_sum, &reduce_extent](const ffi::Array<Var>& indices) {
+ return temp_sum(indices) / reduce_extent;
+ },
+ data->op->name + "_mean", kInjective);
+
+ Tensor temp_var_sum = te::compute(
+ target_shape,
+ [is_float16, &data, &reduce_axes, &make_eval_range,
+ &temp_mean](const ffi::Array<Var>& indices) {
+ auto eval_range = make_eval_range(indices);
+ PrimExpr x = data(eval_range);
+ if (is_float16) {
+ x = Cast(DataType::Float(32), x);
+ }
+ PrimExpr diff = x - temp_mean(indices);
+ return sum(diff * diff, reduce_axes);
+ },
+ data->op->name + "_var_sum", kCommReduce);
+
auto layer_norm_func = [&](const ffi::Array<Var>& indices) {
ffi::Array<Var> reduce_indices, non_reduce_indices;
for (int i = 0, n = static_cast<int>(indices.size()); i < n; ++i) {
@@ -117,9 +136,9 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
non_reduce_indices.push_back(indices[i]);
}
}
- auto mean = temp_x(non_reduce_indices) / reduce_extent;
- auto var = temp_x2(non_reduce_indices) / reduce_extent - mean * mean;
- auto layer_norm = (data(indices) - mean) * tvm::rsqrt(var +
make_const(var->dtype, epsilon));
+ auto mean = temp_mean(non_reduce_indices);
+ auto var = temp_var_sum(non_reduce_indices) / reduce_extent;
+ auto layer_norm = (data(indices) - mean) * rsqrt(var +
make_const(var->dtype, epsilon));
if (is_float16) {
layer_norm = Cast(DataType::Float(16), layer_norm);
}
@@ -129,7 +148,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
}
return layer_norm;
};
- return tvm::te::compute(data->shape, layer_norm_func, name, tag);
+ return te::compute(data->shape, layer_norm_func, name, tag);
}
} // namespace nn
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 4278812436..7ee10993a4 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2309,6 +2309,44 @@ def test_layer_norm_with_nd_gamma_beta():
check_correctness(model)
+def test_layer_norm_numerical_stability():
+ """Numerical stability test for
https://github.com/apache/tvm/issues/19592."""
+ layer_norm_node = helper.make_node(
+ "LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=-1,
epsilon=1e-5
+ )
+ graph = helper.make_graph(
+ [layer_norm_node],
+ "layer_norm_numerical_stability",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 4]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT, [4]),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, [4]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 4]),
+ ],
+ )
+ model = helper.make_model(graph,
producer_name="layer_norm_numerical_stability")
+
+ input_array = np.array([[80000.0, 80001.0, 80002.0, 80003.0]],
dtype=np.float32)
+ scale_array = np.ones(4, dtype=np.float32)
+ bias_array = np.zeros(4, dtype=np.float32)
+ inputs = {"input": input_array, "scale": scale_array, "bias": bias_array}
+
+ # ONNXRuntime also returns NaN for Large-value, small-variance inputs, so
we here
+ # compare against a two-pass reference instead of ORT.
+ mean = input_array.mean(axis=-1, keepdims=True)
+ var = ((input_array - mean) ** 2).mean(axis=-1, keepdims=True)
+ expected = ((input_array - mean) / np.sqrt(var + 1e-5) * scale_array +
bias_array).astype(
+ np.float32
+ )
+
+ tvm_output = run_in_tvm(model, inputs=inputs, ir_version=9, opset=17)
+
+ assert np.isfinite(tvm_output.numpy()).all()
+ tvm.testing.assert_allclose(tvm_output.numpy(), expected)
+
+
def test_rms_norm():
# Basic test: default axis=-1
rms_norm_node = helper.make_node("RMSNormalization", ["input", "scale"],
["Y"], epsilon=1e-05)
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 6badc7fc33..4a708b5da1 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2734,28 +2734,40 @@ def test_layer_norm():
return gv
@T.prim_func(private=True, s_tir=True)
- def layer_norm(rxplaceholder: T.Buffer((T.int64(2), T.int64(3),
T.int64(4), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(4),
T.int64(5)), "float32"), rxplaceholder_2: T.Buffer((T.int64(4), T.int64(5)),
"float32"), T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32")):
+ def layer_norm(x: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float32"), gamma: T.Buffer((T.int64(4), T.int64(5)), "float32"),
beta: T.Buffer((T.int64(4), T.int64(5)), "float32"), T_layer_norm:
T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)), "float32")):
T.func_attr({"tirx.noalias": True})
- rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer([T.int64(2),
T.int64(3)], dtype="float32")
- rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer([T.int64(2),
T.int64(3)], dtype="float32")
- for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
- with T.sblock("rxplaceholder_red_temp"):
- ax0, ax1, k2, k3 = T.axis.remap("SSRR", [i0, i1, i2, i3])
- T.reads(rxplaceholder[ax0, ax1, k2, k3])
- T.writes(rxplaceholder_red_temp_v0[ax0, ax1],
rxplaceholder_red_temp_v1[ax0, ax1])
+ # with T.sblock("root"):
+ x_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+ x_mean = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+ x_var_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.sblock("x_sum"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(x_sum[v_ax0, v_ax1])
+ with T.init():
+ x_sum[v_ax0, v_ax1] = T.float32(0.0)
+ x_sum[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] + x[v_ax0,
v_ax1, v_k2, v_k3]
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.sblock("x_mean"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(x_sum[v_ax0, v_ax1])
+ T.writes(x_mean[v_ax0, v_ax1])
+ x_mean[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] /
T.float32(20.0)
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.sblock("x_var_sum"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(x[v_ax0, v_ax1, v_k2, v_k3], x_mean[v_ax0, v_ax1])
+ T.writes(x_var_sum[v_ax0, v_ax1])
with T.init():
- rxplaceholder_red_temp_v0[ax0, ax1] = T.float32(0)
- rxplaceholder_red_temp_v1[ax0, ax1] = T.float32(0)
- v_rxplaceholder_red_temp_v0: T.let[T.float32] =
rxplaceholder_red_temp_v0[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3]
- v_rxplaceholder_red_temp_v1: T.let[T.float32] =
rxplaceholder_red_temp_v1[ax0, ax1] + rxplaceholder[ax0, ax1, k2, k3] *
rxplaceholder[ax0, ax1, k2, k3]
- rxplaceholder_red_temp_v0[ax0, ax1] =
v_rxplaceholder_red_temp_v0
- rxplaceholder_red_temp_v1[ax0, ax1] =
v_rxplaceholder_red_temp_v1
- for i0, i1, i2, i3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ x_var_sum[v_ax0, v_ax1] = T.float32(0.0)
+ x_var_sum[v_ax0, v_ax1] = x_var_sum[v_ax0, v_ax1] +
(x[v_ax0, v_ax1, v_k2, v_k3] - x_mean[v_ax0, v_ax1]) * (x[v_ax0, v_ax1, v_k2,
v_k3] - x_mean[v_ax0, v_ax1])
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
with T.sblock("T_layer_norm"):
- ax0, ax1, ax2, ax3 = T.axis.remap("SSSS", [i0, i1, i2, i3])
- T.reads(rxplaceholder[ax0, ax1, ax2, ax3],
rxplaceholder_red_temp_v0[ax0, ax1], rxplaceholder_red_temp_v1[ax0, ax1],
rxplaceholder_1[ax2, ax3], rxplaceholder_2[ax2, ax3])
- T.writes(T_layer_norm[ax0, ax1, ax2, ax3])
- T_layer_norm[ax0, ax1, ax2, ax3] = (rxplaceholder[ax0,
ax1, ax2, ax3] - rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) *
T.rsqrt(rxplaceholder_red_temp_v1[ax0, ax1] / T.float32(20) -
rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20) *
(rxplaceholder_red_temp_v0[ax0, ax1] / T.float32(20)) + T.float32(1e-05),
dtype="float32") * rxplaceholder_1[ax2, ax3] + rxplaceholder_2[ax2, ax3]
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_mean[v_ax0,
v_ax1], x_var_sum[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
+ T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] = (x[v_ax0,
v_ax1, v_ax2, v_ax3] - x_mean[v_ax0, v_ax1]) * T.rsqrt(x_var_sum[v_ax0, v_ax1]
/ T.float32(20.0) + T.float32(1.0000000000000001e-05)) * gamma[v_ax2, v_ax3] +
beta[v_ax2, v_ax3]
# fmt: on
mod = LegalizeOps()(LayerNorm)
tvm.ir.assert_structural_equal(mod, Expected)
@@ -2780,26 +2792,36 @@ def test_layer_norm_1d():
def layer_norm(x: T.Buffer((T.int64(3),), "float32"),
layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias:
T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),),
"float32")):
T.func_attr({"tirx.noalias": True})
# with T.sblock("root"):
- x_red_temp_v0 = T.sblock_alloc_buffer(())
- x_red_temp_v1 = T.sblock_alloc_buffer(())
+ x_sum = T.sblock_alloc_buffer(())
+ x_mean = T.sblock_alloc_buffer(())
+ x_var_sum = T.sblock_alloc_buffer(())
for k0 in range(T.int64(3)):
- with T.sblock("x_red_temp"):
+ with T.sblock("x_sum"):
v_k0 = T.axis.reduce(T.int64(3), k0)
T.reads(x[v_k0])
- T.writes(x_red_temp_v0[()], x_red_temp_v1[()])
+ T.writes(x_sum[()])
+ with T.init():
+ x_sum[()] = T.float32(0.0)
+ x_sum[()] = x_sum[()] + x[v_k0]
+ with T.sblock("x_mean"):
+ vi = T.axis.spatial(1, T.int64(0))
+ T.reads(x_sum[()])
+ T.writes(x_mean[()])
+ x_mean[()] = x_sum[()] / T.float32(3.0)
+ for k0 in range(T.int64(3)):
+ with T.sblock("x_var_sum"):
+ v_k0 = T.axis.reduce(T.int64(3), k0)
+ T.reads(x[v_k0], x_mean[()])
+ T.writes(x_var_sum[()])
with T.init():
- x_red_temp_v0[()] = T.float32(0.0)
- x_red_temp_v1[()] = T.float32(0.0)
- v_x_red_temp_v0: T.let[T.float32] = x_red_temp_v0[()] +
x[v_k0]
- v_x_red_temp_v1: T.let[T.float32] = x_red_temp_v1[()] +
x[v_k0] * x[v_k0]
- x_red_temp_v0[()] = v_x_red_temp_v0
- x_red_temp_v1[()] = v_x_red_temp_v1
+ x_var_sum[()] = T.float32(0.0)
+ x_var_sum[()] = x_var_sum[()] + (x[v_k0] - x_mean[()]) *
(x[v_k0] - x_mean[()])
for ax0 in range(T.int64(3)):
with T.sblock("T_layer_norm"):
v_ax0 = T.axis.spatial(T.int64(3), ax0)
- T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()],
layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
+ T.reads(x[v_ax0], x_mean[()], x_var_sum[()],
layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
T.writes(T_layer_norm[v_ax0])
- T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] /
T.float32(3)) * T.rsqrt(x_red_temp_v1[()] / T.float32(3) - x_red_temp_v0[()] /
T.float32(3) * (x_red_temp_v0[()] / T.float32(3)) +
T.float32(1.0000000000000001e-05)) * layer_norm_weight[v_ax0] +
layer_norm_bias[v_ax0]
+ T_layer_norm[v_ax0] = (x[v_ax0] - x_mean[()]) *
T.rsqrt(x_var_sum[()] / T.float32(3.0) + T.float32(1.0000000000000001e-05)) *
layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0]
@R.function
def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight:
R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,),
dtype="float32")) -> R.Tensor((3,), dtype="float32"):
@@ -2827,47 +2849,45 @@ def test_layer_norm_fp16():
@I.ir_module(s_tir=True)
class Expected:
@T.prim_func(private=True, s_tir=True)
- def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
+ def layer_norm(
+ x: T.Buffer((T.int64(2), T.int64(3), T.int64(4), T.int64(5)),
"float16"),
+ gamma: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+ beta: T.Buffer((T.int64(4), T.int64(5)), "float16"),
+ T_layer_norm: T.Buffer((T.int64(2), T.int64(3), T.int64(4),
T.int64(5)), "float16"),
+ ):
T.func_attr({"tirx.noalias": True})
- rxplaceholder = T.match_buffer(var_rxplaceholder, (T.int64(2),
T.int64(3), T.int64(4), T.int64(5)), "float16")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, (T.int64(4),
T.int64(5)), "float16")
- rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, (T.int64(4),
T.int64(5)), "float16")
- T_layer_norm = T.match_buffer(var_T_layer_norm, (T.int64(2),
T.int64(3), T.int64(4), T.int64(5)), "float16")
- with T.sblock("root"):
- T.reads()
- T.writes()
- rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer((T.int64(2),
T.int64(3)))
- rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer((T.int64(2),
T.int64(3)))
- for ax0 in range(T.int64(2)):
- for ax1 in range(T.int64(3)):
- for k2 in range(T.int64(4)):
- for k3 in range(T.int64(5)):
- with T.sblock("rxplaceholder_red_temp"):
- v_ax0 = T.axis.spatial(T.int64(2), ax0)
- v_ax1 = T.axis.spatial(T.int64(3), ax1)
- v_k2 = T.axis.reduce(T.int64(4), k2)
- v_k3 = T.axis.reduce(T.int64(5), k3)
- T.reads(rxplaceholder[v_ax0, v_ax1, v_k2,
v_k3])
- T.writes(rxplaceholder_red_temp_v0[v_ax0,
v_ax1], rxplaceholder_red_temp_v1[v_ax0, v_ax1])
- with T.init():
- rxplaceholder_red_temp_v0[v_ax0,
v_ax1] = T.float32(0)
- rxplaceholder_red_temp_v1[v_ax0,
v_ax1] = T.float32(0)
- v_rxplaceholder_red_temp_v0:
T.let[T.float32] = rxplaceholder_red_temp_v0[v_ax0, v_ax1] + T.Cast("float32",
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
- v_rxplaceholder_red_temp_v1:
T.let[T.float32] = rxplaceholder_red_temp_v1[v_ax0, v_ax1] + T.Cast("float32",
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3]) * T.Cast("float32",
rxplaceholder[v_ax0, v_ax1, v_k2, v_k3])
- rxplaceholder_red_temp_v0[v_ax0, v_ax1] =
v_rxplaceholder_red_temp_v0
- rxplaceholder_red_temp_v1[v_ax0, v_ax1] =
v_rxplaceholder_red_temp_v1
- for ax0 in range(T.int64(2)):
- for ax1 in range(T.int64(3)):
- for ax2 in range(T.int64(4)):
- for ax3 in range(T.int64(5)):
- with T.sblock("T_layer_norm"):
- v_ax0 = T.axis.spatial(T.int64(2), ax0)
- v_ax1 = T.axis.spatial(T.int64(3), ax1)
- v_ax2 = T.axis.spatial(T.int64(4), ax2)
- v_ax3 = T.axis.spatial(T.int64(5), ax3)
- T.reads(rxplaceholder[v_ax0, v_ax1, v_ax2,
v_ax3], rxplaceholder_red_temp_v0[v_ax0, v_ax1],
rxplaceholder_red_temp_v1[v_ax0, v_ax1], rxplaceholder_1[v_ax2, v_ax3],
rxplaceholder_2[v_ax2, v_ax3])
- T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2,
v_ax3])
- T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] =
T.Cast("float16", (T.Cast("float32", rxplaceholder[v_ax0, v_ax1, v_ax2, v_ax3])
- rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) *
T.float16(5))) * T.rsqrt(rxplaceholder_red_temp_v1[v_ax0, v_ax1] /
T.Cast("float32", T.float16(4) * T.float16(5)) -
rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast("float32", T.float16(4) *
T.float16(5)) * (rxplaceholder_red_temp_v0[v_ax0, v_ax1] / T.Cast [...]
+ # with T.sblock("root"):
+ x_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+ x_mean = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+ x_var_sum = T.sblock_alloc_buffer((T.int64(2), T.int64(3)))
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.sblock("x_sum"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(x[v_ax0, v_ax1, v_k2, v_k3])
+ T.writes(x_sum[v_ax0, v_ax1])
+ with T.init():
+ x_sum[v_ax0, v_ax1] = T.float32(0.0)
+ x_sum[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] +
T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3])
+ for ax0, ax1 in T.grid(T.int64(2), T.int64(3)):
+ with T.sblock("x_mean"):
+ v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(x_sum[v_ax0, v_ax1])
+ T.writes(x_mean[v_ax0, v_ax1])
+ x_mean[v_ax0, v_ax1] = x_sum[v_ax0, v_ax1] /
T.float32(20.0)
+ for ax0, ax1, k2, k3 in T.grid(T.int64(2), T.int64(3), T.int64(4),
T.int64(5)):
+ with T.sblock("x_var_sum"):
+ v_ax0, v_ax1, v_k2, v_k3 = T.axis.remap("SSRR", [ax0, ax1,
k2, k3])
+ T.reads(x[v_ax0, v_ax1, v_k2, v_k3], x_mean[v_ax0, v_ax1])
+ T.writes(x_var_sum[v_ax0, v_ax1])
+ with T.init():
+ x_var_sum[v_ax0, v_ax1] = T.float32(0.0)
+ x_var_sum[v_ax0, v_ax1] = x_var_sum[v_ax0, v_ax1] +
(T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3]) - x_mean[v_ax0, v_ax1]) *
(T.Cast("float32", x[v_ax0, v_ax1, v_k2, v_k3]) - x_mean[v_ax0, v_ax1])
+ for ax0, ax1, ax2, ax3 in T.grid(T.int64(2), T.int64(3),
T.int64(4), T.int64(5)):
+ with T.sblock("T_layer_norm"):
+ v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0,
ax1, ax2, ax3])
+ T.reads(x[v_ax0, v_ax1, v_ax2, v_ax3], x_mean[v_ax0,
v_ax1], x_var_sum[v_ax0, v_ax1], gamma[v_ax2, v_ax3], beta[v_ax2, v_ax3])
+ T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3])
+ T_layer_norm[v_ax0, v_ax1, v_ax2, v_ax3] =
T.Cast("float16", (T.Cast("float32", x[v_ax0, v_ax1, v_ax2, v_ax3]) -
x_mean[v_ax0, v_ax1]) * T.rsqrt(x_var_sum[v_ax0, v_ax1] / T.float32(20.0) +
T.float32(1.0000000000000001e-05))) * gamma[v_ax2, v_ax3] + beta[v_ax2, v_ax3]
@R.function
def main(x: R.Tensor((2, 3, 4, 5), dtype="float16"), gamma:
R.Tensor((4, 5), dtype="float16"), beta: R.Tensor((4, 5), dtype="float16")) ->
R.Tensor((2, 3, 4, 5), dtype="float16"):
@@ -2901,35 +2921,45 @@ def test_layer_norm_symbolic():
return gv
@T.prim_func(private=True, s_tir=True)
- def layer_norm(var_rxplaceholder: T.handle, var_rxplaceholder_1:
T.handle, var_rxplaceholder_2: T.handle, var_T_layer_norm: T.handle):
+ def layer_norm(var_x: T.handle, var_gamma: T.handle, var_beta:
T.handle, var_T_layer_norm: T.handle):
T.func_attr({"tirx.noalias": True})
- f = T.int64()
- n = T.int64()
- s = T.int64()
- rxplaceholder = T.match_buffer(var_rxplaceholder, [n, s, f],
dtype="float32")
- rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [s, f],
dtype="float32")
- rxplaceholder_2 = T.match_buffer(var_rxplaceholder_2, [s, f],
dtype="float32")
- T_layer_norm = T.match_buffer(var_T_layer_norm, [n, s, f],
dtype="float32")
- rxplaceholder_red_temp_v0 = T.sblock_alloc_buffer([n],
dtype="float32")
- rxplaceholder_red_temp_v1 = T.sblock_alloc_buffer([n],
dtype="float32")
- for i0, i1, i2 in T.grid(n, s, f):
- with T.sblock("rxplaceholder_red_temp"):
- ax0, k1, k2 = T.axis.remap("SRR", [i0, i1, i2])
- T.reads(rxplaceholder[ax0, k1, k2])
- T.writes(rxplaceholder_red_temp_v0[ax0],
rxplaceholder_red_temp_v1[ax0])
+ n, s, f = T.int64(), T.int64(), T.int64()
+ x = T.match_buffer(var_x, (n, s, f))
+ gamma = T.match_buffer(var_gamma, (s, f))
+ beta = T.match_buffer(var_beta, (s, f))
+ T_layer_norm = T.match_buffer(var_T_layer_norm, (n, s, f))
+ # with T.sblock("root"):
+ x_sum = T.sblock_alloc_buffer((n,))
+ x_mean = T.sblock_alloc_buffer((n,))
+ x_var_sum = T.sblock_alloc_buffer((n,))
+ for ax0, k1, k2 in T.grid(n, s, f):
+ with T.sblock("x_sum"):
+ v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+ T.reads(x[v_ax0, v_k1, v_k2])
+ T.writes(x_sum[v_ax0])
with T.init():
- rxplaceholder_red_temp_v0[ax0] = T.float32(0)
- rxplaceholder_red_temp_v1[ax0] = T.float32(0)
- v_rxplaceholder_red_temp_v0: T.let[T.float32] =
rxplaceholder_red_temp_v0[ax0] + rxplaceholder[ax0, k1, k2]
- v_rxplaceholder_red_temp_v1: T.let[T.float32] =
rxplaceholder_red_temp_v1[ax0] + rxplaceholder[ax0, k1, k2] *
rxplaceholder[ax0, k1, k2]
- rxplaceholder_red_temp_v0[ax0] =
v_rxplaceholder_red_temp_v0
- rxplaceholder_red_temp_v1[ax0] =
v_rxplaceholder_red_temp_v1
- for i0, i1, i2 in T.grid(n, s, f):
+ x_sum[v_ax0] = T.float32(0.0)
+ x_sum[v_ax0] = x_sum[v_ax0] + x[v_ax0, v_k1, v_k2]
+ for ax0 in range(n):
+ with T.sblock("x_mean"):
+ v_ax0 = T.axis.spatial(n, ax0)
+ T.reads(x_sum[v_ax0])
+ T.writes(x_mean[v_ax0])
+ x_mean[v_ax0] = x_sum[v_ax0] / (T.Cast("float32", s) *
T.Cast("float32", f))
+ for ax0, k1, k2 in T.grid(n, s, f):
+ with T.sblock("x_var_sum"):
+ v_ax0, v_k1, v_k2 = T.axis.remap("SRR", [ax0, k1, k2])
+ T.reads(x[v_ax0, v_k1, v_k2], x_mean[v_ax0])
+ T.writes(x_var_sum[v_ax0])
+ with T.init():
+ x_var_sum[v_ax0] = T.float32(0.0)
+ x_var_sum[v_ax0] = x_var_sum[v_ax0] + (x[v_ax0, v_k1,
v_k2] - x_mean[v_ax0]) * (x[v_ax0, v_k1, v_k2] - x_mean[v_ax0])
+ for ax0, ax1, ax2 in T.grid(n, s, f):
with T.sblock("T_layer_norm"):
- ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
- T.reads(rxplaceholder[ax0, ax1, ax2],
rxplaceholder_red_temp_v0[ax0], rxplaceholder_red_temp_v1[ax0],
rxplaceholder_1[ax1, ax2], rxplaceholder_2[ax1, ax2])
- T.writes(T_layer_norm[ax0, ax1, ax2])
- T_layer_norm[ax0, ax1, ax2] = (rxplaceholder[ax0, ax1,
ax2] - rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) *
T.Cast("float32", f))) * T.rsqrt(rxplaceholder_red_temp_v1[ax0] /
(T.Cast("float32", s) * T.Cast("float32", f)) - rxplaceholder_red_temp_v0[ax0]
/ (T.Cast("float32", s) * T.Cast("float32", f)) *
(rxplaceholder_red_temp_v0[ax0] / (T.Cast("float32", s) * T.Cast("float32",
f))) + T.float32(1e-05), dtype="float32") * rxplaceholder_1[ax1, ax2] +
rxplacehol [...]
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(x[v_ax0, v_ax1, v_ax2], x_mean[v_ax0],
x_var_sum[v_ax0], gamma[v_ax1, v_ax2], beta[v_ax1, v_ax2])
+ T.writes(T_layer_norm[v_ax0, v_ax1, v_ax2])
+ T_layer_norm[v_ax0, v_ax1, v_ax2] = (x[v_ax0, v_ax1,
v_ax2] - x_mean[v_ax0]) * T.rsqrt(x_var_sum[v_ax0] / (T.Cast("float32", s) *
T.Cast("float32", f)) + T.float32(1.0000000000000001e-05)) * gamma[v_ax1,
v_ax2] + beta[v_ax1, v_ax2]
# fmt: on
mod = LegalizeOps()(LayerNorm)
tvm.ir.assert_structural_equal(mod, Expected)