This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b5d352fb6c [TOPI][NN][Layer_Norm] Fix layer_norm error with
reduce-only axes (#18063)
b5d352fb6c is described below
commit b5d352fb6c40c48732e84c2d91ca1608021e438f
Author: Ruxiao Yin <[email protected]>
AuthorDate: Wed Jun 18 16:37:21 2025 +0800
[TOPI][NN][Layer_Norm] Fix layer_norm error with reduce-only axes (#18063)
* [TOPI][NN][Layer_Norm] Fix layer_norm error with reduce-only axes
* change code style
* rechange code style
---
include/tvm/topi/nn/layer_norm.h | 2 +-
.../python/relax/test_transform_legalize_ops_nn.py | 54 ++++++++++++++++++++++
2 files changed, 55 insertions(+), 1 deletion(-)
diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h
index ee0cba74dd..f1b0e4ac9e 100644
--- a/include/tvm/topi/nn/layer_norm.h
+++ b/include/tvm/topi/nn/layer_norm.h
@@ -65,7 +65,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor&
gamma, const Tensor&
auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
auto reduce_axes = MakeReduceAxes(real_axis, data);
auto target_shape =
- MakeReduceTargetShape(real_axis, data, /*keepdims=*/false,
/*atleast1d=*/true);
+ MakeReduceTargetShape(real_axis, data, /*keepdims=*/false,
/*atleast1d=*/false);
auto func = MakeTupleSumReducer();
auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 52986feef3..575f4a0fb0 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2599,6 +2599,60 @@ def test_layer_norm():
tvm.ir.assert_structural_equal(mod, Expected)
+def test_layer_norm_1d():
+ # fmt: off
+ @I.ir_module
+ class LayerNorm_1D:
+ @R.function
+ def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight:
R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,),
dtype="float32")) -> R.Tensor((3,), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ layer_norm: R.Tensor((3,), dtype="float32") =
R.nn.layer_norm(x, layer_norm_weight, layer_norm_bias, axes=[-1],
epsilon=1.0000000000000001e-05, center=True, scale=True)
+ gv: R.Tensor((3,), dtype="float32") = layer_norm
+ R.output(gv)
+ return gv
+
+ @I.ir_module
+ class LayerNorm_1D_Expected:
+ @T.prim_func(private=True)
+ def layer_norm(x: T.Buffer((T.int64(3),), "float32"),
layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias:
T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),),
"float32")):
+ T.func_attr({"tir.noalias": True})
+ # with T.block("root"):
+ x_red_temp_v0 = T.alloc_buffer(())
+ x_red_temp_v1 = T.alloc_buffer(())
+ for k0 in range(T.int64(3)):
+ with T.block("x_red_temp"):
+ v_k0 = T.axis.reduce(T.int64(3), k0)
+ T.reads(x[v_k0])
+ T.writes(x_red_temp_v0[()], x_red_temp_v1[()])
+ with T.init():
+ x_red_temp_v0[()] = T.float32(0.0)
+ x_red_temp_v1[()] = T.float32(0.0)
+ v_x_red_temp_v0: T.float32 = x_red_temp_v0[()] + x[v_k0]
+ v_x_red_temp_v1: T.float32 = x_red_temp_v1[()] + x[v_k0] *
x[v_k0]
+ x_red_temp_v0[()] = v_x_red_temp_v0
+ x_red_temp_v1[()] = v_x_red_temp_v1
+ for ax0 in range(T.int64(3)):
+ with T.block("T_layer_norm"):
+ v_ax0 = T.axis.spatial(T.int64(3), ax0)
+ T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()],
layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
+ T.writes(T_layer_norm[v_ax0])
+ T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] *
T.float32(0.33333333333333331)) * T.rsqrt(x_red_temp_v1[()] *
T.float32(0.33333333333333331) - x_red_temp_v0[()] *
T.float32(0.33333333333333331) * (x_red_temp_v0[()] *
T.float32(0.33333333333333331)) + T.float32(1.0000000000000001e-05)) *
layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0]
+
+ @R.function
+ def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight:
R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,),
dtype="float32")) -> R.Tensor((3,), dtype="float32"):
+ R.func_attr({"num_input": 1})
+ cls = LayerNorm_1D_Expected
+ with R.dataflow():
+ layer_norm = R.call_tir(cls.layer_norm, (x, layer_norm_weight,
layer_norm_bias), out_sinfo=R.Tensor((3,), dtype="float32"))
+ gv: R.Tensor((3,), dtype="float32") = layer_norm
+ R.output(gv)
+ return gv
+ # fmt: on
+ mod = LegalizeOps()(LayerNorm_1D)
+ tvm.ir.assert_structural_equal(mod, LayerNorm_1D_Expected)
+
+
def test_layer_norm_fp16():
# fmt: off
@tvm.script.ir_module