This is an automated email from the ASF dual-hosted git repository.
cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6620fe2b67 Add LLVM Legalization for tir.erf (#18104)
6620fe2b67 is described below
commit 6620fe2b67af1e3e1581e544167db04d37c2046f
Author: Gokulnath Srinivasan <[email protected]>
AuthorDate: Fri Jul 4 01:48:03 2025 +0530
Add LLVM Legalization for tir.erf (#18104)
This PR adds LLVM legalization support for tir.erf using the Abramowitz and
Stegun approximation, which avoids the precision issues found in the tanh
approximation based implementation.
---
src/target/llvm/intrin_rule_llvm.cc | 18 ++++++++++++++++++
tests/python/tir-base/test_tir_intrin.py | 2 ++
2 files changed, 20 insertions(+)
diff --git a/src/target/llvm/intrin_rule_llvm.cc
b/src/target/llvm/intrin_rule_llvm.cc
index e519c9eef3..15cc445090 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -242,6 +242,24 @@ TVM_REGISTER_OP("tir.atanh")
return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5);
});
+TVM_REGISTER_OP("tir.erf").set_attr<FLegalize>("llvm.FLegalize", [](const
PrimExpr& e) -> PrimExpr {
+ using tir::make_const;
+ const tir::CallNode* call = e.as<tir::CallNode>();
+ ICHECK(call != nullptr) << "Invalid call node in erf legalization";
+ const PrimExpr& x = call->args[0];
+ PrimExpr abs_x = tvm::abs(x);
+ PrimExpr t = make_const(x.dtype(), 1.0) /
+ (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911)
* abs_x);
+ PrimExpr a1 = make_const(x.dtype(), 0.254829592);
+ PrimExpr a2 = make_const(x.dtype(), -0.284496736);
+ PrimExpr a3 = make_const(x.dtype(), 1.421413741);
+ PrimExpr a4 = make_const(x.dtype(), -1.453152027);
+ PrimExpr a5 = make_const(x.dtype(), 1.061405429);
+ PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t);
+ PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x);
+ return tvm::tir::Select(x < 0, -approx, approx);
+});
+
TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const
PrimExpr& e) -> PrimExpr {
const tir::CallNode* call = e.as<tir::CallNode>();
ICHECK(call != nullptr);
diff --git a/tests/python/tir-base/test_tir_intrin.py
b/tests/python/tir-base/test_tir_intrin.py
index 3e731f55fb..55f8dbed6c 100644
--- a/tests/python/tir-base/test_tir_intrin.py
+++ b/tests/python/tir-base/test_tir_intrin.py
@@ -23,6 +23,7 @@ from tvm.script import tir as T
import numpy as np
import ctypes
import math
+import scipy
def test_nearbyint():
@@ -77,6 +78,7 @@ def test_unary_intrin():
(tvm.tir.asinh, lambda x: np.arcsinh(x)),
(tvm.tir.acosh, lambda x: np.arccosh(x)),
(tvm.tir.atanh, lambda x: np.arctanh(x)),
+ (tvm.tir.erf, lambda x: scipy.special.erf(x)),
]
def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):