gemini-code-assist[bot] commented on code in PR #19590:
URL: https://github.com/apache/tvm/pull/19590#discussion_r3270912813


##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3787,6 +3787,41 @@ def _impl_v17(cls, bb, inputs, attr, params):
         return relax.Tuple([output, placeholder, placeholder])
 
 
+class RMSNormalization(OnnxOpConverter):
+    """Converts an onnx RMSNormalization node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v23(cls, bb, inputs, attr, params):
+        data = inputs[0]
+        scale = inputs[1]
+        axis = attr.get("axis", -1)
+        epsilon = attr.get("epsilon", 1e-05)
+        stash_type = attr.get("stash_type", 1)
+
+        # Determine normalization axes: from `axis` to the last dimension
+        ndim = len(data.struct_info.shape)
+        if axis < 0:
+            axis = ndim + axis
+        axes = list(range(axis, ndim))
+
+        # If stash_type requires float32 computation and input is not float32, 
cast
+        input_dtype = data.struct_info.dtype
+        if stash_type == 1 and input_dtype != "float32":
+            data_compute = relax.op.astype(data, "float32")
+            scale_compute = relax.op.astype(scale, "float32")
+        else:
+            data_compute = data
+            scale_compute = scale
+
+        output = relax.op.nn.rms_norm(data_compute, scale_compute, axes, 
epsilon)
+
+        # Cast back to original dtype if needed
+        if stash_type == 1 and input_dtype != "float32":
+            output = relax.op.astype(output, input_dtype)
+
+        return output

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The ONNX `RMSNormalization` operator (opset 23) specifies two outputs: the 
normalized tensor `Y` and an optional stashed inverse standard deviation 
`inv_std_dev`. To ensure compatibility with ONNX models that might request both 
outputs, the converter should return a `relax.Tuple`. Following the pattern 
used in other normalization converters in this frontend, a placeholder can be 
provided for the unused second output.
   
   ```suggestion
           # ONNX RMSNormalization has 2 outputs: Y and inv_std_dev.
           # We return a placeholder for the second output.
           placeholder = relax.const(0, dtype="float32" if stash_type == 1 else 
input_dtype)
           return relax.Tuple([output, placeholder])
   ```



##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -3787,6 +3787,41 @@ def _impl_v17(cls, bb, inputs, attr, params):
         return relax.Tuple([output, placeholder, placeholder])
 
 
+class RMSNormalization(OnnxOpConverter):
+    """Converts an onnx RMSNormalization node into an equivalent Relax 
expression."""
+
+    @classmethod
+    def _impl_v23(cls, bb, inputs, attr, params):
+        data = inputs[0]
+        scale = inputs[1]
+        axis = attr.get("axis", -1)
+        epsilon = attr.get("epsilon", 1e-05)
+        stash_type = attr.get("stash_type", 1)
+
+        # Determine normalization axes: from `axis` to the last dimension
+        ndim = len(data.struct_info.shape)
+        if axis < 0:
+            axis = ndim + axis
+        axes = list(range(axis, ndim))

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Accessing `data.struct_info.shape` directly can be unsafe if the input rank 
is unknown (i.e., `ndim == -1`), as `shape` would be `None`. It is recommended 
to use `_get_known_tensor_rank` to safely retrieve the rank and provide a 
descriptive error message if it's unavailable. Additionally, using 
`_normalize_constant_axes` is a more robust way to handle negative axis values 
and perform bounds checking.
   
   ```suggestion
           ndim = _get_known_tensor_rank(data)
           if ndim is None:
               raise ValueError("RMSNormalization requires a statically known 
input rank.")
           axis = _normalize_constant_axes([axis], ndim, "RMSNormalization")[0]
           axes = list(range(axis, ndim))
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to