This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 56d0e3b7af [METAL][CODEGEN] testcase for ramp codegen (#14331)
56d0e3b7af is described below
commit 56d0e3b7affce768ae8c7b9d17743d3e8332308d
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Mar 19 14:50:44 2023 -0400
[METAL][CODEGEN] testcase for ramp codegen (#14331)
This PR adds a testcase that can be tested locally to cover metal ramp
codegen
---
tests/python/unittest/test_target_codegen_metal.py | 29 ++++++++++++++++++++--
1 file changed, 27 insertions(+), 2 deletions(-)
diff --git a/tests/python/unittest/test_target_codegen_metal.py
b/tests/python/unittest/test_target_codegen_metal.py
index 002cf3c696..45588c69cf 100644
--- a/tests/python/unittest/test_target_codegen_metal.py
+++ b/tests/python/unittest/test_target_codegen_metal.py
@@ -17,11 +17,12 @@
import tvm
from tvm import te
import numpy as np
-from tvm import topi
-import unittest
+
from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
from tvm.contrib import nvcc
import tvm.testing
+import tvm.script
+from tvm.script import tir as T
tx = te.thread_axis("threadIdx.x")
bx = te.thread_axis("blockIdx.x")
@@ -76,6 +77,30 @@ def test_metal_erf():
check_erf(dev, 1, "float16")
[email protected]_gpu
[email protected]_metal
+def test_ramp():
+ target = "metal"
+
+ @tvm.script.ir_module
+ class IRModule:
+ @T.prim_func
+ def main(A: T.Buffer((1, 2), "int32")):
+ T.func_attr({"global_symbol": "main"})
+ for i in T.thread_binding(1, thread="threadIdx.x"):
+ with T.block("block"):
+ tx = T.axis.spatial(1, i)
+ r = T.ramp(tx, 3, 2)
+ A[0, T.ramp(0, 1, 2)] = r
+
+ f = tvm.build(IRModule, target=target)
+ dev = tvm.metal()
+ a_nd = tvm.nd.empty((1, 2), "int32", dev)
+ f(a_nd)
+ assert tuple(a_nd.numpy()[0, :]) == (0, 3)
+
+
if __name__ == "__main__":
+ test_ramp()
test_metal_inf_nan()
test_metal_erf()