This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2ab519282f [Vulkan] Add TIR unary trigonometric/hyperbolic intrinsic 
definitions (#18005)
2ab519282f is described below

commit 2ab519282f693f65e2c9bd91a6f345be41145f17
Author: Darren Wihandi <[email protected]>
AuthorDate: Wed May 21 14:59:10 2025 -0400

    [Vulkan] Add TIR unary trigonometric/hyperbolic intrinsic definitions 
(#18005)
---
 src/target/spirv/intrin_rule_spirv.cc              | 36 ++++++++++++--
 tests/python/codegen/test_target_codegen_vulkan.py | 55 ++++++++++++++++++++++
 2 files changed, 88 insertions(+), 3 deletions(-)

diff --git a/src/target/spirv/intrin_rule_spirv.cc 
b/src/target/spirv/intrin_rule_spirv.cc
index e5f869de17..ccb8d131c9 100644
--- a/src/target/spirv/intrin_rule_spirv.cc
+++ b/src/target/spirv/intrin_rule_spirv.cc
@@ -91,6 +91,39 @@ 
TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
 TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
                                                      
DispatchGLSLPureIntrin<GLSLstd450Cos>);
 
+TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
+                                                     
DispatchGLSLPureIntrin<GLSLstd450Tan>);
+
+TVM_REGISTER_OP("tir.asin")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Asin>);
+
+TVM_REGISTER_OP("tir.acos")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Acos>);
+
+TVM_REGISTER_OP("tir.atan")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Atan>);
+
+TVM_REGISTER_OP("tir.sinh")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Sinh>);
+
+TVM_REGISTER_OP("tir.cosh")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Cosh>);
+
+TVM_REGISTER_OP("tir.tanh")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Tanh>);
+
+TVM_REGISTER_OP("tir.asinh")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Asinh>);
+
+TVM_REGISTER_OP("tir.acosh")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Acosh>);
+
+TVM_REGISTER_OP("tir.atanh")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Atanh>);
+
+TVM_REGISTER_OP("tir.atan2")
+    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Atan2>);
+
 TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
                                                      
DispatchGLSLPureIntrin<GLSLstd450Log>);
 
@@ -103,9 +136,6 @@ TVM_REGISTER_OP("tir.sqrt")
 TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
                                                      
DispatchGLSLPureIntrin<GLSLstd450Pow>);
 
-TVM_REGISTER_OP("tir.tanh")
-    .set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic", 
DispatchGLSLPureIntrin<GLSLstd450Tanh>);
-
 TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("vulkan.FLowerIntrinsic",
                                                      codegen::intrin 
::DispatchFastErf);
 }  // namespace intrin
diff --git a/tests/python/codegen/test_target_codegen_vulkan.py 
b/tests/python/codegen/test_target_codegen_vulkan.py
index b661ce4869..89acf598d6 100644
--- a/tests/python/codegen/test_target_codegen_vulkan.py
+++ b/tests/python/codegen/test_target_codegen_vulkan.py
@@ -568,5 +568,60 @@ def test_codegen_decl_buffer():
     vulkan_codegen(mod, target)
 
 
[email protected]_gpu
[email protected]_vulkan
+def test_unary():
+    test_funcs = [
+        (tvm.tir.sin, lambda x: np.sin(x)),
+        (tvm.tir.cos, lambda x: np.cos(x)),
+        (tvm.tir.tan, lambda x: np.tan(x)),
+        (tvm.tir.sinh, lambda x: np.sinh(x)),
+        (tvm.tir.cosh, lambda x: np.cosh(x)),
+        (tvm.tir.tanh, lambda x: np.tanh(x)),
+        (tvm.tir.asin, lambda x: np.arcsin(x)),
+        (tvm.tir.acos, lambda x: np.arccos(x)),
+        (tvm.tir.atan, lambda x: np.arctan(x)),
+        (tvm.tir.asinh, lambda x: np.arcsinh(x)),
+        (tvm.tir.acosh, lambda x: np.arccosh(x)),
+        (tvm.tir.atanh, lambda x: np.arctanh(x)),
+    ]
+
+    def run_test(tvm_intrin, np_func):
+        m = te.var("m")
+        A = te.placeholder((m,), name="A", dtype="float32")
+        B = te.compute((m,), lambda *i: tvm_intrin(A(*i)), name="B")
+
+        mod = te.create_prim_func([A, B])
+        sch = tir.Schedule(mod)
+
+        block = sch.get_block("B")
+        loop = sch.get_loops(block)[0]
+        bx, tx = sch.split(loop, factors=[None, 64])
+        sch.bind(bx, "blockIdx.x")
+        sch.bind(tx, "threadIdx.x")
+
+        target = tvm.target.Target("vulkan")
+        dev = tvm.device(target.kind.name, 0)
+        func = tvm.compile(sch.mod, target=target)
+
+        n = 16
+        if tvm_intrin in [tvm.tir.asin, tvm.tir.acos]:
+            data = np.random.uniform(-1.0, 1.0, size=n)
+        elif tvm_intrin == tvm.tir.atanh:
+            data = np.random.uniform(-0.999, 0.999, size=n)
+        elif tvm_intrin == tvm.tir.acosh:
+            data = np.random.uniform(1.0, 5.0, size=n)
+        else:
+            data = np.random.uniform(0.1, 0.9, size=n)
+
+        a = tvm.nd.array(data.astype(A.dtype), dev)
+        b = tvm.nd.array(np.zeros(n, dtype=A.dtype), dev)
+        func(a, b)
+        tvm.testing.assert_allclose(b.numpy(), np_func(a.numpy()), atol=1e-3, 
rtol=1e-3)
+
+    for func in test_funcs:
+        run_test(*func)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to