This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new fcd266e8db [Fix][ONNX] No precision widening for numpy binary 
operations (#18207)
fcd266e8db is described below

commit fcd266e8db609f0d4e1380f9b40e098ecc91c2a2
Author: Balint Cristian <[email protected]>
AuthorDate: Mon Aug 18 05:29:12 2025 +0300

    [Fix][ONNX] No precision widening for numpy binary operations (#18207)
---
 python/tvm/relax/frontend/onnx/onnx_frontend.py | 16 ++++++-----
 tests/python/relax/test_frontend_onnx.py        | 37 +++++++++++++++++++------
 2 files changed, 38 insertions(+), 15 deletions(-)

diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py 
b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index ee80436a8a..b91106e64a 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -336,15 +336,17 @@ class BinaryBase(OnnxOpConverter):
         """Base implementation for binary operations."""
         if cls.numpy_op is None or cls.relax_op is None:
             raise ValueError("Numpy and Relax operators must be defined for 
BinaryBase.")
-        if all([isinstance(inp, relax.Constant) for inp in inputs]):
-            output = cls.numpy_op(  # pylint: disable=not-callable
-                inputs[0].data.numpy(), inputs[1].data.numpy()
-            )
-            return relax.const(output, inputs[0].struct_info.dtype)
-        if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+        if all([not isinstance(inp, (relax.expr.Call, relax.Var)) for inp in 
inputs]):
             x = _to_numpy(inputs[0])
             y = _to_numpy(inputs[1])
-            return relax.PrimValue(cls.numpy_op(x, y))  # pylint: 
disable=not-callable
+            output = cls.numpy_op(x, y)  # pylint: disable=not-callable
+            if x.dtype == y.dtype:
+                # no numpy precision widening
+                output = output.astype(x.dtype)
+            if all([isinstance(inp, relax.Constant) for inp in inputs]):
+                return relax.const(output, output.dtype)  # pylint: 
disable=not-callable
+            if any([isinstance(inp, relax.PrimValue) for inp in inputs]):
+                return relax.PrimValue(output.item())  # pylint: 
disable=not-callable
 
         return cls.relax_op(inputs[0], inputs[1])  # pylint: 
disable=not-callable
 
diff --git a/tests/python/relax/test_frontend_onnx.py 
b/tests/python/relax/test_frontend_onnx.py
index 3d112c2f3b..b55489a623 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -27,7 +27,7 @@ import numpy as np
 import onnx
 import onnxruntime
 import pytest
-from onnx import ModelProto, TensorProto, helper, mapping
+from onnx import ModelProto, TensorProto, helper
 
 import tvm
 import tvm.testing
@@ -62,7 +62,7 @@ def generate_random_inputs(
 def generate_random_value(shape, elem_type) -> np.ndarray:
     # Extract datatype for the input.
     if elem_type:
-        dtype = str(onnx.mapping.TENSOR_TYPE_TO_NP_TYPE[elem_type])
+        dtype = str(helper.tensor_dtype_to_np_dtype(elem_type))
     else:
         dtype = "float32"
 
@@ -87,6 +87,7 @@ def check_correctness(
     opset: int = 14,
     rtol: float = 1e-7,
     atol: float = 1e-5,
+    check_dtypes: bool = False,
 ) -> None:
     """Run an onnx model in both onnxruntime and TVM through our importer
        confirm that the results match. Otherwise, an exception will be raised.
@@ -104,6 +105,8 @@ def check_correctness(
     atol: float
         Set the tolerance of correctness checking. Some ops may be show more
         arithmetic variance than others.
+    check_dtypes: bool
+        Check if data types are the same.
     """
     # Configure model format.
     if ir_version is not None:
@@ -152,17 +155,35 @@ def check_correctness(
         # while the ONNX output number is one, which is a list
         tvm_output = [tvm_output]
 
+    def _get_numpy_subdtype(narray):
+        if np.issubdtype(narray.dtype, np.integer):
+            return "integer"
+        elif np.issubdtype(narray.dtype, np.floating):
+            return "floating"
+        elif np.issubdtype(narray.dtype, np.bool_):
+            return "bool"
+        elif np.issubdtype(narray.dtype, np.complexfloating):
+            return "complexfloating"
+        else:
+            return "other"
+
     def _check_output(tvm_out, ort_out):
         if isinstance(tvm_out, tuple) and isinstance(ort_out, 
(tvm.runtime.ShapeTuple, list)):
             assert len(tvm_out) == len(ort_out), "Unequal number of outputs"
             for tvm_out_i, ort_out_i in zip(tvm_out, ort_out):
                 _check_output(tvm_out_i, ort_out_i)
         elif isinstance(tvm_out, tvm.nd.NDArray) and isinstance(ort_out, 
np.ndarray):
+            if check_dtypes:
+                assert tvm_out.numpy().dtype == ort_out.dtype
             tvm.testing.assert_allclose(tvm_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
         elif isinstance(tvm_out, tvm.runtime.ShapeTuple) and 
isinstance(ort_out, np.ndarray):
             shape_out = tvm.nd.array([int(i) for i in tvm_out])
+            if check_dtypes:
+                assert _get_numpy_subdtype(shape_out.numpy()) == 
_get_numpy_subdtype(ort_out)
             tvm.testing.assert_allclose(shape_out.numpy(), ort_out, rtol=rtol, 
atol=atol)
         elif isinstance(tvm_out, (int, float, bool)) and isinstance(ort_out, 
np.ndarray):
+            if check_dtypes:
+                assert _get_numpy_subdtype(np.array(tvm_out)) == 
_get_numpy_subdtype(ort_out)
             tvm.testing.assert_allclose(np.array(tvm_out), ort_out, rtol=rtol, 
atol=atol)
         else:
             raise ValueError(f"Unsupported types: {type(tvm_out)}, 
{type(ort_out)}")
@@ -267,7 +288,7 @@ def verify_binary(
     )
 
     model = helper.make_model(graph, producer_name="binary_test")
-    check_correctness(model, opset=opset)
+    check_correctness(model, opset=opset, check_dtypes=True)
 
 
 def verify_binary_scalar(op_name, attrs={}, domain=None, 
dtype=TensorProto.INT32, opset=14):
@@ -282,7 +303,7 @@ def verify_binary_scalar(op_name, attrs={}, domain=None, 
dtype=TensorProto.INT32
     )
 
     model = helper.make_model(graph, producer_name="binary_test")
-    check_correctness(model, opset=opset)
+    check_correctness(model, opset=opset, check_dtypes=True)
 
 
 def verify_compare(op_name, shape, attrs={}, domain=None):
@@ -1897,7 +1918,7 @@ def test_constantofshape():
             ["input"],
             ["output"],
             value=helper.make_tensor(
-                "value", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], 
(1,), (value,)
+                "value", helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), 
(1,), (value,)
             ),
         )
 
@@ -1917,7 +1938,7 @@ def test_constantofshape():
             ],
             outputs=[
                 helper.make_tensor_value_info(
-                    "output", mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)], 
input_dim
+                    "output", 
helper.np_dtype_to_tensor_dtype(np.dtype(dtype)), input_dim
                 )
             ],
         )
@@ -2299,7 +2320,7 @@ def test_split(fp_arith, dynamic):
 
         inputs = [
             helper.make_tensor_value_info(
-                "input", mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype], 
indata_shape
+                "input", helper.np_dtype_to_tensor_dtype(indata.dtype), 
indata_shape
             )
         ]
 
@@ -2333,7 +2354,7 @@ def test_split(fp_arith, dynamic):
             outputs=[
                 helper.make_tensor_value_info(
                     f"output_{i}",
-                    mapping.NP_TYPE_TO_TENSOR_TYPE[indata.dtype],
+                    helper.np_dtype_to_tensor_dtype(indata.dtype),
                     list(outdata_shapes[i]),
                 )
                 for i in range(len(split_index))

Reply via email to