yzh119 commented on code in PR #15389:
URL: https://github.com/apache/tvm/pull/15389#discussion_r1272925626


##########
tests/python/dlight/test_gpu_matmul.py:
##########
@@ -28,7 +28,7 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
     @pytest.fixture
     def transform(self):
         def transform(mod):
-            with Target("nvidia/geforce-rtx-3090-ti"):
+            with Target("nvidia/geforce-gtx-1080-ti"):

Review Comment:
   The target tag here should not influence codegen so I suggest keeping it to 
its original value.



##########
tests/python/dlight/test_gpu_matmul_tensorize.py:
##########
@@ -0,0 +1,259 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+
+import tvm.testing
+from tvm import dlight as dl
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    @pytest.fixture
+    def transform(self):
+        def transform(mod):
+            with Target("nvidia/geforce-rtx-2080-ti"):
+                return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+        return transform
+
+
+class TestMatmulTensorize(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), 
"float16"), compute: T.Buffer((256, 256), "float16")):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i, j, k in T.grid(256, 256, 256):
+            with T.block("compute"):
+                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                T.reads(X[v_i, v_k], W[v_j, v_k])
+                T.writes(compute[v_i, v_j])
+                with T.init():
+                    compute[v_i, v_j] = T.float16(0)
+                compute[v_i, v_j] = compute[v_i, v_j] + X[v_i, v_k] * W[v_j, 
v_k]
+
+    @T.prim_func
+    def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), 
"float16"), compute: T.Buffer((256, 256), "float16")):
+        T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", 
scope="shared.dyn")
+        W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", 
scope="shared.dyn")
+        X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), 
"float16", scope="wmma.matrix_a")
+        W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), 
"float16", scope="wmma.matrix_b")
+        compute_reindex_shared = T.alloc_buffer((1, 256, 256), "float16", 
scope="shared")
+        compute_reindex_shared_wmma_accumulator = T.alloc_buffer((1, 256, 
256), "float16", scope="wmma.accumulator")
+        for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+            for ax1_0_0_ax2_0_0_fused in T.thread_binding(4, 
thread="blockIdx.x"):
+                for ax1_0_1_ax2_0_1_fused in T.thread_binding(4, 
thread="blockIdx.y"):
+                    for ax2_0_2_ax1_0_2_fused in T.thread_binding(4, 
thread="threadIdx.y"):
+                        for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+                            with T.block("compute_o_init"):
+                                v0_o = T.axis.spatial(T.int64(1), ax0)
+                                v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3_init)
+                                v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax2_0_3_init)
+                                T.reads()
+                                
T.writes(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                with T.block("compute_init_o"):
+                                    v1_i_init_o = T.axis.spatial(1, 0)
+                                    v2_i_init_o = T.axis.spatial(1, 0)
+                                    T.reads()
+                                    
T.writes(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                    C = 
T.match_buffer(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="wmma.accumulator", offset_factor=16)
+                                    T.tvm_fill_fragment(C.data, 16, 16, 16, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, T.float32(0))
+                        for ax3_0_0 in range(8):
+                            for ax0_ax1_fused_0 in range(2):
+                                for ax0_ax1_fused_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(8):
+                                            with 
T.block("X_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(256, 
ax1_0_0_ax2_0_0_fused * 64 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 256 + 
ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 32)
+                                                v2 = T.axis.spatial(256, 
ax3_0_0 * 32 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 256 + 
ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 32)
+                                                T.reads(X[v1, v2])
+                                                
T.writes(X_reindex_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+                                                X_reindex_shared_dyn[v0, v1, 
v2] = X[v1, v2]
+                            for ax0_ax1_fused_0 in range(2):
+                                for ax0_ax1_fused_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(8):
+                                            with 
T.block("W_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 64 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 256 + 
ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) // 32)
+                                                v2 = T.axis.spatial(256, 
ax3_0_0 * 32 + (ax0_ax1_fused_0 * 1024 + ax0_ax1_fused_1 * 256 + 
ax0_ax1_fused_2 * 8 + ax0_ax1_fused_3) % 32)
+                                                T.reads(W[v1, v2])
+                                                
T.writes(W_reindex_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+                                                W_reindex_shared_dyn[v0, v1, 
v2] = W[v1, v2]
+                            for ax3_0_1 in range(2):
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("X_reindex_shared.dyn_wmma.matrix_a_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(16, ax3_0_0 
* 2 + ax3_0_1 + ax1_0)
+                                            T.reads(X_reindex_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            
T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, 
T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, 
A.strides[0] * 16, 1), A.strides[0], "row_major")
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("W_reindex_shared.dyn_wmma.matrix_b_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(16, ax3_0_0 
* 2 + ax3_0_1 + ax1_0)
+                                            T.reads(W_reindex_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            
T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, 
T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, 
A.strides[0] * 16, 1), A.strides[0], "col_major")
+                                for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+                                    with T.block("compute_o_update"):
+                                        v0_o = T.axis.spatial(T.int64(1), ax0)
+                                        v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3)
+                                        v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax2_0_3)
+                                        v3_o = T.axis.reduce(16, ax3_0_0 * 2 + 
ax3_0_1)
+                                        
T.reads(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o 
* 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, 
v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+                                        
T.writes(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                        with T.block("compute_o"):
+                                            v1_i_o = T.axis.spatial(1, 0)
+                                            v2_i_o = T.axis.spatial(1, 0)
+                                            v3_i_o = T.axis.reduce(1, 0)
+                                            
T.reads(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o 
* 16 + 16, v3_o * 16:v3_o * 16 + 16], W_reindex_shared_dyn_wmma_matrix_b[0, 
v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o * 16 + 16])
+                                            
T.writes(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            B = 
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            C = 
T.match_buffer(compute_reindex_shared_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="wmma.accumulator", offset_factor=16)
+                                            T.tvm_mma_sync(C.data, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] 
// 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // 
B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, 
C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16)
+                        for ax0_0, ax1_0 in T.grid(2, 2):
+                            with 
T.block("compute_reindex_shared_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(1, 0)
+                                v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax0_0)
+                                v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax1_0)
+                                
T.reads(compute_reindex_shared_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                T.writes(compute_reindex_shared[v0_o, v1_o * 
16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                A = 
T.match_buffer(compute_reindex_shared_wmma_accumulator[v0_o, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", 
"A_s1"), scope="wmma.accumulator", offset_factor=16)
+                                C = 
T.match_buffer(compute_reindex_shared[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="shared", offset_factor=16)
+                                T.tvm_store_matrix_sync(A.data, 16, 16, 16, 
A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % 
A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, 
C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major")
+                        for ax0_1, ax1 in T.grid(32, 32):
+                            with T.block("compute_reindex_shared"):
+                                v0 = T.axis.spatial(1, 0)
+                                v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused 
* 64 + ax2_0_2_ax1_0_2_fused % 2 * 32 + ax0_1)
+                                v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused 
* 64 + ax2_0_2_ax1_0_2_fused // 2 * 32 + ax1)
+                                T.reads(compute_reindex_shared[v0, v1, v2])
+                                T.writes(compute[v1, v2])
+                                T.block_attr({"buffer_dim_align": [[0, 1, 16, 
4]]})
+                                compute[v1, v2] = compute_reindex_shared[v0, 
v1, v2]
+
+    # fmt: on
+
+class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), 
var_compute: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        m = T.int32()
+        X = T.match_buffer(var_X, (m, 256), "float16")
+        compute = T.match_buffer(var_compute, (m, 15))
+        # with T.block("root"):
+        for i, j, k in T.grid(m, 15, 256):
+            with T.block("compute"):
+                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                T.reads(X[v_i, v_k], W[v_j, v_k])
+                T.writes(compute[v_i, v_j])
+                with T.init():
+                    compute[v_i, v_j] = T.float32(0)
+                compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("float32", 
X[v_i, v_k]) * T.Cast("float32", W[v_j, v_k])
+
+    @T.prim_func
+    def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), 
var_compute: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+        m = T.int32()
+        X = T.match_buffer(var_X, (m, 256), "float16")
+        compute = T.match_buffer(var_compute, (m, 15))
+        # with T.block("root"):
+        compute_reindex_pad_local = T.alloc_buffer((1, (T.Cast("int32", 
T.Cast("int64", m)) + 31) // 32 * 32, 64), scope="local")
+        X_reindex_pad_shared = T.alloc_buffer((1, (T.Cast("int32", 
T.Cast("int64", m)) + 31) // 32 * 32, 256), "float16", scope="shared")
+        W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", 
scope="shared")
+        for ax0_ax2_0_fused in T.thread_binding(T.int64(1), 
thread="blockIdx.y"):
+            for ax1_0 in T.thread_binding((T.Cast("int32", T.Cast("int64", m)) 
+ 31) // 32, thread="blockIdx.x"):
+                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(16, 
thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(8, 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
+                                    with T.block("compute_init"):
+                                        v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                        v1 = T.axis.spatial((T.Cast("int32", 
T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + 
ax1_3_init)
+                                        v2 = T.axis.spatial(64, ax2_1 * 64 + 
ax2_2 * 4 + ax2_3_init)
+                                        T.reads()
+                                        T.writes(compute_reindex_pad_local[0, 
v1, v2])
+                                        compute_reindex_pad_local[0, v1, v2] = 
T.float32(0)
+                                for ax3_0 in range(16):
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(16, thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(8, thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(2):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(2):
+                                                    with 
T.block("X_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(1, 
0)
+                                                        v1 = 
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 
32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 
* 2 + ax0_ax1_ax2_fused_3) // 16)
+                                                        v2 = 
T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + 
ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+                                                        T.reads(X[v1, v2])
+                                                        
T.writes(X_reindex_pad_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
X_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, X[v1, v2], 
T.float16(0))
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(16, thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(8, thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(4):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(2):
+                                                    with 
T.block("W_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(1, 
0)
+                                                        v1 = 
T.axis.spatial(64, (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + 
ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+                                                        v2 = 
T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + 
ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+                                                        T.reads(W[v1, v2])
+                                                        
T.writes(W_reindex_pad_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2], 
T.float16(0))
+                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 
4):
+                                        with T.block("compute_update"):
+                                            v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                            v1 = 
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 
32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
+                                            v2 = T.axis.spatial(64, ax2_1 * 64 
+ ax2_2 * 4 + ax2_3)
+                                            v3 = T.axis.reduce(256, ax3_0 * 16 
+ ax3_1)
+                                            
T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], 
W_reindex_pad_shared[0, v2, v3])
+                                            
T.writes(compute_reindex_pad_local[0, v1, v2])
+                                            compute_reindex_pad_local[0, v1, 
v2] = compute_reindex_pad_local[0, v1, v2] + T.Cast("float32", 
X_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", W_reindex_pad_shared[0, 
v2, v3])
+                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
+                                    for ax2_1_1 in T.vectorized(2):
+                                        with 
T.block("compute_reindex_pad_local"):
+                                            v0 = T.axis.spatial(1, ax0)
+                                            v1 = 
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 
32 + ax1_2 * 4 + ax1)
+                                            v2 = T.axis.spatial(64, ax2_2 * 4 
+ ax2_0 * 2 + ax2_1_1)
+                                            
T.reads(compute_reindex_pad_local[v0, v1, v2])
+                                            T.writes(compute[v1, v2])
+                                            if v1 < m and v2 < 15:
+                                                compute[v1, v2] = 
compute_reindex_pad_local[v0, v1, v2]
+    # fmt: on
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Review Comment:
   Add an empty line in the end to pass pytlint.



##########
python/tvm/dlight/gpu/matmul.py:
##########
@@ -290,6 +507,16 @@ def is_spatial(block: BlockRV) -> bool:
             return None
         matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
 
+        if target.kind.name == "cuda" and check_sm_version(target.arch) > 70:
+            apply_tensorization: bool = True
+            for item_var in block_stmt.iter_vars:
+                extent = item_var.dom.extent
+                if isinstance(extent, tir.expr.IntImm):
+                    if extent.value > 1 and extent.value <= 128:

Review Comment:
   Would you mind explaining this condition? 



##########
python/tvm/dlight/gpu/matmul.py:
##########
@@ -248,39 +248,256 @@ def get_index_map(block: tir.Block) -> 
Optional[Tuple[tir.IndexMap, ...]]:
         C_index_map,
     )
 
+def get_reduction_blocks(sch, blocks) -> bool:
+    # Get the main computation block
+    def is_reduction(block: BlockRV) -> bool:
+        block_stmt = sch.get(block)
+        iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+        return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+
+    def is_spatial(block: BlockRV) -> bool:
+        block_stmt = sch.get(block)
+        iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+        return iter_types == {IterVar.DataPar}
+
+    # NOTE: We assume there is only one reduction block in the function
+    # all blocks are required to be spatial or reduction
+    if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
+        return None
 
-class Matmul(ScheduleRule):
-    """The schedule rule for matmul-like computation"""
+    # There is only one reduction block
+    reduction_blocks = [block for block in blocks if is_reduction(block)]
+    if len(reduction_blocks) != 1:
+        return None
+
+    return reduction_blocks
+
+def check_sm_version(arch: str) -> int:
+    sm_version = arch.replace("sm_", "")
+    return int(sm_version) if sm_version.isdigit() else -1
+
+class MatmulTensorization(ScheduleRule):
+    """
+    The schedule rule for float16 tensor core matmul computation.
+    func with attr 'dlight.do_not_tensorize' will not be tensorized.
+    """
 
     def apply(  # pylint: disable=too-many-locals,missing-docstring
         self,
         func: tir.PrimFunc,
         target: Target,
         _: bool,
     ) -> Optional[tir.Schedule]:
+        from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group  # 
pylint: disable=import-outside-toplevel
+
         sch = tir.Schedule(func)
         root_block = analysis.get_root_block(sch)
         blocks = sch.get_child_blocks(root_block)
 
-        # Get the main computation block
-        def is_reduction(block: BlockRV) -> bool:
-            block_stmt = sch.get(block)
-            iter_types = {iter_var.iter_type for iter_var in 
block_stmt.iter_vars}
-            return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+        if func.attrs is not None and "dlight.do_not_tensorize" in 
func.attrs.keys():
+            return None
 
-        def is_spatial(block: BlockRV) -> bool:
-            block_stmt = sch.get(block)
-            iter_types = {iter_var.iter_type for iter_var in 
block_stmt.iter_vars}
-            return iter_types == {IterVar.DataPar}
+        reduction_blocks = get_reduction_blocks(sch, blocks)
+        if reduction_blocks is None:
+            return None
 
-        # NOTE: We assume there is only one reduction block in the function
-        # all blocks are required to be spatial or reduction
-        if not all([is_reduction(block) or is_spatial(block) for block in 
blocks]):
+        main_block = reduction_blocks[0]
+        block_stmt = sch.get(main_block)
+        index_maps = get_index_map(block_stmt)
+        if index_maps is None:
             return None
+        matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+        # Start Schedule
+        # Step 0. Get schedule config.
+        # NOTE: we can analyze the config by the hardware spec in the future
+
+        # tensor core intrinsic size
+        micro_size_x = 16
+        micro_size_y = 16
+        micro_size_k = 16
+
+        i_factors, j_factors, k_factors = (
+            [None, 1, 2, 2],
+            [1, None, 2, 2],
+            [None, 2],
+        )
+
+        num_ty = i_factors[2] * j_factors[2]
+        x_pad_factor = i_factors[2] * i_factors[3]
+        y_pad_factor = j_factors[2] * j_factors[3]
+        k_pad_factor = k_factors[1]
+
+        # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, 
J, K]
+        block = sch.reindex(main_block, ("read", 0))
+        sch.transform_layout(block, ("write", 0), a_index_map)
+        block = sch.reindex(main_block, ("read", 1))
+        sch.transform_layout(block, ("write", 0), b_index_map)
+        block = sch.reindex(main_block, ("write", 0))
+        sch.transform_layout(block, ("read", 0), c_index_map)
+        sch.transform_block_layout(main_block, matmul_index_map)
+
+        # Step 2. Padding for dynamic shape kernels
+        sch.pad_einsum(
+            main_block,
+            [
+                1,
+                micro_size_x * x_pad_factor,
+                micro_size_y * y_pad_factor,
+                micro_size_k * k_pad_factor,
+            ],
+        )
 
-        # There is only one reduction block
-        reduction_blocks = [block for block in blocks if is_reduction(block)]
-        if len(reduction_blocks) != 1:
+        # Step 3. Schedule matmul to use tensor core
+        block = main_block
+
+        batch, i, j, k = sch.get_loops(block)
+
+        # inner loops for tensor core computation
+        i, i_inner = sch.split(i, factors=[None, micro_size_x])
+        j, j_inner = sch.split(j, factors=[None, micro_size_y])
+        k, k_inner = sch.split(k, factors=[None, micro_size_k])
+
+        sch.reorder(i, j, k, i_inner, j_inner, k_inner)
+
+        block_inner = block
+        block_outer = sch.blockize(i_inner)
+
+        i0, i1, i2, i3 = sch.split(i, factors=i_factors)
+        j0, j1, j2, j3 = sch.split(j, factors=j_factors)
+        k0, k1 = sch.split(k, k_factors)
+
+        sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)
+
+        block_idx = sch.fuse(i0, j0)
+        block_idy = sch.fuse(i1, j1)
+        thread_idy = sch.fuse(j2, i2)
+        sch.bind(batch, "blockIdx.z")
+        sch.bind(block_idx, "blockIdx.x")
+        sch.bind(block_idy, "blockIdx.y")
+        sch.bind(thread_idy, "threadIdx.y")
+
+        def fetch_to_shared(block, idx, ndim):
+            block_read = sch.cache_read(block, idx, "shared.dyn")
+            sch.compute_at(block_read, k0)
+            vector_size = 8
+            warp_size = 32
+            fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+
+            _, f_1, f_2, f_3 = sch.split(
+                fused, factors=[None, num_ty, warp_size, vector_size]
+            )
+            sch.bind(f_2, "threadIdx.x")
+            sch.bind(f_1, "threadIdx.y")
+            sch.vectorize(f_3)
+
+            sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8)
+            return block_read
+
+        a_g2s = fetch_to_shared(block_outer, 0, 2)
+        b_g2s = fetch_to_shared(block_outer, 1, 2)
+
+        auto_inline_producers(sch, a_g2s)
+        auto_inline_producers(sch, b_g2s)
+
+        # create read cache to load matrix from shared memory to wmma fragments
+        A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
+        B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
+        sch.compute_at(A_mat, k1)
+        sch.compute_at(B_mat, k1)
+
+        # create write cache to store matrix from wmma fragments to shared 
memory and global memory
+        accumulator_shared_to_global = sch.cache_write(block_outer, 0, 
"shared")
+        sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
+
+        store = sch.cache_write(block_outer, 0, "wmma.accumulator")
+        sch.reverse_compute_at(store, thread_idy)
+        sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
+
+        # split the store loop to match hardware intrinsic pattern
+        i, j = sch.get_loops(store)[-2:]
+        i0, i1 = sch.split(i, factors=[None, 16])
+        j0, j1 = sch.split(j, factors=[None, 16])
+        sch.reorder(i0, j0, i1, j1)
+
+        block_init_c = sch.decompose_reduction(block_outer, k0)
+        block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
+
+        # Tensorization by hardware intrinsics
+        intrin_group = get_wmma_intrin_group(
+            load_scope="shared.dyn",
+            store_scope="shared",
+            in_dtype="float16",
+            out_dtype="float32",
+            trans_b=True

Review Comment:
   Seems this dlight rule targets NT(trans_a = non_transpose, trans_b = 
transpose) matmul, can we also support NN matmul as well? (in some quantization 
modes, the weight matrix is not transposed, e.g. `b4f16_0`).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to