This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 13ce30b988 [BugFix][Target][LLVM] Route sinh/cosh/atan/asinh/erf
through libm extern (#19568)
13ce30b988 is described below
commit 13ce30b988cbadef110eff7daafecdf498266e14
Author: Soowon Jeong <[email protected]>
AuthorDate: Tue May 19 15:00:53 2026 +0900
[BugFix][Target][LLVM] Route sinh/cosh/atan/asinh/erf through libm extern
(#19568)
## Summary
Six LLVM legalize rules in `src/target/llvm/intrin_rule_llvm.cc` use
inline mathematical identities that fail on representable inputs because
the intermediate computation overflows or cancels, even though the true
result is in `float32` range:
| Op | Inline form | Failure | True result |
|---|---|---|---|
| `sinh`/`cosh` (#19559) | `(exp(x) ± exp(-x)) / 2` | `exp(89) >
FLT_MAX`, intermediate is `inf` | `sinh(89) ≈ 2.24e38` |
| `atan` (#19560) | `asin(x / sqrt(x²+1))` | `x²` overflows for `|x| >
1.84e19`, then `x/inf=0`, `asin(0)=0` | `±π/2` |
| `asinh` (#19561) | `log(x + sqrt(x²+1))` | same `x²` overflow →
`log(inf)=inf` | `asinh(3e22) ≈ 52.45` |
| `erf` (#19562) | A&S `1 − poly(t)·exp(−x²)` | `poly·exp(−x²) ≈ 1` for
tiny `|x|`; subtraction cancels to 0 | `erf(3e-12) ≈ 3.4e-12` |
| `acosh` (no issue) | `log(x + sqrt(x²−1))` | same `x²` overflow →
`inf` | `acosh(3e22) ≈ 52.45` |
`acosh` was not in the original issue cluster but shows the identical
bug pattern to `asinh`; folding it in keeps this PR's scope consistent
("naive math identity → libm extern"). Happy to split it out if
reviewers prefer.
## Fix
Route all six through the existing `DispatchPureExtern<FloatSuffix>`
helper — i.e. `sinhf`, `coshf`, `atanf`, `asinhf`, `acoshf`, `erff` —
the same pattern `asin`/`acos` use after #19567. ULP-grade accuracy
across the reported ranges.
```
sinh(89.0): ORT=2.244806e+38 TVM=2.244806e+38 (was inf)
atan(3e22): ORT=1.5707964 TVM=1.5707963 (was 0.0)
asinh(3e22): ORT=52.44863 TVM=52.44863 (was inf)
acosh(3e22): ORT=52.44863 TVM=52.44863 (was inf)
erf(3e-12): ORT=3.385e-12 TVM=3.385e-12 (was 0.0)
```
`Atan` is re-enabled in `test_unary`; the overflow that previously broke
it is fixed.
## Notes for reviewers
**Inline-vs-extern decision.** If the inline identities were a
deliberate fast-path (e.g. for autovectorization or to avoid extern-call
overhead in tight loops), please flag it and I'll switch to stable
inline forms instead — `exp(x − ln 2) ± exp(−x − ln 2)` for sinh/cosh,
range-reduced asinh/acosh `sign(x)·log(2|x|)` for large `|x|`,
small-`|x|` Taylor branch for erf, etc. I could not find evidence of
such intent in the git history (sinh/cosh: original commit;
atan/asinh/acosh: #17945 / #17969 follow-ups; erf: #18104 was framed as
"more precise than tanh-approx", not "fast inline").
Fixes #19559.
Fixes #19560.
Fixes #19561.
Fixes #19562.
---
src/target/llvm/intrin_rule_llvm.cc | 82 --------------------------------
tests/python/relax/test_frontend_onnx.py | 2 +-
2 files changed, 1 insertion(+), 83 deletions(-)
diff --git a/src/target/llvm/intrin_rule_llvm.cc
b/src/target/llvm/intrin_rule_llvm.cc
index ae57e8d9a6..4a2246c4b1 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -141,36 +141,6 @@ TVM_REGISTER_OP("tirx.tan")
return tan_x;
});
-TVM_REGISTER_OP("tirx.cosh")
- .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
- using tirx::make_const;
- using tirx::make_zero;
- const tirx::CallNode* call = e.as<tirx::CallNode>();
- TVM_FFI_ICHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr two = make_const(x.dtype(), 2);
- PrimExpr neg_one = make_const(x.dtype(), -1);
- PrimExpr exp_negx = exp(neg_one * x);
- PrimExpr exp_posx = exp(x);
- PrimExpr ret = (exp_posx + exp_negx) / two;
- return ret;
- });
-
-TVM_REGISTER_OP("tirx.sinh")
- .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
- using tirx::make_const;
- using tirx::make_zero;
- const tirx::CallNode* call = e.as<tirx::CallNode>();
- TVM_FFI_ICHECK(call != nullptr);
- const PrimExpr& x = call->args[0];
- PrimExpr two = make_const(x.dtype(), 2);
- PrimExpr neg_one = make_const(x.dtype(), -1);
- PrimExpr exp_negx = exp(neg_one * x);
- PrimExpr exp_posx = exp(x);
- PrimExpr ret = (exp_posx - exp_negx) / two;
- return ret;
- });
-
TVM_REGISTER_OP("tirx.asin")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using namespace intrin;
@@ -187,39 +157,6 @@ TVM_REGISTER_OP("tirx.acos")
return
::tvm::codegen::intrin::DispatchPureExtern<::tvm::codegen::intrin::FloatSuffix>(e);
});
-TVM_REGISTER_OP("tirx.atan")
- .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
- using tirx::make_const;
- const tirx::CallNode* call = e.as<tirx::CallNode>();
- TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in atan
legalization";
- const PrimExpr& x = call->args[0];
- PrimExpr one = make_const(x.dtype(), 1.0);
- PrimExpr denom = sqrt(x * x + one);
- return asin(x / denom);
- });
-
-TVM_REGISTER_OP("tirx.asinh")
- .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
- using tirx::make_const;
- const tirx::CallNode* call = e.as<tirx::CallNode>();
- TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in asinh
legalization";
- const PrimExpr& x = call->args[0];
- PrimExpr one = make_const(x.dtype(), 1.0);
- PrimExpr sqrt_val = sqrt(x * x + one);
- return log(x + sqrt_val);
- });
-
-TVM_REGISTER_OP("tirx.acosh")
- .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
- using tirx::make_const;
- const tirx::CallNode* call = e.as<tirx::CallNode>();
- TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in acosh
legalization";
- const PrimExpr& x = call->args[0];
- PrimExpr one = make_const(x.dtype(), 1.0);
- PrimExpr sqrt_val = sqrt(x * x - one);
- return log(x + sqrt_val);
- });
-
TVM_REGISTER_OP("tirx.atanh")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
using tirx::make_const;
@@ -230,25 +167,6 @@ TVM_REGISTER_OP("tirx.atanh")
return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5);
});
-TVM_REGISTER_OP("tirx.erf")
- .set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
- using tirx::make_const;
- const tirx::CallNode* call = e.as<tirx::CallNode>();
- TVM_FFI_ICHECK(call != nullptr) << "Invalid call node in erf
legalization";
- const PrimExpr& x = call->args[0];
- PrimExpr abs_x = tvm::abs(x);
- PrimExpr t = make_const(x.dtype(), 1.0) /
- (make_const(x.dtype(), 1.0) + make_const(x.dtype(),
0.3275911) * abs_x);
- PrimExpr a1 = make_const(x.dtype(), 0.254829592);
- PrimExpr a2 = make_const(x.dtype(), -0.284496736);
- PrimExpr a3 = make_const(x.dtype(), 1.421413741);
- PrimExpr a4 = make_const(x.dtype(), -1.453152027);
- PrimExpr a5 = make_const(x.dtype(), 1.061405429);
- PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t);
- PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x *
abs_x);
- return tvm::tirx::Select(x < 0, -approx, approx);
- });
-
TVM_REGISTER_OP("tirx.clz")
.set_attr<FLegalize>("llvm.FLegalize", [](const PrimExpr& e) -> PrimExpr {
const tirx::CallNode* call = e.as<tirx::CallNode>();
diff --git a/tests/python/relax/test_frontend_onnx.py
b/tests/python/relax/test_frontend_onnx.py
index ca05a6492f..b658a2aaba 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -726,7 +726,7 @@ def test_bitwise_shift(direction: str):
"Tanh",
"Asin",
"Acos",
- # "Atan", // TODO: fix x²+1 overflow in llvm legalize for huge inputs
(issue #19560)
+ "Atan",
"Asinh",
"Acosh",
"Atanh",