This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 30bf013e78 [TIR][Schedule] Add unittest for read_write_at (#14395)
30bf013e78 is described below

commit 30bf013e788784a1fce031cb2b2cbf0811cf7f58
Author: Tian Xia <[email protected]>
AuthorDate: Mon Mar 27 23:05:31 2023 +0800

    [TIR][Schedule] Add unittest for read_write_at (#14395)
    
    This PR adds unittest for schedule primitive read_at and write_at.
    
    Co-authored-by: Siyuan Feng <[email protected]>
---
 .../unittest/test_tir_schedule_read_write_at.py    | 221 +++++++++++++++++++++
 1 file changed, 221 insertions(+)

diff --git a/tests/python/unittest/test_tir_schedule_read_write_at.py 
b/tests/python/unittest/test_tir_schedule_read_write_at.py
new file mode 100644
index 0000000000..dd61a4d62b
--- /dev/null
+++ b/tests/python/unittest/test_tir_schedule_read_write_at.py
@@ -0,0 +1,221 @@
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-function-docstring,missing-module-docstring
+import sys
+
+import pytest
+
+import tvm
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import verify_trace_roundtrip
+
+
+# fmt: off
+# pylint: 
disable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable
+
[email protected]_func
+def cuda_matmul(a: T.handle, b: T.handle, c: T.handle) -> None:  # pylint: 
disable=undefined-loop-variable
+    A = T.match_buffer(a, [2048, 2048], "float32")
+    B = T.match_buffer(b, [2048, 2048], "float32")
+    C = T.match_buffer(c, [2048, 2048], "float32")
+    for by in T.thread_binding(0, 32, thread = "blockIdx.y"):
+        for bx in T.thread_binding(0, 32, thread = "blockIdx.x"):
+            for vy in T.thread_binding(0, 2, thread = "vthread.y"):
+                for vx in T.thread_binding(0, 2, thread = "vthread.x"):
+                    for ty in T.thread_binding(0, 8, thread = "threadIdx.y"):
+                        for tx in T.thread_binding(0, 8, thread = 
"threadIdx.x"):
+                            for k0 in T.serial(0, 256):
+                                for k1 in T.unroll(0, 8):
+                                    for _, i, j in T.grid(1, 4, 4):
+                                        with T.block("C"):
+                                            vi = T.axis.S(2048, by * 64 + vy * 
32 + ty * 4 + i)
+                                            vj = T.axis.S(2048, bx * 64 + vx * 
32 + tx * 4 + j)
+                                            vk = T.axis.R(2048, k0 * 8 + k1)
+                                            T.reads([C[vi, vj], A[vi, vk], 
B[vk, vj]])
+                                            T.writes([C[vi, vj]])
+                                            with T.init():
+                                                C[vi, vj] = 0.0
+                                            C[vi, vj] = C[vi, vj] + A[vi, vk] 
* B[vk, vj]
+
+
[email protected]_func
+def cuda_matmul_read_at_a(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [2048, 2048], dtype="float32")
+    B = T.match_buffer(b, [2048, 2048], dtype="float32")
+    C = T.match_buffer(c, [2048, 2048], dtype="float32")
+    A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
+    for by in T.thread_binding(0, 32, thread="blockIdx.y"):
+        for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
+            for vy in T.thread_binding(0, 2, thread="vthread.y"):
+                for vx in T.thread_binding(0, 2, thread="vthread.x"):
+                    for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
+                        for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
+                            for k0 in T.serial(0, 256):
+                                with T.block("A_shared"):
+                                    v0 = T.axis.S(32, by)
+                                    v1 = T.axis.S(256, k0)
+                                    T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 
: v1 * 8 + 8]])
+                                    T.writes([A_shared[v0 * 64 : v0 * 64 + 64, 
v1 * 8 : v1 * 8 + 8]])
+                                    T.block_attr({"auto_copy":1})
+                                    for ax0, ax1 in T.grid(64, 8):
+                                        A_shared[v0 * 64 + ax0, v1 * 8 + ax1] 
= A[v0 * 64 + ax0, v1 * 8 + ax1]
+                                for k1 in T.unroll(0, 8):
+                                    for v_, i, j in T.grid(1, 4, 4):
+                                        with T.block("C"):
+                                            vi = T.axis.S(2048, by * 64 + vy * 
32 + ty * 4 + i)
+                                            vj = T.axis.S(2048, bx * 64 + vx * 
32 + tx * 4 + j)
+                                            vk = T.axis.R(2048, k0 * 8 + k1)
+                                            T.reads([C[vi, vj], A_shared[vi, 
vk], B[vk, vj]])
+                                            T.writes([C[vi, vj]])
+                                            with T.init():
+                                                C[vi, vj] = T.float32(0)
+                                            C[vi, vj] = C[vi, vj] + 
A_shared[vi, vk] * B[vk, vj]
+
+
[email protected]_func
+def cuda_matmul_read_at_ab(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [2048, 2048], dtype="float32")
+    B = T.match_buffer(b, [2048, 2048], dtype="float32")
+    C = T.match_buffer(c, [2048, 2048], dtype="float32")
+    A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
+    B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
+    for by in T.thread_binding(0, 32, thread="blockIdx.y"):
+        for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
+            for vy in T.thread_binding(0, 2, thread="vthread.y"):
+                for vx in T.thread_binding(0, 2, thread="vthread.x"):
+                    for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
+                        for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
+                            for k0 in T.serial(0, 256):
+                                with T.block("A_shared"):
+                                    v0 = T.axis.S(32, by)
+                                    v1 = T.axis.S(256, k0)
+                                    T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 
: v1 * 8 + 8]])
+                                    T.writes([A_shared[v0 * 64 : v0 * 64 + 64, 
v1 * 8 : v1 * 8 + 8]])
+                                    T.block_attr({"auto_copy":1})
+                                    for ax0, ax1 in T.grid(64, 8):
+                                        A_shared[v0 * 64 + ax0, v1 * 8 + ax1] 
= A[v0 * 64 + ax0, v1 * 8 + ax1]
+                                with T.block("B_shared"):
+                                    v0 = T.axis.S(256, k0)
+                                    v1 = T.axis.S(32, bx)
+                                    T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : 
v1 * 64 + 64]])
+                                    T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 
* 64 : v1 * 64 + 64]])
+                                    T.block_attr({"auto_copy":1})
+                                    for ax0, ax1 in T.grid(8, 64):
+                                        B_shared[v0 * 8 + ax0, v1 * 64 + ax1] 
= B[v0 * 8 + ax0, v1 * 64 + ax1]
+                                for k1 in T.unroll(0, 8):
+                                    for v_, i, j in T.grid(1, 4, 4):
+                                        with T.block("C"):
+                                            vi = T.axis.S(2048, by * 64 + vy * 
32 + ty * 4 + i)
+                                            vj = T.axis.S(2048, bx * 64 + vx * 
32 + tx * 4 + j)
+                                            vk = T.axis.R(2048, k0 * 8 + k1)
+                                            T.reads([C[vi, vj], A_shared[vi, 
vk], B_shared[vk, vj]])
+                                            T.writes([C[vi, vj]])
+                                            with T.init():
+                                                C[vi, vj] = T.float32(0)
+                                            C[vi, vj] = C[vi, vj] + 
A_shared[vi, vk] * B_shared[vk, vj]
+
[email protected]_func
+def cuda_matmul_write_at_c(a: T.handle, b: T.handle, c: T.handle) -> None:
+    A = T.match_buffer(a, [2048, 2048], dtype="float32")
+    B = T.match_buffer(b, [2048, 2048], dtype="float32")
+    C = T.match_buffer(c, [2048, 2048], dtype="float32")
+    A_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
+    B_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
+    C_shared = T.alloc_buffer([2048, 2048], dtype="float32", scope="shared")
+    for by in T.thread_binding(0, 32, thread="blockIdx.y"):
+        for bx in T.thread_binding(0, 32, thread="blockIdx.x"):
+            for vy in T.thread_binding(0, 2, thread="vthread.y"):
+                for vx in T.thread_binding(0, 2, thread="vthread.x"):
+                    for ty in T.thread_binding(0, 8, thread="threadIdx.y"):
+                        for tx in T.thread_binding(0, 8, thread="threadIdx.x"):
+                            for k0 in T.serial(0, 256):
+                                with T.block("A_shared"):
+                                    v0 = T.axis.S(32, by)
+                                    v1 = T.axis.S(256, k0)
+                                    T.reads([A[v0 * 64 : v0 * 64 + 64, v1 * 8 
: v1 * 8 + 8]])
+                                    T.writes([A_shared[v0 * 64 : v0 * 64 + 64, 
v1 * 8 : v1 * 8 + 8]])
+                                    T.block_attr({"auto_copy":1})
+                                    for ax0, ax1 in T.grid(64, 8):
+                                        A_shared[v0 * 64 + ax0, v1 * 8 + ax1] 
= A[v0 * 64 + ax0, v1 * 8 + ax1]
+                                with T.block("B_shared"):
+                                    v0 = T.axis.S(256, k0)
+                                    v1 = T.axis.S(32, bx)
+                                    T.reads([B[v0 * 8 : v0 * 8 + 8, v1 * 64 : 
v1 * 64 + 64]])
+                                    T.writes([B_shared[v0 * 8 : v0 * 8 + 8, v1 
* 64 : v1 * 64 + 64]])
+                                    T.block_attr({"auto_copy":1})
+                                    for ax0, ax1 in T.grid(8, 64):
+                                        B_shared[v0 * 8 + ax0, v1 * 64 + ax1] 
= B[v0 * 8 + ax0, v1 * 64 + ax1]
+                                for k1 in T.unroll(0, 8):
+                                    for v_, i, j in T.grid(1, 4, 4):
+                                        with T.block("C"):
+                                            vi = T.axis.S(2048, by * 64 + vy * 
32 + ty * 4 + i)
+                                            vj = T.axis.S(2048, bx * 64 + vx * 
32 + tx * 4 + j)
+                                            vk = T.axis.R(2048, k0 * 8 + k1)
+                                            T.reads([C_shared[vi, vj], 
A_shared[vi, vk], B_shared[vk, vj]])
+                                            T.writes([C_shared[vi, vj]])
+                                            with T.init():
+                                                C_shared[vi, vj] = T.float32(0)
+                                            C_shared[vi, vj] = C_shared[vi, 
vj] + A_shared[vi, vk] * B_shared[vk, vj]
+                            with T.block("C_shared"):
+                                v0 = T.axis.S(32, by)
+                                v1 = T.axis.S(32, bx)
+                                T.reads([C_shared[v0 * 64 : v0 * 64 + 64, v1 * 
64 : v1 * 64 + 64]])
+                                T.writes([C[v0 * 64 : v0 * 64 + 64, v1 * 64 : 
v1 * 64 + 64]])
+                                T.block_attr({"auto_copy":1})
+                                for ax0, ax1 in T.grid(64, 64):
+                                    C[v0 * 64 + ax0, v1 * 64 + ax1] = 
C_shared[v0 * 64 + ax0, v1 * 64 + ax1]
+
+
+# pylint: 
enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks,not-callable
+# fmt: on
+
+
+def test_read_at_global_to_shared_a():
+    sch = tir.Schedule(cuda_matmul, debug_mask="all")
+    block = sch.get_block("C")
+    # pylint: disable=invalid-name
+    _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block)
+    # pylint: enable=invalid-name
+    sch.read_at(k0, block, 1, "shared")
+    tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_a)
+    verify_trace_roundtrip(sch, cuda_matmul)
+
+
+def test_read_at_global_to_shared_ab():
+    sch = tir.Schedule(cuda_matmul_read_at_a, debug_mask="all")
+    block = sch.get_block("C")
+    # pylint: disable=invalid-name
+    _by, _bx, _vy, _vx, _ty, _tx, k0, _k1, _, _i, _j = sch.get_loops(block)
+    # pylint: enable=invalid-name
+    sch.read_at(k0, block, 2, "shared")
+    tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_read_at_ab)
+    verify_trace_roundtrip(sch, cuda_matmul_read_at_a)
+
+
+def test_read_at_local_to_shared_c():
+    sch = tir.Schedule(cuda_matmul_read_at_ab, debug_mask="all")
+    block = sch.get_block("C")
+    # pylint: disable=invalid-name
+    _by, _bx, _vy, _vx, _ty, tx, _k0, _k1, _, _i, _j = sch.get_loops(block)
+    # pylint: enable=invalid-name
+    sch.write_at(tx, block, 0, "shared")
+    tvm.ir.assert_structural_equal(sch.mod["main"], cuda_matmul_write_at_c)
+    verify_trace_roundtrip(sch, cuda_matmul_read_at_ab)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to