This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a1e4cd82fe [Relay/ONNX] Add RMSNormalization converter for ONNX opset
23 (#19590)
a1e4cd82fe is described below
commit a1e4cd82fecbf95f711c18b4509316174d73cb40
Author: hh <[email protected]>
AuthorDate: Thu May 21 12:45:55 2026 +0800
[Relay/ONNX] Add RMSNormalization converter for ONNX opset 23 (#19590)
Add support for the ONNX RMSNormalization operator (opset 23) in the
Relax ONNX frontend. This operator is essential for importing LLM models
(LLaMA, Gemma, etc.) that use RMS normalization.
The implementation:
- Maps ONNX RMSNormalization to relax.op.nn.rms_norm
- Supports the axis, epsilon, and stash_type attributes
- Handles float16 inputs with stash_type=1 (compute in float32)
- Includes unit tests comparing against ONNX Runtime
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 37 +++++++++++++++
tests/python/relax/test_frontend_onnx.py | 62 +++++++++++++++++++++++++
2 files changed, 99 insertions(+)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 6624110241..1a224e431b 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -3787,6 +3787,42 @@ class LayerNormalization(OnnxOpConverter):
return relax.Tuple([output, placeholder, placeholder])
+class RMSNormalization(OnnxOpConverter):
+ """Converts an onnx RMSNormalization node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v23(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ scale = inputs[1]
+ axis = attr.get("axis", -1)
+ epsilon = attr.get("epsilon", 1e-05)
+ stash_type = attr.get("stash_type", 1)
+
+ # Determine normalization axes: from `axis` to the last dimension
+ ndim = _get_known_tensor_rank(data)
+ if ndim is None:
+ raise ValueError("RMSNormalization requires a statically known
input rank.")
+ axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0]
+ axes = list(range(axis, ndim))
+
+ # If stash_type requires float32 computation and input is not float32,
cast
+ input_dtype = data.struct_info.dtype
+ if stash_type == 1 and input_dtype != "float32":
+ data_compute = relax.op.astype(data, "float32")
+ scale_compute = relax.op.astype(scale, "float32")
+ else:
+ data_compute = data
+ scale_compute = scale
+
+ output = relax.op.nn.rms_norm(data_compute, scale_compute, axes,
epsilon)
+
+ # Cast back to original dtype if needed
+ if stash_type == 1 and input_dtype != "float32":
+ output = relax.op.astype(output, input_dtype)
+
+ return output
+
+
class ReduceMax(OnnxOpConverter):
"""Converts an onnx ReduceMax node into an equivalent Relax expression."""
@@ -5129,6 +5165,7 @@ def _get_convert_map():
# Normalization
"BatchNormalization": BatchNormalization,
"LayerNormalization": LayerNormalization,
+ "RMSNormalization": RMSNormalization,
"SkipLayerNormalization": SkipLayerNormalization,
"EmbedLayerNormalization": EmbedLayerNormalization,
"InstanceNormalization": InstanceNormalization,
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index b658a2aaba..2b0194f085 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -2309,6 +2309,68 @@ def test_layer_norm_with_nd_gamma_beta():
check_correctness(model)
+def test_rms_norm():
+ # Basic test: default axis=-1
+ rms_norm_node = helper.make_node(
+ "RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05
+ )
+
+ graph = helper.make_graph(
+ [rms_norm_node],
+ "rms_norm_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [2, 8,
32]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT, [32]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 8, 32]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="rms_norm_test")
+ check_correctness(model, opset=23)
+
+ # Test with explicit axis=1 (normalize over last 2 dims)
+ rms_norm_node = helper.make_node(
+ "RMSNormalization", ["input", "scale"], ["Y"], axis=1, epsilon=1e-06
+ )
+
+ graph = helper.make_graph(
+ [rms_norm_node],
+ "rms_norm_axis_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT, [4, 8,
16]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT, [8, 16]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT, [4, 8, 16]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="rms_norm_axis_test")
+ check_correctness(model, opset=23)
+
+ # Test with float16 input (stash_type=1 means compute in float32)
+ rms_norm_node = helper.make_node(
+ "RMSNormalization", ["input", "scale"], ["Y"], epsilon=1e-05,
stash_type=1
+ )
+
+ graph = helper.make_graph(
+ [rms_norm_node],
+ "rms_norm_fp16_test",
+ inputs=[
+ helper.make_tensor_value_info("input", TensorProto.FLOAT16, [2, 8,
32]),
+ helper.make_tensor_value_info("scale", TensorProto.FLOAT16, [32]),
+ ],
+ outputs=[
+ helper.make_tensor_value_info("Y", TensorProto.FLOAT16, [2, 8,
32]),
+ ],
+ )
+
+ model = helper.make_model(graph, producer_name="rms_norm_fp16_test")
+ check_correctness(model, opset=23, rtol=1e-2, atol=1e-2)
+
+
# TODO Enable dynamism
@pytest.mark.parametrize("dynamic", [False])
def test_skiplayernormalization(dynamic):