gemini-code-assist[bot] commented on code in PR #19535: URL: https://github.com/apache/tvm/pull/19535#discussion_r3218267649
########## python/tvm/relax/frontend/onnx/onnx_frontend.py: ########## @@ -52,12 +52,26 @@ from tvm import TVMError, relax, tirx, topi from tvm.ir import IRModule from tvm.ir.supply import NameSupply +from tvm.runtime import DataType, DataTypeCode from tvm.tirx.generic import cast from tvm.topi.utils import get_const_tuple from ..common import autopad +def _relax_dtype_is_floating_point(dtype: str) -> bool: + """Whether a Relax dtype string is a floating point type.""" + try: + code = DataType(dtype).type_code + except (ValueError, TypeError, TVMError): + return False + return ( + code == DataTypeCode.FLOAT + or code == DataTypeCode.BFLOAT + or (code >= DataTypeCode.Float8E3M4 and code <= DataTypeCode.Float4E2M1FN) Review Comment:  The casing for `DataTypeCode` members is inconsistent. While `FLOAT` and `BFLOAT` are correctly in all caps (following the `IntEnum` convention in TVM), the newer FP8/FP4 types like `Float8E3M4` and `Float4E2M1FN` are written in CamelCase. In the TVM Python API, these are typically defined as all-caps members (e.g., `FLOAT8E4M3FN`). Using the wrong casing will result in an `AttributeError` at runtime. ```python return ( code == DataTypeCode.FLOAT or code == DataTypeCode.BFLOAT or (code >= DataTypeCode.FLOAT8E3M4 and code <= DataTypeCode.FLOAT4E2M1FN) ) ``` ########## tests/python/relax/test_frontend_onnx.py: ########## @@ -1435,6 +1453,50 @@ def test_clip_v6(max, min): check_correctness(model, opset=10) [email protected]( + "min,max", + [ + pytest.param( + np.array(0.0, dtype=np.float32), + np.array(np.nan, dtype=np.float32), + ), + pytest.param( + np.array(0.0, dtype=np.float32), + np.array(np.nan, dtype=np.float32), + ), + pytest.param( + np.array(np.nan, dtype=np.float32), + np.array(6.0, dtype=np.float32), + ), + pytest.param( + np.array(np.nan, dtype=np.float32), + np.array(np.nan, dtype=np.float32), + ), + ], +) +def test_clip_v13(min, max): + # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); input NaN preserved. + clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"]) + graph = helper.make_graph( + [clip_node], + "clip_v13_nan_max", + inputs=[ + helper.make_tensor_value_info("input", TensorProto.FLOAT, [5]), + helper.make_tensor_value_info("min", TensorProto.FLOAT, []), + helper.make_tensor_value_info("max", TensorProto.FLOAT, []), + ], + outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [5])], + ) + model = helper.make_model(graph, producer_name="clip_v13_nan_max") + input = np.array([0.5, -3.0, 4.5, 11.0, 7.0], dtype=np.float32) + check_correctness( + model, + inputs={"input": input, "min": min, "max": max}, + opset=13, + equal_nan=True, + ) Review Comment:  The PR description mentions that it fixes the preservation of NaNs from the input tensor, but the added test case `test_clip_v13` only includes NaN values in the bounds (`min`/`max`). It would be beneficial to include at least one `np.nan` in the `input` array to verify that input NaNs are indeed preserved through the operator. ```python def test_clip_v13(min, max): # Opset 13: tensor min/max. NaN bound => unbounded on that side (ORT); input NaN preserved. clip_node = helper.make_node("Clip", ["input", "min", "max"], ["output"]) graph = helper.make_graph( [clip_node], "clip_v13_nan_max", inputs=[ helper.make_tensor_value_info("input", TensorProto.FLOAT, [6]), helper.make_tensor_value_info("min", TensorProto.FLOAT, []), helper.make_tensor_value_info("max", TensorProto.FLOAT, []), ], outputs=[helper.make_tensor_value_info("output", TensorProto.FLOAT, [6])], ) model = helper.make_model(graph, producer_name="clip_v13_nan_max") input = np.array([0.5, -3.0, 4.5, 11.0, 7.0, np.nan], dtype=np.float32) check_correctness( model, inputs={"input": input, "min": min, "max": max}, opset=13, equal_nan=True, ) ``` ########## tests/python/relax/test_frontend_onnx.py: ########## @@ -177,16 +180,31 @@ def _check_output(tvm_out, ort_out): elif isinstance(tvm_out, tvm.runtime.Tensor) and isinstance(ort_out, np.ndarray): if check_dtypes: assert tvm_out.numpy().dtype == ort_out.dtype - tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) + if equal_nan: + np.testing.assert_allclose( + tvm_out.numpy(), ort_out, rtol=rtol, atol=atol, equal_nan=True + ) + else: + tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, atol=atol) elif isinstance(tvm_out, tvm_ffi.Shape) and isinstance(ort_out, np.ndarray): shape_out = tvm.runtime.tensor([int(i) for i in tvm_out]) if check_dtypes: assert _get_numpy_subdtype(shape_out.numpy()) == _get_numpy_subdtype(ort_out) - tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) + if equal_nan: + np.testing.assert_allclose( + shape_out.numpy(), ort_out, rtol=rtol, atol=atol, equal_nan=True + ) + else: + tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, atol=atol) elif isinstance(tvm_out, int | float | bool) and isinstance(ort_out, np.ndarray): if check_dtypes: assert _get_numpy_subdtype(np.array(tvm_out)) == _get_numpy_subdtype(ort_out) - tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol) + if equal_nan: + np.testing.assert_allclose( + np.array(tvm_out), ort_out, rtol=rtol, atol=atol, equal_nan=True + ) + else: + tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, atol=atol) Review Comment:  The logic for handling `equal_nan` comparisons is duplicated across three different type-check branches in `_check_output`. This could be refactored to avoid repetition and improve maintainability by extracting the `actual` value first and then performing a single comparison check at the end of the function. ########## tests/python/relax/test_frontend_onnx.py: ########## @@ -1435,6 +1453,50 @@ def test_clip_v6(max, min): check_correctness(model, opset=10) [email protected]( + "min,max", + [ + pytest.param( + np.array(0.0, dtype=np.float32), + np.array(np.nan, dtype=np.float32), + ), + pytest.param( + np.array(0.0, dtype=np.float32), + np.array(np.nan, dtype=np.float32), + ), Review Comment:  These two `pytest.param` entries are identical. The duplicate should be removed to keep the test suite concise. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
