This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2e6ee08eaf [BugFix] Align `tir.round` to ties-to-even across all
backends (#19368)
2e6ee08eaf is described below
commit 2e6ee08eafc328b82a965e49d106d61828c1d623
Author: Soowon Jeong <[email protected]>
AuthorDate: Thu Apr 9 03:35:22 2026 +0900
[BugFix] Align `tir.round` to ties-to-even across all backends (#19368)
## Problem
`tir.round` constant-folds using `std::nearbyint` (IEEE 754
ties-to-even), but all backends lower it to platform `round()` which
uses ties-away-from-zero. This means compiled code can produce different
results from constant-folded code for midpoint values:
| Input | Constant-fold (ties-to-even) | Compiled (ties-away) |
|-------|-----|------|
| 0.5 | 0.0 | 1.0 |
| 2.5 | 2.0 | 3.0 |
| -0.5 | 0.0 | -1.0 |
This was identified as a follow-up to #19367 — see [this
comment](https://github.com/apache/tvm/pull/19367#issuecomment-4201800320).
## Fix
Align all backends to use ties-to-even intrinsics, matching the
constant-folding behavior:
| Backend | Before | After |
|---------|--------|-------|
| LLVM/ROCm/Hexagon | `llvm::Intrinsic::round` |
`llvm::Intrinsic::nearbyint` |
| NVPTX | `__nv_round[f]` | `__nv_nearbyint[f]` |
| CUDA | `round`/`roundf` | `nearbyint`/`nearbyintf` (f16/bf16 already
used `hrint`) |
| Metal/OpenCL | `round` | `rint` |
| Vulkan/SPIR-V | `GLSLstd450Round` | `GLSLstd450RoundEven` |
Also fixes OpenCL codegen where `tir.nearbyint` was incorrectly mapped
to OpenCL `round()` instead of `rint()`.
Updates `op.h` documentation to explicitly state ties-to-even semantics
for both `round()` and `nearbyint()`.
## Testing
```
python -m pytest tests/python/tirx-base/test_tir_intrin.py -xvs
```
New `test_round_ties_to_even` verifies midpoint inputs `[0.5, 1.5, 2.5,
3.5, -0.5, -1.5, -2.5, -3.5]` produce ties-to-even results on the LLVM
backend. All 12 tests pass (10 passed, 2 skipped for CUDA).
---------
Co-authored-by: Claude Opus 4.6 (1M context) <[email protected]>
---
include/tvm/tirx/op.h | 12 +++++++++---
python/tvm/topi/testing/roi_pool_python.py | 10 ++++++----
python/tvm/topi/vision/roi_pool.py | 15 +++++++++++----
src/target/llvm/intrin_rule_hexagon.cc | 2 +-
src/target/llvm/intrin_rule_llvm.cc | 2 +-
src/target/llvm/intrin_rule_nvptx.cc | 10 +++++++++-
src/target/llvm/intrin_rule_rocm.cc | 2 +-
src/target/source/codegen_opencl.cc | 2 +-
src/target/source/intrin_rule_cuda.cc | 3 +++
src/target/source/intrin_rule_metal.cc | 11 ++++++++++-
src/target/source/intrin_rule_opencl.cc | 11 ++++++++++-
src/target/spirv/intrin_rule_spirv.cc | 6 ++++--
tests/python/tirx-base/test_tir_intrin.py | 21 +++++++++++++++++++++
13 files changed, 87 insertions(+), 20 deletions(-)
diff --git a/include/tvm/tirx/op.h b/include/tvm/tirx/op.h
index 66d9d932b3..c953f12e38 100644
--- a/include/tvm/tirx/op.h
+++ b/include/tvm/tirx/op.h
@@ -654,7 +654,11 @@ TVM_DLL PrimExpr floor(PrimExpr x, Span span = Span());
TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
/*!
- * \brief Calculate round(x)
+ * \brief Round x to the nearest integer, ties to even.
+ *
+ * Uses IEEE 754 default rounding mode (ties-to-even / banker's rounding).
+ * Constant-folding and all backends consistently use std::nearbyint semantics.
+ *
* \param x The input expression.
* \param span The location of this operation in the source.
* \return The result expression.
@@ -662,11 +666,13 @@ TVM_DLL PrimExpr ceil(PrimExpr x, Span span = Span());
TVM_DLL PrimExpr round(PrimExpr x, Span span = Span());
/*!
- * \brief Calculates std::nearbyint(x)
+ * \brief Round x to the nearest integer, ties to even.
+ *
+ * Equivalent to round(). Both use IEEE 754 default rounding mode
(ties-to-even).
+ *
* \param x The input expression.
* \param span The location of this operation in the source.
* \return The result expression.
- * This is a faster alternate to round.
*/
TVM_DLL PrimExpr nearbyint(PrimExpr x, Span span = Span());
diff --git a/python/tvm/topi/testing/roi_pool_python.py
b/python/tvm/topi/testing/roi_pool_python.py
index 0f7120b466..583800e982 100644
--- a/python/tvm/topi/testing/roi_pool_python.py
+++ b/python/tvm/topi/testing/roi_pool_python.py
@@ -36,10 +36,12 @@ def roi_pool_nchw_python(a_np, rois_np, pooled_size,
spatial_scale):
for i in range(num_roi):
roi = rois_np[i]
batch_index = int(roi[0])
- roi_start_w = round(roi[1] * spatial_scale)
- roi_start_h = round(roi[2] * spatial_scale)
- roi_end_w = round(roi[3] * spatial_scale)
- roi_end_h = round(roi[4] * spatial_scale)
+ # Use ties-away-from-zero rounding to match ONNX runtime (std::round
semantics).
+ # Python's built-in round() uses ties-to-even, so use floor(x + 0.5)
explicitly.
+ roi_start_w = math.floor(roi[1] * spatial_scale + 0.5)
+ roi_start_h = math.floor(roi[2] * spatial_scale + 0.5)
+ roi_end_w = math.floor(roi[3] * spatial_scale + 0.5)
+ roi_end_h = math.floor(roi[4] * spatial_scale + 0.5)
roi_h = max(roi_end_h - roi_start_h + 1, 1)
roi_w = max(roi_end_w - roi_start_w + 1, 1)
diff --git a/python/tvm/topi/vision/roi_pool.py
b/python/tvm/topi/vision/roi_pool.py
index 54a4aeba50..2e86066c5b 100644
--- a/python/tvm/topi/vision/roi_pool.py
+++ b/python/tvm/topi/vision/roi_pool.py
@@ -36,12 +36,19 @@ def roi_pool_nchw(data, rois, pooled_size, spatial_scale):
neg_inf = tvm.tirx.const(float("-inf"), data.dtype)
+ def _round_away(x):
+ # ONNX MaxRoiPool spec uses ties-away-from-zero rounding for coordinate
+ # mapping (matching std::round semantics in the reference
implementation).
+ # Use floor(x + 0.5) to be explicit and independent of tir.round
semantics.
+ half = tvm.tirx.const(0.5, roi_dtype)
+ return te.floor(x + half)
+
def _bin_bounds(i, ph, pw):
roi = rois[i]
- roi_start_w = te.round(roi[1] * spatial_scale).astype("int32")
- roi_start_h = te.round(roi[2] * spatial_scale).astype("int32")
- roi_end_w = te.round(roi[3] * spatial_scale).astype("int32")
- roi_end_h = te.round(roi[4] * spatial_scale).astype("int32")
+ roi_start_w = _round_away(roi[1] * spatial_scale).astype("int32")
+ roi_start_h = _round_away(roi[2] * spatial_scale).astype("int32")
+ roi_end_w = _round_away(roi[3] * spatial_scale).astype("int32")
+ roi_end_h = _round_away(roi[4] * spatial_scale).astype("int32")
roi_h = te.max(roi_end_h - roi_start_h + 1, tvm.tirx.const(1, "int32"))
roi_w = te.max(roi_end_w - roi_start_w + 1, tvm.tirx.const(1, "int32"))
diff --git a/src/target/llvm/intrin_rule_hexagon.cc
b/src/target/llvm/intrin_rule_hexagon.cc
index 79e91c20a3..e330dba4e1 100644
--- a/src/target/llvm/intrin_rule_hexagon.cc
+++ b/src/target/llvm/intrin_rule_hexagon.cc
@@ -93,7 +93,7 @@ TVM_REGISTER_OP("tirx.fabs")
TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
-
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
TVM_REGISTER_OP("tirx.ctpop")
.set_attr<FLowerIntrinsic>("hexagon.FLowerIntrinsic",
diff --git a/src/target/llvm/intrin_rule_llvm.cc
b/src/target/llvm/intrin_rule_llvm.cc
index 468f0fb7b5..3244deab87 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -90,7 +90,7 @@ TVM_REGISTER_OP("tirx.fabs")
TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
-
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("llvm.FLowerIntrinsic",
diff --git a/src/target/llvm/intrin_rule_nvptx.cc
b/src/target/llvm/intrin_rule_nvptx.cc
index 4560205a60..0707a9a787 100644
--- a/src/target/llvm/intrin_rule_nvptx.cc
+++ b/src/target/llvm/intrin_rule_nvptx.cc
@@ -66,7 +66,15 @@ TVM_REGISTER_OP("tirx.ceil")
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
DispatchPureExternLibDevice);
TVM_REGISTER_OP("tirx.round")
- .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
DispatchPureExternLibDevice);
+ .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", [](const PrimExpr& e)
-> PrimExpr {
+ // Redirect to nearbyint (ties-to-even) to match constant-folding
semantics.
+ using namespace tirx;
+ const CallNode* call = e.as<CallNode>();
+ TVM_FFI_ICHECK(call != nullptr);
+ auto nearbyint_op = Op::Get("tirx.nearbyint");
+ auto new_call = Call(call->dtype, nearbyint_op, call->args);
+ return DispatchPureExternLibDevice(new_call);
+ });
TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
DispatchPureExternLibDevice);
diff --git a/src/target/llvm/intrin_rule_rocm.cc
b/src/target/llvm/intrin_rule_rocm.cc
index 6d72c77783..4d542c1299 100644
--- a/src/target/llvm/intrin_rule_rocm.cc
+++ b/src/target/llvm/intrin_rule_rocm.cc
@@ -132,7 +132,7 @@ TVM_REGISTER_OP("tirx.ceil")
TVM_REGISTER_OP("tirx.round")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
-
DispatchLLVMPureIntrin<::llvm::Intrinsic::round, 1>);
+
DispatchLLVMPureIntrin<::llvm::Intrinsic::nearbyint, 1>);
TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
diff --git a/src/target/source/codegen_opencl.cc
b/src/target/source/codegen_opencl.cc
index 5d9135ef22..b2f78c2dbd 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -526,7 +526,7 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op,
std::ostream& os) {
this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)),
"atomic_add_float_emu", op->args,
true, os);
} else if (func->value == "nearbyint") {
- this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "round",
op->args, true, os);
+ this->PrintCallExtern(GetType(ffi::GetRef<PrimExpr>(op)), "rint",
op->args, true, os);
} else {
if (func->value == "atomic_add") {
enable_atomics_ = true;
diff --git a/src/target/source/intrin_rule_cuda.cc
b/src/target/source/intrin_rule_cuda.cc
index bcd158432b..d38db9fe83 100644
--- a/src/target/source/intrin_rule_cuda.cc
+++ b/src/target/source/intrin_rule_cuda.cc
@@ -37,8 +37,11 @@ struct CUDAMath {
if (t.is_float()) {
switch (t.bits()) {
case 64:
+ // Use nearbyint (ties-to-even) for round to match constant-folding
semantics.
+ if (name == "round") return "nearbyint";
return name;
case 32:
+ if (name == "round") return "nearbyintf";
return name + 'f';
case 16: {
if (name == "fabs") {
diff --git a/src/target/source/intrin_rule_metal.cc
b/src/target/source/intrin_rule_metal.cc
index d61bf1256f..cea19519ca 100644
--- a/src/target/source/intrin_rule_metal.cc
+++ b/src/target/source/intrin_rule_metal.cc
@@ -68,7 +68,16 @@ TVM_REGISTER_OP("tirx.fabs")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.round")
- .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
+ .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", [](const PrimExpr& e)
-> PrimExpr {
+ // Metal's rint() uses ties-to-even, matching constant-folding semantics.
+ const tirx::CallNode* call = e.as<tirx::CallNode>();
+ TVM_FFI_ICHECK(call != nullptr);
+ ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
+ for (auto arg : call->args) {
+ new_args.push_back(arg);
+ }
+ return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(),
new_args);
+ });
TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
DispatchPureExtern<Direct>);
diff --git a/src/target/source/intrin_rule_opencl.cc
b/src/target/source/intrin_rule_opencl.cc
index 85084b1a16..ba1873bde6 100644
--- a/src/target/source/intrin_rule_opencl.cc
+++ b/src/target/source/intrin_rule_opencl.cc
@@ -47,7 +47,16 @@ TVM_REGISTER_OP("tirx.fabs")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
DispatchPureExtern<Direct>);
TVM_REGISTER_OP("tirx.round")
- .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
DispatchPureExtern<Direct>);
+ .set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic", [](const PrimExpr& e)
-> PrimExpr {
+ // OpenCL's rint() uses ties-to-even, matching constant-folding
semantics.
+ const tirx::CallNode* call = e.as<tirx::CallNode>();
+ TVM_FFI_ICHECK(call != nullptr);
+ ffi::Array<PrimExpr> new_args = {tirx::StringImm("rint")};
+ for (auto arg : call->args) {
+ new_args.push_back(arg);
+ }
+ return tirx::Call(call->dtype, tirx::builtin::call_pure_extern(),
new_args);
+ });
TVM_REGISTER_OP("tirx.nearbyint")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
DispatchPureExtern<Direct>);
diff --git a/src/target/spirv/intrin_rule_spirv.cc
b/src/target/spirv/intrin_rule_spirv.cc
index cde1e0165f..4b1ffc4b6d 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -68,10 +68,12 @@ TVM_REGISTER_OP("tirx.ceil")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Ceil>);
TVM_REGISTER_OP("tirx.round")
- .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Round>);
+ .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
+ DispatchGLSLPureIntrin<GLSLstd450RoundEven>);
TVM_REGISTER_OP("tirx.nearbyint")
- .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Round>);
+ .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
+ DispatchGLSLPureIntrin<GLSLstd450RoundEven>);
TVM_REGISTER_OP("tirx.trunc")
.set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
DispatchGLSLPureIntrin<GLSLstd450Trunc>);
diff --git a/tests/python/tirx-base/test_tir_intrin.py
b/tests/python/tirx-base/test_tir_intrin.py
index 0dd06dee93..30676715b8 100644
--- a/tests/python/tirx-base/test_tir_intrin.py
+++ b/tests/python/tirx-base/test_tir_intrin.py
@@ -56,6 +56,27 @@ def test_nearbyint():
tvm.testing.assert_allclose(a_rounded.numpy(), np.rint(a.numpy()))
+def test_round_ties_to_even():
+ """Test that tir.round uses ties-to-even (banker's rounding) semantics."""
+ m = te.var("m")
+ A = te.placeholder((m,), name="A")
+ A_rounded = te.compute((m,), lambda *i: tvm.tirx.round(A(*i)), name="A")
+
+ mod = te.create_prim_func([A, A_rounded])
+ sch = tvm.s_tir.Schedule(mod)
+ func = tvm.compile(sch.mod, target="llvm")
+
+ dev = tvm.cpu(0)
+ # Midpoint values where ties-to-even and ties-away differ
+ test_values = np.array([0.5, 1.5, 2.5, 3.5, -0.5, -1.5, -2.5, -3.5],
dtype="float32")
+ expected = np.array([0.0, 2.0, 2.0, 4.0, 0.0, -2.0, -2.0, -4.0],
dtype="float32")
+
+ a = tvm.runtime.tensor(test_values, dev)
+ a_rounded = tvm.runtime.tensor(np.zeros(len(test_values),
dtype="float32"), dev)
+ func(a, a_rounded)
+ tvm.testing.assert_allclose(a_rounded.numpy(), expected)
+
+
def test_round_intrinsics_on_int():
i = tvm.tirx.Var("i", "int32")
for op in [tvm.tirx.round, tvm.tirx.trunc, tvm.tirx.ceil, tvm.tirx.floor,
tvm.tirx.nearbyint]: