This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new e8995272d1 [Unity][Dlight] Tensorization Rule in GPU Matmul (#15389)
e8995272d1 is described below

commit e8995272d112b3bc9e58a55fce4320c5539b9958
Author: Bao Anchang <[email protected]>
AuthorDate: Thu Jul 27 14:34:59 2023 +0800

    [Unity][Dlight] Tensorization Rule in GPU Matmul (#15389)
    
    * feat: dlight matmul tensorize
    
    * lint
    
    * lint and add comment
    
    * fix: use vector size 4
    
    * fix git-black lint
    
    * pylint
    
    * lint
    
    * fix: use share.dyn in write cache
---
 python/tvm/dlight/gpu/matmul.py                  | 271 +++++++++++++++++++++--
 tests/python/dlight/test_gpu_matmul.py           |   2 +-
 tests/python/dlight/test_gpu_matmul_tensorize.py | 261 ++++++++++++++++++++++
 3 files changed, 516 insertions(+), 18 deletions(-)

diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index def13a60ac..5ea3aa4dee 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -249,8 +249,41 @@ def get_index_map(block: tir.Block) -> 
Optional[Tuple[tir.IndexMap, ...]]:
     )
 
 
-class Matmul(ScheduleRule):
-    """The schedule rule for matmul-like computation"""
+def get_reduction_blocks(sch, blocks) -> bool:
+    # Get the main computation block
+    def is_reduction(block: BlockRV) -> bool:
+        block_stmt = sch.get(block)
+        iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+        return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+
+    def is_spatial(block: BlockRV) -> bool:
+        block_stmt = sch.get(block)
+        iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+        return iter_types == {IterVar.DataPar}
+
+    # NOTE: We assume there is only one reduction block in the function
+    # all blocks are required to be spatial or reduction
+    if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
+        return None
+
+    # There is only one reduction block
+    reduction_blocks = [block for block in blocks if is_reduction(block)]
+    if len(reduction_blocks) != 1:
+        return None
+
+    return reduction_blocks
+
+
+def check_sm_version(arch: str) -> int:
+    sm_version = arch.replace("sm_", "")
+    return int(sm_version) if sm_version.isdigit() else -1
+
+
+class MatmulTensorization(ScheduleRule):
+    """
+    The schedule rule for float16 tensor core matmul computation.
+    func with attr 'dlight.do_not_tensorize' will not be tensorized.
+    """
 
     def apply(  # pylint: disable=too-many-locals,missing-docstring
         self,
@@ -258,29 +291,216 @@ class Matmul(ScheduleRule):
         target: Target,
         _: bool,
     ) -> Optional[tir.Schedule]:
+        from tvm.tir.tensor_intrin.cuda import (  # pylint: 
disable=import-outside-toplevel
+            get_wmma_intrin_group,
+        )
+
         sch = tir.Schedule(func)
         root_block = analysis.get_root_block(sch)
         blocks = sch.get_child_blocks(root_block)
 
-        # Get the main computation block
-        def is_reduction(block: BlockRV) -> bool:
-            block_stmt = sch.get(block)
-            iter_types = {iter_var.iter_type for iter_var in 
block_stmt.iter_vars}
-            return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+        if func.attrs is not None and "dlight.do_not_tensorize" in 
func.attrs.keys():
+            return None
 
-        def is_spatial(block: BlockRV) -> bool:
-            block_stmt = sch.get(block)
-            iter_types = {iter_var.iter_type for iter_var in 
block_stmt.iter_vars}
-            return iter_types == {IterVar.DataPar}
+        reduction_blocks = get_reduction_blocks(sch, blocks)
+        if reduction_blocks is None:
+            return None
 
-        # NOTE: We assume there is only one reduction block in the function
-        # all blocks are required to be spatial or reduction
-        if not all([is_reduction(block) or is_spatial(block) for block in 
blocks]):
+        main_block = reduction_blocks[0]
+        block_stmt = sch.get(main_block)
+        index_maps = get_index_map(block_stmt)
+        if index_maps is None:
             return None
+        matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+        # Start Schedule
+        # Step 0. Get schedule config.
+        # NOTE: we can analyze the config by the hardware spec in the future
+
+        # tensor core intrinsic size
+        micro_size_x = 16
+        micro_size_y = 16
+        micro_size_k = 16
+
+        i_factors, j_factors, k_factors = (
+            [None, 1, 2, 2],
+            [1, None, 2, 2],
+            [None, 2],
+        )
 
-        # There is only one reduction block
-        reduction_blocks = [block for block in blocks if is_reduction(block)]
-        if len(reduction_blocks) != 1:
+        num_ty = i_factors[2] * j_factors[2]
+        x_pad_factor = i_factors[2] * i_factors[3]
+        y_pad_factor = j_factors[2] * j_factors[3]
+        k_pad_factor = k_factors[1]
+
+        # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, 
J, K]
+        block = sch.reindex(main_block, ("read", 0))
+        sch.transform_layout(block, ("write", 0), a_index_map)
+        block = sch.reindex(main_block, ("read", 1))
+        sch.transform_layout(block, ("write", 0), b_index_map)
+        block = sch.reindex(main_block, ("write", 0))
+        sch.transform_layout(block, ("read", 0), c_index_map)
+        sch.transform_block_layout(main_block, matmul_index_map)
+
+        # Step 2. Padding for dynamic shape kernels
+        sch.pad_einsum(
+            main_block,
+            [
+                1,
+                micro_size_x * x_pad_factor,
+                micro_size_y * y_pad_factor,
+                micro_size_k * k_pad_factor,
+            ],
+        )
+
+        # Step 3. Schedule matmul to use tensor core
+        block = main_block
+
+        batch, i, j, k = sch.get_loops(block)
+
+        # inner loops for tensor core computation
+        i, i_inner = sch.split(i, factors=[None, micro_size_x])
+        j, j_inner = sch.split(j, factors=[None, micro_size_y])
+        k, k_inner = sch.split(k, factors=[None, micro_size_k])
+
+        sch.reorder(i, j, k, i_inner, j_inner, k_inner)
+
+        block_inner = block
+        block_outer = sch.blockize(i_inner)
+
+        i0, i1, i2, i3 = sch.split(i, factors=i_factors)
+        j0, j1, j2, j3 = sch.split(j, factors=j_factors)
+        k0, k1 = sch.split(k, k_factors)
+
+        sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)
+
+        block_idx = sch.fuse(i0, j0)
+        block_idy = sch.fuse(i1, j1)
+        thread_idy = sch.fuse(j2, i2)
+        sch.bind(batch, "blockIdx.z")
+        sch.bind(block_idx, "blockIdx.x")
+        sch.bind(block_idy, "blockIdx.y")
+        sch.bind(thread_idy, "threadIdx.y")
+
+        def fetch_to_shared(block, idx, ndim):
+            block_read = sch.cache_read(block, idx, "shared.dyn")
+            sch.compute_at(block_read, k0)
+            vector_size = 4
+            warp_size = 32
+            fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+
+            _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty, 
warp_size, vector_size])
+
+            sch.bind(f_2, "threadIdx.x")
+            sch.bind(f_1, "threadIdx.y")
+            sch.vectorize(f_3)
+
+            sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8)
+            return block_read
+
+        a_g2s = fetch_to_shared(block_outer, 0, 2)
+        b_g2s = fetch_to_shared(block_outer, 1, 2)
+
+        auto_inline_producers(sch, a_g2s)
+        auto_inline_producers(sch, b_g2s)
+
+        # create read cache to load matrix from shared memory to wmma fragments
+        A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
+        B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
+        sch.compute_at(A_mat, k1)
+        sch.compute_at(B_mat, k1)
+
+        # create write cache to store matrix from wmma fragments to shared 
memory and global memory
+        accumulator_shared_to_global = sch.cache_write(block_outer, 0, 
"shared.dyn")
+        sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
+
+        store = sch.cache_write(block_outer, 0, "wmma.accumulator")
+        sch.reverse_compute_at(store, thread_idy)
+        sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
+
+        # split the store loop to match hardware intrinsic pattern
+        i, j = sch.get_loops(store)[-2:]
+        i0, i1 = sch.split(i, factors=[None, 16])
+        j0, j1 = sch.split(j, factors=[None, 16])
+        sch.reorder(i0, j0, i1, j1)
+
+        block_init_c = sch.decompose_reduction(block_outer, k0)
+        block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
+
+        # Tensorization by hardware intrinsics
+        intrin_group = get_wmma_intrin_group(
+            load_scope="shared.dyn",
+            store_scope="shared.dyn",
+            in_dtype="float16",
+            out_dtype="float32",
+            trans_b=True,
+        )
+
+        try:
+            i, j = sch.get_loops(A_mat)[-2:]
+            i0, i1 = sch.split(i, factors=[None, 16])
+            j0, j1 = sch.split(j, factors=[None, 16])
+            sch.reorder(i0, j0, i1, j1)
+            sch.unroll(i0)
+            sch.unroll(j0)
+            sch.tensorize(i1, intrin_group["load_a"])
+
+            i, j = sch.get_loops(B_mat)[-2:]
+            i0, i1 = sch.split(i, factors=[None, 16])
+            j0, j1 = sch.split(j, factors=[None, 16])
+            sch.reorder(i0, j0, i1, j1)
+            sch.unroll(i0)
+            sch.unroll(j0)
+            sch.tensorize(i1, intrin_group["load_b"])
+        except:  # pylint: disable=bare-except
+            return None
+
+        # Try to tensorize the init, store and compute block with f16 or f32 
intrinsics
+        tensorize_success: bool = False
+
+        def tensorize_init_store_compute():
+            sch.tensorize(sch.get_loops(block_init_c_inner)[-2], 
intrin_group["init"])
+            sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"])
+            sch.tensorize(sch.get_loops(block_inner)[-3], 
intrin_group["compute"])
+
+        try:
+            tensorize_init_store_compute()
+            tensorize_success = True
+        except:  # pylint: disable=bare-except
+            intrin_group = get_wmma_intrin_group(
+                load_scope="shared.dyn",
+                store_scope="shared.dyn",
+                in_dtype="float16",
+                out_dtype="float16",
+                trans_b=True,
+            )
+
+        if not tensorize_success:
+            try:
+                tensorize_init_store_compute()
+                tensorize_success = True
+            except:  # pylint: disable=bare-except
+                return None
+
+        auto_inline_consumers(sch, accumulator_shared_to_global)
+        return sch if tensorize_success else None
+
+
+class Matmul(ScheduleRule):
+    """The schedule rule for matmul-like computation"""
+
+    def apply(  # pylint: disable=too-many-locals,missing-docstring
+        self,
+        func: tir.PrimFunc,
+        target: Target,
+        _: bool,
+    ) -> Optional[tir.Schedule]:
+        sch = tir.Schedule(func)
+        root_block = analysis.get_root_block(sch)
+        blocks = sch.get_child_blocks(root_block)
+
+        reduction_blocks = get_reduction_blocks(sch, blocks)
+        if reduction_blocks is None:
             return None
 
         main_block = reduction_blocks[0]
