This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new aa59644072 [Relax][ONNX] Preserve NaN in Sign to align with ONNX
Runtime (#19674)
aa59644072 is described below
commit aa596440727bf768604aac0f347333097852759a
Author: Neo Chien <[email protected]>
AuthorDate: Sat Jun 6 19:22:25 2026 +0800
[Relax][ONNX] Preserve NaN in Sign to align with ONNX Runtime (#19674)
Hi Committers,
This PR fixes issues https://github.com/apache/tvm/issues/19543. Any
suggestions would be appreciated if you are available.
### Root cause:
The ONNX frontend `Sign` converter directly returned `relax.op.sign(x)`.
After legalization, this maps to `topi.sign`, which is implemented via
comparisons (x < 0 ? -1 : x > 0 ? 1 : 0). For `NaN`, both comparisons
are false, so TVM produced 0, while ONNX Runtime preserves NaN. This
created a frontend semantic mismatch for imported ONNX models.
### Solution:
Apply a minimal ONNX-frontend-only fix in `onnx_frontend.py`:
- For floating-point inputs, lower `Sign` as `where(isnan(x), x,
sign(x))`.
- Keep non-floating inputs unchanged (`sign(x)`).
---------
Co-authored-by: cchung100m <[email protected]>
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 7 +++++-
tests/python/relax/test_frontend_onnx.py | 29 +++++++++++++++++++++++++
2 files changed, 35 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 3a2a0fdaf2..0e3ccef08c 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -4588,7 +4588,12 @@ class Sign(OnnxOpConverter):
@classmethod
def _impl_v9(cls, bb, inputs, attr, params):
- return relax.op.sign(inputs[0])
+ x = inputs[0]
+ x_dtype = x.struct_info.dtype if isinstance(x.struct_info,
relax.TensorStructInfo) else None
+ y = relax.op.sign(x)
+ if x_dtype is not None and _relax_dtype_is_floating_point(x_dtype):
+ return relax.op.where(relax.op.isnan(x), x, y)
+ return y
class Not(OnnxOpConverter):
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 9a644c4a3a..471186589e 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -771,6 +771,35 @@ def test_unary(op_name: str):
verify_unary(op_name, [8, 8, 8], input_dtype=input_dtype,
output_dtype=output_dtype)
+def test_sign_nan_preserve():
+ sign_node = helper.make_node("Sign", ["x"], ["y"])
+ graph = helper.make_graph(
+ [sign_node],
+ "sign_nan_test",
+ inputs=[helper.make_tensor_value_info("x", TensorProto.FLOAT, [4])],
+ outputs=[helper.make_tensor_value_info("y", TensorProto.FLOAT, [4])],
+ )
+ model = helper.make_model(graph, producer_name="sign_nan_test")
+ model.ir_version = 8
+ for opset_import in model.opset_import:
+ if opset_import.domain in ["", "ai.onnx"]:
+ opset_import.version = 18
+ break
+ x = np.array([np.nan, 9.0, -9.0, np.nan], dtype=np.float32)
+
+ ort_out = onnxruntime.InferenceSession(
+ model.SerializeToString(), providers=["CPUExecutionProvider"]
+ ).run([], {"x": x})[0]
+
+ tvm_out = run_in_tvm(model, inputs={"x": x}, opset=18)
+ out_np = (tvm_out[0] if isinstance(tvm_out, list | tuple) else
tvm_out).numpy()
+
+ np.testing.assert_array_equal(np.isnan(out_np), np.isnan(ort_out))
+ np.testing.assert_allclose(
+ out_np[~np.isnan(ort_out)], ort_out[~np.isnan(ort_out)], rtol=1e-7,
atol=1e-5
+ )
+
+
@pytest.mark.parametrize("op_name", ["Softmax", "LogSoftmax", "Hardmax"])
def test_softmax_family_opset11_default_axis_semantics(op_name: str):
verify_unary(op_name, [2, 3, 4], opset=11)