This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d9d129c93 [Relax][ONNX] Fix Cast operator float->int NaN/Inf handling 
(#19626)
4d9d129c93 is described below

commit 4d9d129c93a0ac93e1c2643b3f35a67b05c0b451
Author: Neo Chien <[email protected]>
AuthorDate: Fri Jun 5 07:51:55 2026 +0800

    [Relax][ONNX] Fix Cast operator float->int NaN/Inf handling (#19626)
    
    Hi Committers,
    
    This PR is trying to fix issues #19542. Any suggestions would be
    appreciated if you are available.
    
    ### Root cause:
    FP to INT lowering can be implementation-defined or UB for NaN/Inf and
    extreme floats, producing backend-dependent results versus ONNX Runtime.
    
    ### Solution:
    Apply a minimal, deterministic frontend sanitization for float to
    integer Casts: map NaN and ±Inf to 0.0 before astype. This prevents
    NaN/Inf from reaching backend fptosi/fptoui lowers and yields stable
    behavior across targets.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 57 +++++++++++++++++++++++++
 tests/python/relax/test_frontend_onnx.py        | 31 ++++++++++++++
 2 files changed, 88 insertions(+)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index b82fceff1d..3a2a0fdaf2 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1105,6 +1105,63 @@ class Cast(OnnxOpConverter):
             return relax.const(output, to_type)
         if isinstance(inputs[0], relax.PrimValue):
             return relax.PrimValue(inputs[0].value.astype(to_type))
+
+        try:
+            np_dst = _np.dtype(str(to_type))
+        except Exception:
+            return relax.op.astype(inputs[0], to_type)
+
+        if np_dst.kind in ("i", "u"):
+            src = inputs[0]
+            src_dtype = getattr(getattr(src, "struct_info", None), "dtype", 
None) or getattr(
+                src, "dtype", None
+            )
+            if src_dtype is not None and 
_relax_dtype_is_floating_point(src_dtype):
+                x_sanitized = bb.emit(
+                    relax.op.where(
+                        relax.op.logical_not(relax.op.isfinite(src)),
+                        relax.const(0.0, src_dtype),
+                        src,
+                    )
+                )
+                dst_str = str(to_type)
+                if dst_str.startswith("uint"):
+                    signed = False
+                    bits = int(dst_str[4:])
+                elif dst_str.startswith("int"):
+                    signed = True
+                    bits = int(dst_str[3:])
+                else:
+                    return relax.op.astype(x_sanitized, to_type)
+
+                if bits == 64:
+                    return relax.op.astype(x_sanitized, to_type)
+
+                temp_dtype = "int64" if bits >= 32 else "int32"
+                t = relax.op.astype(x_sanitized, temp_dtype)
+                if bits == 32:
+                    two_pow = relax.const(1 << bits, temp_dtype)
+                    uw = relax.op.floor_mod(t, two_pow)
+                else:
+                    mask_val = (1 << bits) - 1
+                    mask = relax.const(mask_val, temp_dtype)
+                    uw = relax.op.bitwise_and(t, mask)
+                if signed:
+                    half = 1 << (bits - 1)
+                    half_c = relax.const(half, temp_dtype)
+                    if bits == 32:
+                        two_pow = relax.const(1 << bits, temp_dtype)
+                    else:
+                        two_pow = relax.op.add(mask, relax.const(1, 
temp_dtype))
+                    wrapped = relax.op.where(
+                        relax.op.greater_equal(uw, half_c),
+                        relax.op.subtract(uw, two_pow),
+                        uw,
+                    )
+                else:
+                    wrapped = uw
+                return relax.op.astype(wrapped, to_type)
+
         return relax.op.astype(inputs[0], to_type)
 
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 7ee10993a4..9a644c4a3a 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -863,6 +863,37 @@ def test_cast(from_type, to_type):
     check_correctness(model, opset=13)
 
 
[email protected]("to_type", [TensorProto.INT64, TensorProto.UINT64])
+def test_cast_float_to_64bit_int_dynamic(to_type):
+    cast_node = helper.make_node("Cast", ["a"], ["b"], to=to_type)
+    graph = helper.make_graph(
+        [cast_node],
+        "cast_float_to_64bit_int_dynamic_test",
+        inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, [1, 8])],
+        outputs=[helper.make_tensor_value_info("b", to_type, [1, 8])],
+    )
+    model = helper.make_model(graph, 
producer_name="cast_float_to_64bit_int_dynamic_test")
+    inputs = {"a": np.array([[0.0, 1.2, 2.8, 7.9, 15.1, 31.7, 63.4, 127.9]], 
dtype=np.float32)}
+    check_correctness(model, inputs=inputs, opset=13, check_dtypes=True)
+
+
+def test_cast_nan_inf_to_int8():
+    vals = np.array([300.0, np.nan, np.inf, -np.inf, 50.0, -50.0], 
dtype=np.float32)
+    node = helper.make_node("Cast", inputs=["a"], outputs=["b"], 
to=TensorProto.INT8)
+    graph = helper.make_graph(
+        [node],
+        "cast_nan_inf_test",
+        inputs=[helper.make_tensor_value_info("a", TensorProto.FLOAT, 
list(vals.shape))],
+        outputs=[helper.make_tensor_value_info("b", TensorProto.INT8, 
list(vals.shape))],
+    )
+    model = helper.make_model(graph, producer_name="cast_nan_inf_test")
+    tvm_output = run_in_tvm(model, inputs={"a": vals}, opset=13)
+    out_np = tvm_output.numpy()
+    expected = np.array([44, 0, 0, 0, 50, -50], dtype=np.int8)
+    assert out_np.dtype == np.int8
+    np.testing.assert_array_equal(out_np, expected)
+
+
 def test_gather():
     def _verify_gather(data_shape, indices, out_shape, axis=0):
         gather_node = helper.make_node("Gather", ["data", "indices"], ["y"], 
axis=axis)

Reply via email to