This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 56d0e3b7af [METAL][CODEGEN] testcase for ramp codegen (#14331)
56d0e3b7af is described below

commit 56d0e3b7affce768ae8c7b9d17743d3e8332308d
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Mar 19 14:50:44 2023 -0400

    [METAL][CODEGEN] testcase for ramp codegen (#14331)
    
    This PR adds a testcase that can be tested locally to cover metal ramp 
codegen
---
 tests/python/unittest/test_target_codegen_metal.py | 29 ++++++++++++++++++++--
 1 file changed, 27 insertions(+), 2 deletions(-)

diff --git a/tests/python/unittest/test_target_codegen_metal.py 
b/tests/python/unittest/test_target_codegen_metal.py
index 002cf3c696..45588c69cf 100644
--- a/tests/python/unittest/test_target_codegen_metal.py
+++ b/tests/python/unittest/test_target_codegen_metal.py
@@ -17,11 +17,12 @@
 import tvm
 from tvm import te
 import numpy as np
-from tvm import topi
-import unittest
+
 from tvm.contrib.nvcc import have_fp16, have_int8, have_bf16
 from tvm.contrib import nvcc
 import tvm.testing
+import tvm.script
+from tvm.script import tir as T
 
 tx = te.thread_axis("threadIdx.x")
 bx = te.thread_axis("blockIdx.x")
@@ -76,6 +77,30 @@ def test_metal_erf():
     check_erf(dev, 1, "float16")
 
 
[email protected]_gpu
[email protected]_metal
+def test_ramp():
+    target = "metal"
+
+    @tvm.script.ir_module
+    class IRModule:
+        @T.prim_func
+        def main(A: T.Buffer((1, 2), "int32")):
+            T.func_attr({"global_symbol": "main"})
+            for i in T.thread_binding(1, thread="threadIdx.x"):
+                with T.block("block"):
+                    tx = T.axis.spatial(1, i)
+                    r = T.ramp(tx, 3, 2)
+                    A[0, T.ramp(0, 1, 2)] = r
+
+    f = tvm.build(IRModule, target=target)
+    dev = tvm.metal()
+    a_nd = tvm.nd.empty((1, 2), "int32", dev)
+    f(a_nd)
+    assert tuple(a_nd.numpy()[0, :]) == (0, 3)
+
+
 if __name__ == "__main__":
+    test_ramp()
     test_metal_inf_nan()
     test_metal_erf()

Reply via email to