This is an automated email from the ASF dual-hosted git repository.

cbalint13 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 6620fe2b67 Add LLVM Legalization for tir.erf (#18104)
6620fe2b67 is described below

commit 6620fe2b67af1e3e1581e544167db04d37c2046f
Author: Gokulnath Srinivasan <[email protected]>
AuthorDate: Fri Jul 4 01:48:03 2025 +0530

    Add LLVM Legalization for tir.erf (#18104)
    
    This PR adds LLVM legalization support for tir.erf using the Abramowitz and 
Stegun approximation, which avoids the precision issues found in the tanh 
approximation based implementation.
---
 src/target/llvm/intrin_rule_llvm.cc      | 18 ++++++++++++++++++
 tests/python/tir-base/test_tir_intrin.py |  2 ++
 2 files changed, 20 insertions(+)

diff --git a/src/target/llvm/intrin_rule_llvm.cc 
b/src/target/llvm/intrin_rule_llvm.cc
index e519c9eef3..15cc445090 100644
--- a/src/target/llvm/intrin_rule_llvm.cc
+++ b/src/target/llvm/intrin_rule_llvm.cc
@@ -242,6 +242,24 @@ TVM_REGISTER_OP("tir.atanh")
       return (log(one + x) - log(one - x)) * make_const(x.dtype(), 0.5);
     });
 
+TVM_REGISTER_OP("tir.erf").set_attr<FLegalize>("llvm.FLegalize", [](const 
PrimExpr& e) -> PrimExpr {
+  using tir::make_const;
+  const tir::CallNode* call = e.as<tir::CallNode>();
+  ICHECK(call != nullptr) << "Invalid call node in erf legalization";
+  const PrimExpr& x = call->args[0];
+  PrimExpr abs_x = tvm::abs(x);
+  PrimExpr t = make_const(x.dtype(), 1.0) /
+               (make_const(x.dtype(), 1.0) + make_const(x.dtype(), 0.3275911) 
* abs_x);
+  PrimExpr a1 = make_const(x.dtype(), 0.254829592);
+  PrimExpr a2 = make_const(x.dtype(), -0.284496736);
+  PrimExpr a3 = make_const(x.dtype(), 1.421413741);
+  PrimExpr a4 = make_const(x.dtype(), -1.453152027);
+  PrimExpr a5 = make_const(x.dtype(), 1.061405429);
+  PrimExpr poly = (((((a5 * t + a4) * t + a3) * t + a2) * t + a1) * t);
+  PrimExpr approx = make_const(x.dtype(), 1.0) - poly * exp(-abs_x * abs_x);
+  return tvm::tir::Select(x < 0, -approx, approx);
+});
+
 TVM_REGISTER_OP("tir.clz").set_attr<FLegalize>("llvm.FLegalize", [](const 
PrimExpr& e) -> PrimExpr {
   const tir::CallNode* call = e.as<tir::CallNode>();
   ICHECK(call != nullptr);
diff --git a/tests/python/tir-base/test_tir_intrin.py 
b/tests/python/tir-base/test_tir_intrin.py
index 3e731f55fb..55f8dbed6c 100644
--- a/tests/python/tir-base/test_tir_intrin.py
+++ b/tests/python/tir-base/test_tir_intrin.py
@@ -23,6 +23,7 @@ from tvm.script import tir as T
 import numpy as np
 import ctypes
 import math
+import scipy
 
 
 def test_nearbyint():
@@ -77,6 +78,7 @@ def test_unary_intrin():
         (tvm.tir.asinh, lambda x: np.arcsinh(x)),
         (tvm.tir.acosh, lambda x: np.arccosh(x)),
         (tvm.tir.atanh, lambda x: np.arctanh(x)),
+        (tvm.tir.erf, lambda x: scipy.special.erf(x)),
     ]
 
     def run_test(tvm_intrin, np_func, atol=1e-5, rtol=1e-5):

Reply via email to