gemini-code-assist[bot] commented on code in PR #19535:
URL: https://github.com/apache/tvm/pull/19535#discussion_r3218267649


##########
python/tvm/relax/frontend/onnx/onnx_frontend.py:
##########
@@ -52,12 +52,26 @@
 from tvm import TVMError, relax, tirx, topi
 from tvm.ir import IRModule
 from tvm.ir.supply import NameSupply
+from tvm.runtime import DataType, DataTypeCode
 from tvm.tirx.generic import cast
 from tvm.topi.utils import get_const_tuple
 
 from ..common import autopad
 
 
+def _relax_dtype_is_floating_point(dtype: str) -> bool:
+    """Whether a Relax dtype string is a floating point type."""
+    try:
+        code = DataType(dtype).type_code
+    except (ValueError, TypeError, TVMError):
+        return False
+    return (
+        code == DataTypeCode.FLOAT
+        or code == DataTypeCode.BFLOAT
+        or (code >= DataTypeCode.Float8E3M4 and code <= 
DataTypeCode.Float4E2M1FN)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The casing for `DataTypeCode` members is inconsistent. While `FLOAT` and 
`BFLOAT` are correctly in all caps (following the `IntEnum` convention in TVM), 
the newer FP8/FP4 types like `Float8E3M4` and `Float4E2M1FN` are written in 
CamelCase. In the TVM Python API, these are typically defined as all-caps 
members (e.g., `FLOAT8E4M3FN`). Using the wrong casing will result in an 
`AttributeError` at runtime.
   
   ```python
       return (
           code == DataTypeCode.FLOAT
           or code == DataTypeCode.BFLOAT
           or (code >= DataTypeCode.FLOAT8E3M4 and code <= 
DataTypeCode.FLOAT4E2M1FN)
       )
   ```



##########
tests/python/relax/test_frontend_onnx.py:
##########
@@ -1435,6 +1453,50 @@ def test_clip_v6(max, min):
     check_correctness(model, opset=10)
 
 
[email protected](
+    "min,max",
+    [
+        pytest.param(
+            np.array(0.0, dtype=np.float32),
+            np.array(np.nan, dtype=np.float32),
+        ),
+        pytest.param(
+            np.array(0.0, dtype=np.float32),
+            np.array(np.nan, dtype=np.float32),
+        ),
+        pytest.param(
+            np.array(np.nan, dtype=np.float32),
+            np.array(6.0, dtype=np.float32),
+        ),
+        pytest.param(
+            np.array(np.nan, dtype=np.float32),
+            np.array(np.nan, dtype=np.float32),
+        ),
+    ],
+)
+def test_clip_v13(min, max):
+    # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); 
input NaN preserved.
+    clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
+    graph = helper.make_graph(
+        [clip_node],
+        "clip_v13_nan_max",
+        inputs=[
+            helper.make_tensor_value_info("input", TensorProto.FLOAT, [5]),
+            helper.make_tensor_value_info("min", TensorProto.FLOAT, []),
+            helper.make_tensor_value_info("max", TensorProto.FLOAT, []),
+        ],
+        outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[5])],
+    )
+    model = helper.make_model(graph, producer_name="clip_v13_nan_max")
+    input = np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32)
+    check_correctness(
+        model,
+        inputs={"input": input, "min": min, "max": max},
+        opset=13,
+        equal_nan=True,
+    )

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The PR description mentions that it fixes the preservation of NaNs from the 
input tensor, but the added test case `test_clip_v13` only includes NaN values 
in the bounds (`min`/`max`). It would be beneficial to include at least one 
`np.nan` in the `input` array to verify that input NaNs are indeed preserved 
through the operator.
   
   ```python
   def test_clip_v13(min, max):
       # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); 
input NaN preserved.
       clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"])
       graph = helper.make_graph(
           [clip_node],
           "clip_v13_nan_max",
           inputs=[
               helper.make_tensor_value_info("input", TensorProto.FLOAT, [6]),
               helper.make_tensor_value_info("min", TensorProto.FLOAT, []),
               helper.make_tensor_value_info("max", TensorProto.FLOAT, []),
           ],
           outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, 
[6])],
       )
       model = helper.make_model(graph, producer_name="clip_v13_nan_max")
       input = np.array([0.5, -3.0, 4.5, 11.0, 7.0, np.nan], dtype=np.float32)
       check_correctness(
           model,
           inputs={"input": input, "min": min, "max": max},
           opset=13,
           equal_nan=True,
       )
   ```



##########
tests/python/relax/test_frontend_onnx.py:
##########
@@ -177,16 +180,31 @@ def _check_output(tvm_out, ort_out):
         elif isinstance(tvm_out, tvm.runtime.Tensor) and isinstance(ort_out, 
np.ndarray):
             if check_dtypes:
                 assert tvm_out.numpy().dtype == ort_out.dtype
-            tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
+            if equal_nan:
+                np.testing.assert_allclose(
+                    tvm_out.numpy(), ort_out, rtol=rtol, atol=atol, 
equal_nan=True
+                )
+            else:
+                tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, 
rtol=rtol, atol=atol)
         elif isinstance(tvm_out, tvm_ffi.Shape) and isinstance(ort_out, 
np.ndarray):
             shape_out = tvm.runtime.tensor([int(i) for i in tvm_out])
             if check_dtypes:
                 assert _get_numpy_subdtype(shape_out.numpy()) == 
_get_numpy_subdtype(ort_out)
-            tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
+            if equal_nan:
+                np.testing.assert_allclose(
+                    shape_out.numpy(), ort_out, rtol=rtol, atol=atol, 
equal_nan=True
+                )
+            else:
+                tvm.testing.assert_allclose(shape_out.numpy(), ort_out, 
rtol=rtol, atol=atol)
         elif isinstance(tvm_out, int | float | bool) and isinstance(ort_out, 
np.ndarray):
             if check_dtypes:
                 assert _get_numpy_subdtype(np.array(tvm_out)) == 
_get_numpy_subdtype(ort_out)
-            tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, 
atol=atol)
+            if equal_nan:
+                np.testing.assert_allclose(
+                    np.array(tvm_out), ort_out, rtol=rtol, atol=atol, 
equal_nan=True
+                )
+            else:
+                tvm.testing.assert_allclose(np.array(tvm_out), ort_out, 
rtol=rtol, atol=atol)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The logic for handling `equal_nan` comparisons is duplicated across three 
different type-check branches in `_check_output`. This could be refactored to 
avoid repetition and improve maintainability by extracting the `actual` value 
first and then performing a single comparison check at the end of the function.



##########
tests/python/relax/test_frontend_onnx.py:
##########
@@ -1435,6 +1453,50 @@ def test_clip_v6(max, min):
     check_correctness(model, opset=10)
 
 
[email protected](
+    "min,max",
+    [
+        pytest.param(
+            np.array(0.0, dtype=np.float32),
+            np.array(np.nan, dtype=np.float32),
+        ),
+        pytest.param(
+            np.array(0.0, dtype=np.float32),
+            np.array(np.nan, dtype=np.float32),
+        ),

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   These two `pytest.param` entries are identical. The duplicate should be 
removed to keep the test suite concise.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to