gemini-code-assist[bot] commented on code in PR #19590:
URL: https://github.com/apache/tvm/pull/19590#discussion_r3270912813
##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3787,6 +3787,41 @@ def _impl_v17(cls, bb, inputs, attr, params):
return relax.Tuple([output, placeholder, placeholder])
+class RMSNormalization(OnnxOpConverter):
+ """Converts an onnx RMSNormalization node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v23(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ scale = inputs[1]
+ axis = attr.get("axis", -1)
+ epsilon = attr.get("epsilon", 1e-05)
+ stash_type = attr.get("stash_type", 1)
+
+ # Determine normalization axes: from `axis` to the last dimension
+ ndim = len(data.struct_info.shape)
+ if axis < 0:
+ axis = ndim + axis
+ axes = list(range(axis, ndim))
+
+ # If stash_type requires float32 computation and input is not float32,
cast
+ input_dtype = data.struct_info.dtype
+ if stash_type == 1 and input_dtype != "float32":
+ data_compute = relax.op.astype(data, "float32")
+ scale_compute = relax.op.astype(scale, "float32")
+ else:
+ data_compute = data
+ scale_compute = scale
+
+ output = relax.op.nn.rms_norm(data_compute, scale_compute, axes,
epsilon)
+
+ # Cast back to original dtype if needed
+ if stash_type == 1 and input_dtype != "float32":
+ output = relax.op.astype(output, input_dtype)
+
+ return output
Review Comment:

The ONNX `RMSNormalization` operator (opset 23) specifies two outputs: the
normalized tensor `Y` and an optional stashed inverse standard deviation
`inv_std_dev`. To ensure compatibility with ONNX models that might request both
outputs, the converter should return a `relax.Tuple`. Following the pattern
used in other normalization converters in this frontend, a placeholder can be
provided for the unused second output.
```suggestion
# ONNX RMSNormalization has 2 outputs: Y and inv_std_dev.
# We return a placeholder for the second output.
placeholder = relax.const(0, dtype="float32" if stash_type == 1 else
input_dtype)
return relax.Tuple([output, placeholder])
```
##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3787,6 +3787,41 @@ def _impl_v17(cls, bb, inputs, attr, params):
return relax.Tuple([output, placeholder, placeholder])
+class RMSNormalization(OnnxOpConverter):
+ """Converts an onnx RMSNormalization node into an equivalent Relax
expression."""
+
+ @classmethod
+ def _impl_v23(cls, bb, inputs, attr, params):
+ data = inputs[0]
+ scale = inputs[1]
+ axis = attr.get("axis", -1)
+ epsilon = attr.get("epsilon", 1e-05)
+ stash_type = attr.get("stash_type", 1)
+
+ # Determine normalization axes: from `axis` to the last dimension
+ ndim = len(data.struct_info.shape)
+ if axis < 0:
+ axis = ndim + axis
+ axes = list(range(axis, ndim))
Review Comment:

Accessing `data.struct_info.shape` directly can be unsafe if the input rank
is unknown (i.e., `ndim == -1`), as `shape` would be `None`. It is recommended
to use `_get_known_tensor_rank` to safely retrieve the rank and provide a
descriptive error message if it's unavailable. Additionally, using
`_normalize_constant_axes` is a more robust way to handle negative axis values
and perform bounds checking.
```suggestion
ndim = _get_known_tensor_rank(data)
if ndim is None:
raise ValueError("RMSNormalization requires a statically known
input rank.")
axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0]
axes = list(range(axis, ndim))
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]