This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4b9881e [CODEGEN][OpenCL]: fix tir.erf codegen to opencl directly
(#8756)
4b9881e is described below
commit 4b9881ec50008bc14fc1ae7805413544cf962011
Author: Yuan-Chuan-YUE <[email protected]>
AuthorDate: Sun Aug 22 05:42:21 2021 +0800
[CODEGEN][OpenCL]: fix tir.erf codegen to opencl directly (#8756)
* register tir.erf to lower opencl directly
* add opencl codegen unit test
* change erf opencl codegen unit test for checking there is erf in the
source not erff
---
src/target/source/intrin_rule_opencl.cc | 3 +++
tests/python/unittest/test_target_codegen_opencl.py | 20 ++++++++++++++++++++
2 files changed, 23 insertions(+)
diff --git a/src/target/source/intrin_rule_opencl.cc
b/src/target/source/intrin_rule_opencl.cc
index 288bb2c..64a50c3 100644
--- a/src/target/source/intrin_rule_opencl.cc
+++ b/src/target/source/intrin_rule_opencl.cc
@@ -49,6 +49,9 @@ TVM_REGISTER_OP("tir.round")
TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
DispatchPureExtern<Direct>);
+TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
+
DispatchPureExtern<Direct>);
+
TVM_REGISTER_OP("tir.exp2")
.set_attr<FLowerIntrinsic>("opencl.FLowerIntrinsic",
DispatchPureExtern<Direct>);
diff --git a/tests/python/unittest/test_target_codegen_opencl.py
b/tests/python/unittest/test_target_codegen_opencl.py
index 98340f0..56392ec 100644
--- a/tests/python/unittest/test_target_codegen_opencl.py
+++ b/tests/python/unittest/test_target_codegen_opencl.py
@@ -17,6 +17,7 @@
import tvm
from tvm import te
import tvm.testing
+import re
target = "opencl"
@@ -120,6 +121,25 @@ def test_opencl_max():
check_max(dev, 1, "float64")
+def test_opencl_erf():
+ def check_erf(dev, n, dtype):
+ A = te.placeholder((n,), name="A", dtype=dtype)
+ C = te.compute(A.shape, lambda *i: te.erf(A(*i)), name="C")
+ s = te.create_schedule(C.op)
+ s[C].bind(s[C].op.axis[0], te.thread_axis("threadIdx.x"))
+ fun = tvm.build(s, [A, C], target)
+ source_str = fun.imported_modules[0].get_source()
+ matches = re.findall("erf", source_str)
+ error_matches = re.findall("erff", source_str)
+ assert len(matches) == 1 and len(error_matches) == 0
+
+ dev = tvm.device(target, 0)
+
+ check_erf(dev, 1, "float32")
+ check_erf(dev, 1, "float64")
+
+
if __name__ == "__main__":
test_opencl_ternary_expression()
test_opencl_inf_nan()
+ test_opencl_erf()