This is an automated email from the ASF dual-hosted git repository.
mshr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 359e05e1bb [FRONTEND][ONNX] Extend axes for layer_norm when gamma/beta
are multi-dimensional (#18143)
359e05e1bb is described below
commit 359e05e1bb45ae235fddf3fbe314ad2807f9202d
Author: ConvolutedDog <[email protected]>
AuthorDate: Mon Jul 14 14:45:21 2025 +0800
[FRONTEND][ONNX] Extend axes for layer_norm when gamma/beta are
multi-dimensional (#18143)
* Extend axes for layer_norm when gamma/beta are multi-dimensional
* Add testcase for layer_norm when gamma/beta are multi-dimensional
* fix lint
* fix lint
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 11 +++++
tests/python/relax/test_frontend_onnx.py | 64 +++++++++++++++++++++----
2 files changed, 66 insertions(+), 9 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 5cf324086e..926da7f022 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -48,6 +48,7 @@ from tvm import TVMError, relax, tir, topi
from tvm.ir import IRModule
from tvm.ir.supply import NameSupply
from tvm.tir.generic import cast
+from tvm.topi.utils import get_const_tuple
from ..common import autopad
@@ -2500,9 +2501,19 @@ class LayerNormalization(OnnxOpConverter):
axis = attr.get("axis", -1)
epsilon = attr.get("epsilon", 1e-05)
+ gamma_shape = get_const_tuple(scale.struct_info.shape)
+
if bias is None:
seq_len = data.struct_info.shape[1].value
bias = relax.const([0.0] * seq_len, dtype="float32")
+ else:
+ beta_shape = get_const_tuple(bias.struct_info.shape)
+ if gamma_shape != beta_shape:
+ raise ValueError("gamma and beta shapes do not match")
+
+ axis = list(axis) if isinstance(axis, (list, tuple)) else [axis]
+ if len(axis) < len(gamma_shape):
+ axis.extend(range(axis[-1] + 1, axis[-1] + 1 + len(gamma_shape) -
len(axis)))
output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon)
# Onnx layernorm has 3 outputs but only the first is used.
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 4afe7cd0aa..d93d662e36 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -1282,18 +1282,20 @@ def test_mean_variance_norm():
def test_layer_norm():
- layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"],
["d"], epsilon=1e-12)
+ layer_norm_node = helper.make_node(
+ "LayerNormalization", ["input", "scale", "bias"], ["Y"], epsilon=1e-12
+ )
graph = helper.make_graph(
[layer_norm_node],
"layer_norm_test",
inputs=[
- helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
- helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]),
- helper.make_tensor_value_info("c", TensorProto.FLOAT, [32]),
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [32,
32]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, [32]),
],
outputs=[
- helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]),
],
)
@@ -1301,17 +1303,19 @@ def test_layer_norm():
check_correctness(model)
# Test case with no bias that is an optional input
- layer_norm_node = helper.make_node("LayerNormalization", ["a", "b"],
["d"], epsilon=1e-12)
+ layer_norm_node = helper.make_node(
+ "LayerNormalization", ["input", "scale"], ["Y"], epsilon=1e-12
+ )
graph = helper.make_graph(
[layer_norm_node],
"layer_norm_test",
inputs=[
- helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]),
- helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]),
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [32,
32]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
],
outputs=[
- helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]),
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]),
],
)
@@ -1319,6 +1323,48 @@ def test_layer_norm():
check_correctness(model)
+def test_layer_norm_with_nd_gamma_beta():
+ layer_norm_node = helper.make_node(
+ "LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=1,
epsilon=1e-12
+ )
+
+ graph = helper.make_graph(
+ [layer_norm_node],
+ "layer_norm_with_nd_gamma_beta_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3,
4, 4]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT, [3, 4,
4]),
+ helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3, 4,
4]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4,
4]),
+ ],
+ )
+
+ model = helper.make_model(graph,
producer_name="layer_norm_with_nd_gamma_beta_test")
+ check_correctness(model)
+
+ # Test case with no bias that is an optional input
+ layer_norm_node = helper.make_node(
+ "LayerNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-12
+ )
+
+ graph = helper.make_graph(
+ [layer_norm_node],
+ "layer_norm_with_nd_gamma_beta_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [32,
32]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]),
+ ],
+ )
+
+ model = helper.make_model(graph,
producer_name="layer_norm_with_nd_gamma_beta_test")
+ check_correctness(model)
+
+
# TODO Enable dynamism
@pytest.mark.parametrize("dynamic", [False])
def test_skiplayernormalization(dynamic):