This is an automated email from the ASF dual-hosted git repository. mshr pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 359e05e1bb [FRONTEND][ONNX] Extend axes for layer_norm when gamma/beta are multi-dimensional (#18143) 359e05e1bb is described below commit 359e05e1bb45ae235fddf3fbe314ad2807f9202d Author: ConvolutedDog <yangjiancha...@nudt.edu.cn> AuthorDate: Mon Jul 14 14:45:21 2025 +0800 [FRONTEND][ONNX] Extend axes for layer_norm when gamma/beta are multi-dimensional (#18143) * Extend axes for layer_norm when gamma/beta are multi-dimensional * Add testcase for layer_norm when gamma/beta are multi-dimensional * fix lint * fix lint --- python/tvm/relax/frontend/onnx/onnx_frontend.py | 11 +++++ tests/python/relax/test_frontend_onnx.py | 64 +++++++++++++++++++++---- 2 files changed, 66 insertions(+), 9 deletions(-) diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py index 5cf324086e..926da7f022 100644 --- a/python/tvm/relax/frontend/onnx/onnx_frontend.py +++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py @@ -48,6 +48,7 @@ from tvm import TVMError, relax, tir, topi from tvm.ir import IRModule from tvm.ir.supply import NameSupply from tvm.tir.generic import cast +from tvm.topi.utils import get_const_tuple from ..common import autopad @@ -2500,9 +2501,19 @@ class LayerNormalization(OnnxOpConverter): axis = attr.get("axis", -1) epsilon = attr.get("epsilon", 1e-05) + gamma_shape = get_const_tuple(scale.struct_info.shape) + if bias is None: seq_len = data.struct_info.shape[1].value bias = relax.const([0.0] * seq_len, dtype="float32") + else: + beta_shape = get_const_tuple(bias.struct_info.shape) + if gamma_shape != beta_shape: + raise ValueError("gamma and beta shapes do not match") + + axis = list(axis) if isinstance(axis, (list, tuple)) else [axis] + if len(axis) < len(gamma_shape): + axis.extend(range(axis[-1] + 1, axis[-1] + 1 + len(gamma_shape) - len(axis))) output = relax.op.nn.layer_norm(data, scale, bias, axis, epsilon) # Onnx layernorm has 3 outputs but only the first is used. diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py index 4afe7cd0aa..d93d662e36 100644 --- a/tests/python/relax/test_frontend_onnx.py +++ b/tests/python/relax/test_frontend_onnx.py @@ -1282,18 +1282,20 @@ def test_mean_variance_norm(): def test_layer_norm(): - layer_norm_node = helper.make_node("LayerNormalization", ["a", "b", "c"], ["d"], epsilon=1e-12) + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale", "bias"], ["Y"], epsilon=1e-12 + ) graph = helper.make_graph( [layer_norm_node], "layer_norm_test", inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]), - helper.make_tensor_value_info("c", TensorProto.FLOAT, [32]), + helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [32]), ], outputs=[ - helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]), ], ) @@ -1301,17 +1303,19 @@ def test_layer_norm(): check_correctness(model) # Test case with no bias that is an optional input - layer_norm_node = helper.make_node("LayerNormalization", ["a", "b"], ["d"], epsilon=1e-12) + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale"], ["Y"], epsilon=1e-12 + ) graph = helper.make_graph( [layer_norm_node], "layer_norm_test", inputs=[ - helper.make_tensor_value_info("a", TensorProto.FLOAT, [32, 32]), - helper.make_tensor_value_info("b", TensorProto.FLOAT, [32]), + helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), ], outputs=[ - helper.make_tensor_value_info("d", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]), ], ) @@ -1319,6 +1323,48 @@ def test_layer_norm(): check_correctness(model) +def test_layer_norm_with_nd_gamma_beta(): + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale", "bias"], ["Y"], axis=1, epsilon=1e-12 + ) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_with_nd_gamma_beta_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [1, 3, 4, 4]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [3, 4, 4]), + helper.make_tensor_value_info("bias", TensorProto.FLOAT, [3, 4, 4]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [1, 3, 4, 4]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_with_nd_gamma_beta_test") + check_correctness(model) + + # Test case with no bias that is an optional input + layer_norm_node = helper.make_node( + "LayerNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-12 + ) + + graph = helper.make_graph( + [layer_norm_node], + "layer_norm_with_nd_gamma_beta_test", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [32, 32]), + helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]), + ], + outputs=[ + helper.make_tensor_value_info("Y", TensorProto.FLOAT, [32, 32]), + ], + ) + + model = helper.make_model(graph, producer_name="layer_norm_with_nd_gamma_beta_test") + check_correctness(model) + + # TODO Enable dynamism @pytest.mark.parametrize("dynamic", [False]) def test_skiplayernormalization(dynamic):