This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new fcd266e8db [Fix][ONNX] No precision widening for numpy binary
operations (#18207)
fcd266e8db is described below
commit fcd266e8db609f0d4e1380f9b40e098ecc91c2a2
Author: Balint Cristian <[email protected]>
AuthorDate: Mon Aug 18 05:29:12 2025 +0300
[Fix][ONNX] No precision widening for numpy binary operations (#18207)
---
python/tvm/relax/frontend/onnx/onnx_frontend.py | 16 ++++++-----
tests/python/relax/test_frontend_onnx.py | 37 +++++++++++++++++++------
2 files changed, 38 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index ee80436a8a..b91106e64a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -336,15 +336,17 @@ class BinaryBase(OnnxOpConverter):
"""Base implementation for binary operations."""
if cls.numpy_op is None or cls.relax_op is None:
raise ValueError("Numpy and Relax operators must be defined for
BinaryBase.")
- if all([isinstance(inp, relax.Constant) for inp in inputs]):
- output = cls.numpy_op( # pylint: disable=not-callable
- inputs[0].data.numpy(), inputs[1].data.numpy()
- )
- return relax.const(output, inputs[0].struct_info.dtype)
- if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+ if all([not isinstance(inp, (relax.expr.Call, relax.Var)) for inp in
inputs]):
x = _to_numpy(inputs[0])
y = _to_numpy(inputs[1])
- return relax.PrimValue(cls.numpy_op(x, y)) # pylint:
disable=not-callable
+ output = cls.numpy_op(x, y) # pylint: disable=not-callable
+ if x.dtype == y.dtype:
+ # no numpy precision widening
+ output = output.astype(x.dtype)
+ if all([isinstance(inp, relax.Constant) for inp in inputs]):
+ return relax.const(output, output.dtype) # pylint:
disable=not-callable
+ if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+ return relax.PrimValue(output.item()) # pylint:
disable=not-callable
return cls.relax_op(inputs[0], inputs[1]) # pylint:
disable=not-callable
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index 3d112c2f3b..b55489a623 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -27,7 +27,7 @@ import numpy as np
import onnx
import onnxruntime
import pytest
-from onnx import ModelProto, TensorProto, helper, mapping
+from onnx import ModelProto, TensorProto, helper
import tvm
import tvm.testing
@@ -62,7 +62,7 @@ def generate_random_inputs(
def generate_random_value(shape, elem_type) -> np.ndarray:
# Extract datatype for the input.
if elem_type:
- dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
+ dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
else:
dtype = "float32"
@@ -87,6 +87,7 @@ def check_correctness(
opset: int = 14,
rtol: float = 1e-7,
atol: float = 1e-5,
+ check_dtypes: bool = False,
) -> None:
"""Run an onnx model in both onnxruntime and TVM through our importer
confirm that the results match. Otherwise, an exception will be raised.
@@ -104,6 +105,8 @@ def check_correctness(
atol: float
Set the tolerance of correctness checking. Some ops may be show more
arithmetic variance than others.
+ check_dtypes: bool
+ Check if data types are the same.
"""
# Configure model format.
if ir_version is not None:
@@ -152,17 +155,35 @@ def check_correctness(
# while the ONNX output number is one, which is a list
tvm_output = [tvm_output]
+ def _get_numpy_subdtype(narray):
+ if np.issubdtype(narray.dtype, np.integer):
+ return "integer"
+ elif np.issubdtype(narray.dtype, np.floating):
+ return "floating"
+ elif np.issubdtype(narray.dtype, np.bool_):
+ return "bool"
+ elif np.issubdtype(narray.dtype, np.complexfloating):
+ return "complexfloating"
+ else:
+ return "other"
+
def _check_output(tvm_out, ort_out):
if isinstance(tvm_out, tuple) and isinstance(ort_out,
(tvm.runtime.ShapeTuple, list)):
assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
_check_output(tvm_out_i, ort_out_i)
elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out,
np.ndarray):
+ if check_dtypes:
+ assert tvm_out.numpy().dtype == ort_out.dtype
tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol,
atol=atol)
elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and
isinstance(ort_out, np.ndarray):
shape_out = tvm.nd.array([int(i) for i in tvm_out])
+ if check_dtypes:
+ assert _get_numpy_subdtype(shape_out.numpy()) ==
_get_numpy_subdtype(ort_out)
tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol,
atol=atol)
elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out,
np.ndarray):
+ if check_dtypes:
+ assert _get_numpy_subdtype(np.array(tvm_out)) ==
_get_numpy_subdtype(ort_out)
tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol,
atol=atol)
else:
raise ValueError(f"Unsupported types: {type(tvm_out)},
{type(ort_out)}")
@@ -267,7 +288,7 @@ def verify_binary(
)
model = helper.make_model(graph, producer_name="binary_test")
- check_correctness(model, opset=opset)
+ check_correctness(model, opset=opset, check_dtypes=True)
def verify_binary_scalar(op_name, attrs={}, domain=None,
dtype=TensorProto.INT32, opset=14):
@@ -282,7 +303,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None,
dtype=TensorProto.INT32
)
model = helper.make_model(graph, producer_name="binary_test")
- check_correctness(model, opset=opset)
+ check_correctness(model, opset=opset, check_dtypes=True)
def verify_compare(op_name, shape, attrs={}, domain=None):
@@ -1897,7 +1918,7 @@ def test_constantofshape():
["input"],
["output"],
value=helper.make_tensor(
- "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
(1,), (value,)
+ "value", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)),
(1,), (value,)
),
)
@@ -1917,7 +1938,7 @@ def test_constantofshape():
],
outputs=[
helper.make_tensor_value_info(
- "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)],
input_dim
+ "output",
helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), input_dim
)
],
)
@@ -2299,7 +2320,7 @@ def test_split(fp_arith, dynamic):
inputs = [
helper.make_tensor_value_info(
- "input", mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
indata_shape
+ "input", helper.np_dtype_to_tensor_dtype(indata.dtype),
indata_shape
)
]
@@ -2333,7 +2354,7 @@ def test_split(fp_arith, dynamic):
outputs=[
helper.make_tensor_value_info(
f"output_{i}",
- mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
+ helper.np_dtype_to_tensor_dtype(indata.dtype),
list(outdata_shapes[i]),
)
for i in range(len(split_index))