@@ -302,6 +522,11 @@ class Matmul(ScheduleRule):
         micro_size_k = 16
         vector_size = 2
 
+        # Tensorization config:
+        # If any value of I, J, K is fixed and less than this threshold,
+        # tensorization rule will not be applied.
+        minimal_tensorize_threshold = 128
+
         # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S, 
J, K]
         block = sch.reindex(main_block, ("read", 0))
         sch.transform_layout(block, ("write", 0), a_index_map)
@@ -311,6 +536,18 @@ class Matmul(ScheduleRule):
         sch.transform_layout(block, ("read", 0), c_index_map)
         sch.transform_block_layout(main_block, matmul_index_map)
 
+        block_stmt = sch.get(main_block)
+        if target.kind.name == "cuda" and check_sm_version(target.arch) > 70:
+            apply_tensorization: bool = True
+            # the batch dimension is not taken into consideration.
+            for item_var in block_stmt.iter_vars[1:]:
+                extent = item_var.dom.extent
+                if isinstance(extent, tir.expr.IntImm):
+                    if extent.value <= minimal_tensorize_threshold:
+                        apply_tensorization = False
+            if apply_tensorization:
+                return MatmulTensorization().apply(func, target, _)
+
         # Step 2. Padding for dynamic shape kernels
         sch.pad_einsum(
             main_block,
diff --git a/tests/python/dlight/test_gpu_matmul.py 
b/tests/python/dlight/test_gpu_matmul.py
index b9ee95b76b..bc6419160b 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -28,7 +28,7 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
     @pytest.fixture
     def transform(self):
         def transform(mod):
-            with Target("nvidia/geforce-rtx-3090-ti"):
+            with Target("nvidia/geforce-gtx-1080-ti"):
                 return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
 
         return transform
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py 
b/tests/python/dlight/test_gpu_matmul_tensorize.py
new file mode 100644
index 0000000000..349b4bf256
--- /dev/null
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+
+import tvm.testing
+from tvm import dlight as dl
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+    @pytest.fixture
+    def transform(self):
+        def transform(mod):
+            with Target("nvidia/geforce-rtx-2080-ti"):
+                return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+
+        return transform
+
+
+class TestMatmulTensorize(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), 
"float16"), compute: T.Buffer((256, 256), "float16")):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        for i, j, k in T.grid(256, 256, 256):
+            with T.block("compute"):
+                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                T.reads(X[v_i, v_k], W[v_j, v_k])
+                T.writes(compute[v_i, v_j])
+                with T.init():
+                    compute[v_i, v_j] = T.float16(0)
+                compute[v_i, v_j] = compute[v_i, v_j] + X[v_i, v_k] * W[v_j, 
v_k]
+
+    @T.prim_func
+    def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256), 
"float16"), compute: T.Buffer((256, 256), "float16")):
+        T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+        # with T.block("root"):
+        X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", 
scope="shared.dyn")
+        W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", 
scope="shared.dyn")
+        X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256), 
"float16", scope="wmma.matrix_a")
+        W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256), 
"float16", scope="wmma.matrix_b")
+        compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16", 
scope="shared.dyn")
+        compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256, 
256), "float16", scope="wmma.accumulator")
+        for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+            for ax1_0_0_ax2_0_0_fused in T.thread_binding(4, 
thread="blockIdx.x"):
+                for ax1_0_1_ax2_0_1_fused in T.thread_binding(4, 
thread="blockIdx.y"):
+                    for ax2_0_2_ax1_0_2_fused in T.thread_binding(4, 
thread="threadIdx.y"):
+                        for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+                            with T.block("compute_o_init"):
+                                v0_o = T.axis.spatial(T.int64(1), ax0)
+                                v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3_init)
+                                v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax2_0_3_init)
+                                T.reads()
+                                
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                with T.block("compute_init_o"):
+                                    v1_i_init_o = T.axis.spatial(1, 0)
+                                    v2_i_init_o = T.axis.spatial(1, 0)
+                                    T.reads()
+                                    
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                    C = 
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", 
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                    T.tvm_fill_fragment(C.data, 16, 16, 16, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, T.float32(0))
+                        for ax3_0_0 in range(8):
+                            for ax0_ax1_fused_0 in range(4):
+                                for ax0_ax1_fused_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with 
T.block("X_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(256, 
ax1_0_0_ax2_0_0_fused * 64 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 32)
+                                                v2 = T.axis.spatial(256, 
ax3_0_0 * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 
* 4 + ax0_ax1_fused_3) % 32)
+                                                T.reads(X[v1, v2])
+                                                
T.writes(X_reindex_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+                                                X_reindex_shared_dyn[v0, v1, 
v2] = X[v1, v2]
+                            for ax0_ax1_fused_0 in range(4):
+                                for ax0_ax1_fused_1 in T.thread_binding(4, 
thread="threadIdx.y"):
+                                    for ax0_ax1_fused_2 in 
T.thread_binding(32, thread="threadIdx.x"):
+                                        for ax0_ax1_fused_3 in T.vectorized(4):
+                                            with 
T.block("W_reindex_shared.dyn"):
+                                                v0 = T.axis.spatial(1, 0)
+                                                v1 = T.axis.spatial(256, 
ax1_0_1_ax2_0_1_fused * 64 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 + 
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 32)
+                                                v2 = T.axis.spatial(256, 
ax3_0_0 * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 
* 4 + ax0_ax1_fused_3) % 32)
+                                                T.reads(W[v1, v2])
+                                                
T.writes(W_reindex_shared_dyn[v0, v1, v2])
+                                                
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+                                                W_reindex_shared_dyn[v0, v1, 
v2] = W[v1, v2]
+                            for ax3_0_1 in range(2):
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("X_reindex_shared.dyn_wmma.matrix_a_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(16, ax3_0_0 
* 2 + ax3_0_1 + ax1_0)
+                                            T.reads(X_reindex_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            
T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, 
T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, 
A.strides[0] * 16, 1), A.strides[0], "row_major")
+                                for ax0_0 in T.unroll(2):
+                                    for ax1_0 in T.unroll(1):
+                                        with 
T.block("W_reindex_shared.dyn_wmma.matrix_b_o"):
+                                            v0_o = T.axis.spatial(1, 0)
+                                            v1_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax0_0)
+                                            v2_o = T.axis.spatial(16, ax3_0_0 
* 2 + ax3_0_1 + ax1_0)
+                                            T.reads(W_reindex_shared_dyn[v0_o, 
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                            
T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16, 
v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 
16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), 
scope="shared.dyn", offset_factor=16)
+                                            C = 
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            T.tvm_load_matrix_sync(C.data, 16, 
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16, 
T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset, 
A.strides[0] * 16, 1), A.strides[0], "col_major")
+                                for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+                                    with T.block("compute_o_update"):
+                                        v0_o = T.axis.spatial(T.int64(1), ax0)
+                                        v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3)
+                                        v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax2_0_3)
+                                        v3_o = T.axis.reduce(16, ax3_0_0 * 2 + 
ax3_0_1)
+                                        
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], 
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o 
* 16 + 16])
+                                        
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                        with T.block("compute_o"):
+                                            v1_i_o = T.axis.spatial(1, 0)
+                                            v2_i_o = T.axis.spatial(1, 0)
+                                            v3_i_o = T.axis.reduce(1, 0)
+                                            
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16], 
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o 
* 16 + 16])
+                                            
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                            A = 
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"), 
scope="wmma.matrix_a", offset_factor=16)
+                                            B = 
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, 
v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"), 
scope="wmma.matrix_b", offset_factor=16)
+                                            C = 
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", 
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+                                            T.tvm_mma_sync(C.data, 
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset % 
C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0] 
// 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset // 
B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16, 
C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + 
C.elem_offset % C.strides[0] // 16)
+                        for ax0_0, ax1_0 in T.grid(2, 2):
+                            with 
T.block("compute_reindex_shared.dyn_wmma.accumulator_o"):
+                                v0_o = T.axis.spatial(1, 0)
+                                v1_o = T.axis.spatial(16, 
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax0_0)
+                                v2_o = T.axis.spatial(16, 
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax1_0)
+                                
T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 + 
16, v2_o * 16:v2_o * 16 + 16])
+                                T.writes(compute_reindex_shared_dyn[v0_o, v1_o 
* 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+                                A = 
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o 
* 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", 
"A_s1"), scope="wmma.accumulator", offset_factor=16)
+                                C = 
T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o 
* 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"), 
scope="shared.dyn", offset_factor=16)
+                                T.tvm_store_matrix_sync(A.data, 16, 16, 16, 
A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset % 
A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data, 
C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major")
+                        for ax0_1, ax1 in T.grid(32, 32):
+                            with T.block("compute_reindex_shared.dyn"):
+                                v0 = T.axis.spatial(1, 0)
+                                v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused 
* 64 + ax2_0_2_ax1_0_2_fused % 2 * 32 + ax0_1)
+                                v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused 
* 64 + ax2_0_2_ax1_0_2_fused // 2 * 32 + ax1)
+                                T.reads(compute_reindex_shared_dyn[v0, v1, v2])
+                                T.writes(compute[v1, v2])
+                                T.block_attr({"buffer_dim_align": [[0, 1, 16, 
4]]})
+                                compute[v1, v2] = 
compute_reindex_shared_dyn[v0, v1, v2]
+
+    # fmt: on
+
+
+class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
+    # fmt: off
+
+    @T.prim_func
+    def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"), 
var_compute: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+        m = T.int32()
+        X = T.match_buffer(var_X, (m, 256), "float16")
+        compute = T.match_buffer(var_compute, (m, 15))
+        # with T.block("root"):
+        for i, j, k in T.grid(m, 15, 256):
+            with T.block("compute"):
+                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+                T.reads(X[v_i, v_k], W[v_j, v_k])
+                T.writes(compute[v_i, v_j])
+                with T.init():
+                    compute[v_i, v_j] = T.float32(0)
+                compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("float32", 
X[v_i, v_k]) * T.Cast("float32", W[v_j, v_k])
+
+    @T.prim_func
+    def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"), 
var_compute: T.handle):
+        T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1, 
"tir.noalias": T.bool(True)})
+        m = T.int32()
+        X = T.match_buffer(var_X, (m, 256), "float16")
+        compute = T.match_buffer(var_compute, (m, 15))
+        # with T.block("root"):
+        compute_reindex_pad_local = T.alloc_buffer((1, (T.Cast("int32", 
T.Cast("int64", m)) + 31) // 32 * 32, 64), scope="local")
+        X_reindex_pad_shared = T.alloc_buffer((1, (T.Cast("int32", 
T.Cast("int64", m)) + 31) // 32 * 32, 256), "float16", scope="shared")
+        W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16", 
scope="shared")
+        for ax0_ax2_0_fused in T.thread_binding(T.int64(1), 
thread="blockIdx.y"):
+            for ax1_0 in T.thread_binding((T.Cast("int32", T.Cast("int64", m)) 
+ 31) // 32, thread="blockIdx.x"):
+                for ax2_1 in T.thread_binding(1, thread="vthread.y"):
+                    for ax1_1 in T.thread_binding(1, thread="vthread.x"):
+                        for ax2_2 in T.thread_binding(16, 
thread="threadIdx.y"):
+                            for ax1_2 in T.thread_binding(8, 
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256, 
"pragma_unroll_explicit": 1}):
+                                for ax2_3_init, ax1_3_init in T.grid(4, 4):
+                                    with T.block("compute_init"):
+                                        v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                        v1 = T.axis.spatial((T.Cast("int32", 
T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 + 
ax1_3_init)
+                                        v2 = T.axis.spatial(64, ax2_1 * 64 + 
ax2_2 * 4 + ax2_3_init)
+                                        T.reads()
+                                        T.writes(compute_reindex_pad_local[0, 
v1, v2])
+                                        compute_reindex_pad_local[0, v1, v2] = 
T.float32(0)
+                                for ax3_0 in range(16):
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(16, thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(8, thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(2):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(2):
+                                                    with 
T.block("X_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(1, 
0)
+                                                        v1 = 
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 
32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 
* 2 + ax0_ax1_ax2_fused_3) // 16)
+                                                        v2 = 
T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 + 
ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+                                                        T.reads(X[v1, v2])
+                                                        
T.writes(X_reindex_pad_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
X_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, X[v1, v2], 
T.float16(0))
+                                    for ax0_ax1_ax2_fused_0 in 
T.thread_binding(16, thread="threadIdx.y"):
+                                        for ax0_ax1_ax2_fused_1 in 
T.thread_binding(8, thread="threadIdx.x"):
+                                            for ax0_ax1_ax2_fused_2 in 
range(4):
+                                                for ax0_ax1_ax2_fused_3 in 
T.vectorized(2):
+                                                    with 
T.block("W_reindex_pad_shared"):
+                                                        v0 = T.axis.spatial(1, 
0)
+                                                        v1 = 
T.axis.spatial(64, (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 + 
ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+                                                        v2 = 
T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 + 
ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+                                                        T.reads(W[v1, v2])
+                                                        
T.writes(W_reindex_pad_shared[v0, v1, v2])
+                                                        
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+                                                        
W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2], 
T.float16(0))
+                                    for ax3_1, ax2_3, ax1_3 in T.grid(16, 4, 
4):
+                                        with T.block("compute_update"):
+                                            v0 = T.axis.spatial(T.int64(1), 
T.int64(0))
+                                            v1 = 
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 
32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
+                                            v2 = T.axis.spatial(64, ax2_1 * 64 
+ ax2_2 * 4 + ax2_3)
+                                            v3 = T.axis.reduce(256, ax3_0 * 16 
+ ax3_1)
+                                            
T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3], 
W_reindex_pad_shared[0, v2, v3])
+                                            
T.writes(compute_reindex_pad_local[0, v1, v2])
+                                            compute_reindex_pad_local[0, v1, 
v2] = compute_reindex_pad_local[0, v1, v2] + T.Cast("float32", 
X_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", W_reindex_pad_shared[0, 
v2, v3])
+                                for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
+                                    for ax2_1_1 in T.vectorized(2):
+                                        with 
T.block("compute_reindex_pad_local"):
+                                            v0 = T.axis.spatial(1, ax0)
+                                            v1 = 
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 
32 + ax1_2 * 4 + ax1)
+                                            v2 = T.axis.spatial(64, ax2_2 * 4 
+ ax2_0 * 2 + ax2_1_1)
+                                            
T.reads(compute_reindex_pad_local[v0, v1, v2])
+                                            T.writes(compute[v1, v2])
+                                            if v1 < m and v2 < 15:
+                                                compute[v1, v2] = 
compute_reindex_pad_local[v0, v1, v2]
+    # fmt: on
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to