This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b5d352fb6c [TOPI][NN][Layer_Norm] Fix layer_norm error with 
reduce-only axes (#18063)
b5d352fb6c is described below

commit b5d352fb6c40c48732e84c2d91ca1608021e438f
Author: Ruxiao Yin <[email protected]>
AuthorDate: Wed Jun 18 16:37:21 2025 +0800

    [TOPI][NN][Layer_Norm] Fix layer_norm error with reduce-only axes (#18063)
    
    * [TOPI][NN][Layer_Norm] Fix layer_norm error with reduce-only axes
    
    * change code style
    
    * rechange code style
---
 include/tvm/topi/nn/layer_norm.h                   |  2 +-
 .../python/relax/test_transform_legalize_ops_nn.py | 54 ++++++++++++++++++++++
 2 files changed, 55 insertions(+), 1 deletion(-)

diff --git a/include/tvm/topi/nn/layer_norm.h b/include/tvm/topi/nn/layer_norm.h
index ee0cba74dd..f1b0e4ac9e 100644
--- a/include/tvm/topi/nn/layer_norm.h
+++ b/include/tvm/topi/nn/layer_norm.h
@@ -65,7 +65,7 @@ inline Tensor layer_norm(const Tensor& data, const Tensor& 
gamma, const Tensor&
   auto real_axis = GetRealAxis(static_cast<int>(ndim), axis);
   auto reduce_axes = MakeReduceAxes(real_axis, data);
   auto target_shape =
-      MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, 
/*atleast1d=*/true);
+      MakeReduceTargetShape(real_axis, data, /*keepdims=*/false, 
/*atleast1d=*/false);
   auto func = MakeTupleSumReducer();
 
   auto compute = [ndim, is_float16, &real_axis, &reduce_axes, &func,
diff --git a/tests/python/relax/test_transform_legalize_ops_nn.py 
b/tests/python/relax/test_transform_legalize_ops_nn.py
index 52986feef3..575f4a0fb0 100644
--- a/tests/python/relax/test_transform_legalize_ops_nn.py
+++ b/tests/python/relax/test_transform_legalize_ops_nn.py
@@ -2599,6 +2599,60 @@ def test_layer_norm():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_layer_norm_1d():
+    # fmt: off
+    @I.ir_module
+    class LayerNorm_1D:
+        @R.function
+        def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: 
R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), 
dtype="float32")) -> R.Tensor((3,), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                layer_norm: R.Tensor((3,), dtype="float32") = 
R.nn.layer_norm(x, layer_norm_weight, layer_norm_bias, axes=[-1], 
epsilon=1.0000000000000001e-05, center=True, scale=True)
+                gv: R.Tensor((3,), dtype="float32") = layer_norm
+                R.output(gv)
+            return gv
+
+    @I.ir_module
+    class LayerNorm_1D_Expected:
+        @T.prim_func(private=True)
+        def layer_norm(x: T.Buffer((T.int64(3),), "float32"), 
layer_norm_weight: T.Buffer((T.int64(3),), "float32"), layer_norm_bias: 
T.Buffer((T.int64(3),), "float32"), T_layer_norm: T.Buffer((T.int64(3),), 
"float32")):
+            T.func_attr({"tir.noalias": True})
+            # with T.block("root"):
+            x_red_temp_v0 = T.alloc_buffer(())
+            x_red_temp_v1 = T.alloc_buffer(())
+            for k0 in range(T.int64(3)):
+                with T.block("x_red_temp"):
+                    v_k0 = T.axis.reduce(T.int64(3), k0)
+                    T.reads(x[v_k0])
+                    T.writes(x_red_temp_v0[()], x_red_temp_v1[()])
+                    with T.init():
+                        x_red_temp_v0[()] = T.float32(0.0)
+                        x_red_temp_v1[()] = T.float32(0.0)
+                    v_x_red_temp_v0: T.float32 = x_red_temp_v0[()] + x[v_k0]
+                    v_x_red_temp_v1: T.float32 = x_red_temp_v1[()] + x[v_k0] * 
x[v_k0]
+                    x_red_temp_v0[()] = v_x_red_temp_v0
+                    x_red_temp_v1[()] = v_x_red_temp_v1
+            for ax0 in range(T.int64(3)):
+                with T.block("T_layer_norm"):
+                    v_ax0 = T.axis.spatial(T.int64(3), ax0)
+                    T.reads(x[v_ax0], x_red_temp_v0[()], x_red_temp_v1[()], 
layer_norm_weight[v_ax0], layer_norm_bias[v_ax0])
+                    T.writes(T_layer_norm[v_ax0])
+                    T_layer_norm[v_ax0] = (x[v_ax0] - x_red_temp_v0[()] * 
T.float32(0.33333333333333331)) * T.rsqrt(x_red_temp_v1[()] * 
T.float32(0.33333333333333331) - x_red_temp_v0[()] * 
T.float32(0.33333333333333331) * (x_red_temp_v0[()] * 
T.float32(0.33333333333333331)) + T.float32(1.0000000000000001e-05)) * 
layer_norm_weight[v_ax0] + layer_norm_bias[v_ax0]
+
+        @R.function
+        def forward(x: R.Tensor((3,), dtype="float32"), layer_norm_weight: 
R.Tensor((3,), dtype="float32"), layer_norm_bias: R.Tensor((3,), 
dtype="float32")) -> R.Tensor((3,), dtype="float32"):
+            R.func_attr({"num_input": 1})
+            cls = LayerNorm_1D_Expected
+            with R.dataflow():
+                layer_norm = R.call_tir(cls.layer_norm, (x, layer_norm_weight, 
layer_norm_bias), out_sinfo=R.Tensor((3,), dtype="float32"))
+                gv: R.Tensor((3,), dtype="float32") = layer_norm
+                R.output(gv)
+            return gv
+    # fmt: on
+    mod = LegalizeOps()(LayerNorm_1D)
+    tvm.ir.assert_structural_equal(mod, LayerNorm_1D_Expected)
+
+
 def test_layer_norm_fp16():
     # fmt: off
     @tvm.script.ir_module

Reply via email to