This is an automated email from the ASF dual-hosted git repository.

mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 359e05e1bb [FRONTEND][ONNX] Extend axes for layer_norm when gamma/beta 
are multi-dimensional (#18143)
359e05e1bb is described below

commit 359e05e1bb45ae235fddf3fbe314ad2807f9202d
Author: ConvolutedDog <yangjiancha...@nudt.edu.cn>
AuthorDate: Mon Jul 14 14:45:21 2025 +0800

    [FRONTEND][ONNX] Extend axes for layer_norm when gamma/beta are 
multi-dimensional (#18143)
    
    * Extend axes for layer_norm when gamma/beta are multi-dimensional
    
    * Add testcase for layer_norm when gamma/beta are multi-dimensional
    
    * fix lint
    
    * fix lint
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 11 +++++
 tests/python/relax/test_frontend_onnx.py        | 64 +++++++++++++++++++++----
 2 files changed, 66 insertions(+), 9 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5cf324086e..926da7f022 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -48,6 +48,7 @@ from tvm import TVMError, relax, tir, topi
 from tvm.ir import IRModule
 from tvm.ir.supply import NameSupply
 from tvm.tir.generic import cast
+from tvm.topi.utils import get_const_tuple
 
 from ..common import autopad
 
@@ -2500,9 +2501,19 @@ class LayerNormalization(OnnxOpConverter):
         axis = attr.get("axis", -1)
         epsilon = attr.get("epsilon", 1e-05)
 
+        gamma_shape = get_const_tuple(scale.struct_info.shape)
+
         if bias is None:
             seq_len = data.struct_info.shape[1].value
             bias = relax.const([0.0] * seq_len, dtype="float32")
+        else:
+            beta_shape = get_const_tuple(bias.struct_info.shape)
+            if gamma_shape != beta_shape:
+                raise ValueError("gamma and beta shapes do not match")
+
+        axis = list(axis) if isinstance(axis, (list, tuple)) else [axis]
+        if len(axis) < len(gamma_shape):
+            axis.extend(range(axis[-1] + 1, axis[-1] + 1 + len(gamma_shape) - 
len(axis)))
 
         output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon)
         # Onnx layernorm has 3 outputs but only the first is used.
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 4afe7cd0aa..d93d662e36 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1282,18 +1282,20 @@ def test_mean_variance_norm():
 
 
 def test_layer_norm():
-    layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"], 
["d"], epsilon=1e-12)
+    layer_norm_node = helper.make_node(
+        "LayerNormalization", ["input", "scale", "bias"], ["Y"], epsilon=1e-12
+    )
 
     graph = helper.make_graph(
         [layer_norm_node],
         "layer_norm_test",
         inputs=[
-            helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
-            helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]),
-            helper.make_tensor_value_info("c", TensorProto.FLOAT, [32]),
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 
32]),
+            helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
+            helper.make_tensor_value_info("bias", TensorProto.FLOAT, [32]),
         ],
         outputs=[
-            helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]),
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]),
         ],
     )
 
@@ -1301,17 +1303,19 @@ def test_layer_norm():
     check_correctness(model)
 
     # Test case with no bias that is an optional input
-    layer_norm_node = helper.make_node("LayerNormalization", ["a", "b"], 
["d"], epsilon=1e-12)
+    layer_norm_node = helper.make_node(
+        "LayerNormalization", ["input", "scale"], ["Y"], epsilon=1e-12
+    )
 
     graph = helper.make_graph(
         [layer_norm_node],
         "layer_norm_test",
         inputs=[
-            helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
-            helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]),
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 
32]),
+            helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
         ],
         outputs=[
-            helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]),
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]),
         ],
     )
 
@@ -1319,6 +1323,48 @@ def test_layer_norm():
     check_correctness(model)
 
 
+def test_layer_norm_with_nd_gamma_beta():
+    layer_norm_node = helper.make_node(
+        "LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=1, 
epsilon=1e-12
+    )
+
+    graph = helper.make_graph(
+        [layer_norm_node],
+        "layer_norm_with_nd_gamma_beta_test",
+        inputs=[
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 
4, 4]),
+            helper.make_tensor_value_info("scale", TensorProto.FLOAT, [3, 4, 
4]),
+            helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3, 4, 
4]),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 
4]),
+        ],
+    )
+
+    model = helper.make_model(graph, 
producer_name="layer_norm_with_nd_gamma_beta_test")
+    check_correctness(model)
+
+    # Test case with no bias that is an optional input
+    layer_norm_node = helper.make_node(
+        "LayerNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-12
+    )
+
+    graph = helper.make_graph(
+        [layer_norm_node],
+        "layer_norm_with_nd_gamma_beta_test",
+        inputs=[
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 
32]),
+            helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
+        ],
+        outputs=[
+            helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]),
+        ],
+    )
+
+    model = helper.make_model(graph, 
producer_name="layer_norm_with_nd_gamma_beta_test")
+    check_correctness(model)
+
+
 # TODO Enable dynamism
 @pytest.mark.parametrize("dynamic", [False])
 def test_skiplayernormalization(dynamic):

Reply via email to