This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new db4290b608 [TIR] Support more mma intrinsics and 
`get_mma_intrin_group` utility (#16073)
db4290b608 is described below

commit db4290b608a66d812ec49cf147440fe930ab2b6f
Author: Yixin Dong <[email protected]>
AuthorDate: Tue Nov 7 13:22:38 2023 -0800

    [TIR] Support more mma intrinsics and `get_mma_intrin_group` utility 
(#16073)
    
    * 1104
    
    * 1104
    
    * 1105
    
    * fix ci
    
    * fix ci
---
 python/tvm/tir/tensor_intrin/cuda.py               | 571 ++++++++++++++-------
 ...tir_schedule_tensorize_ldmatrix_mma_numeric.py} |  56 +-
 ...=> test_tir_schedule_tensorize_mfma_numeric.py} |   0
 .../test_tir_transform_inject_software_pipeline.py |   8 +-
 4 files changed, 408 insertions(+), 227 deletions(-)

diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
index 6ee00ee634..409a1ff10a 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -14,18 +14,15 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=invalid-name,missing-function-docstring
+# pylint: disable=invalid-name,missing-function-docstring,unused-variable
 """Intrinsics for tensorization on NVIDIA GPU."""
-from typing import Dict, Tuple
-
-from typing_extensions import Literal
+from typing import Dict, Optional, Tuple, Literal
 
+from tvm._ffi import register_func
+from tvm.runtime import convert
 from tvm.script import tir as T
 from tvm.tir.function import PrimFunc
-
-from ..._ffi import register_func
-from ...runtime import convert
-from .. import Cast, IntImm, TensorIntrin
+from tvm.tir import Cast, IntImm, TensorIntrin
 
 
 def shared_16x16_to_ldmatrix_32x8_layout(i, j):
@@ -43,6 +40,12 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
     return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
 
 
+def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
+    row = 8 * (local_id % 4 // 2) + (thread_id // 4)
+    col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
+    return row, col
+
+
 @register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
 def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
     i, j = ind[0], ind[1]
@@ -59,70 +62,94 @@ HALF_WARP = WARP_SIZE // 2
 HALF_WARP_expr = lift(HALF_WARP)
 
 
-def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
+def get_ldmatrix_intrin(
+    k_dim: int,
+    dtype: str,
+    matrix_name: Literal["A", "B"],
+    transposed: bool,
+    shared_scope: str = "shared",
+):
     local_size = (M_DIM * k_dim) // WARP_SIZE
-    shared_offset = None
+    smem_offset = None
     index_map = None
 
-    if transposed:
-        assert is_b, "Transposed A matrix not supported"
-
-    ldmatrix_col_major = is_b and not transposed
+    if matrix_name == "A":
+        transpose_in_ldmatrix = transposed
+        # transpose_layout_for_ldmatrix_input: Every thread loads 8 bytes 
data. This determines
+        # which 8 bytes every thread loads.
+        # If transpose_layout_for_ldmatrix_input is False, the load pattern is
+        # T0  T0  T0  T0  T16 T16 T16 T16
+        # T1  T1  T1  T1  T17 T17 T17 T17
+        # ...
+        # T8  T8  T8  T8  T24 T24 T24 T24
+        # T9  T9  T9  T9  T25 T25 T25 T25
+        # ...
+        # T15 T15 T15 T15 T31 T31 T31 T31
+        # Otherwise, the load pattern is
+        # T0  T0  T0  T0  T8  T8  T8  T8
+        # T1  T1  T1  T1  T9  T9  T9  T9
+        # ...
+        # T7  T7  T7  T7  T15 T15 T15 T15
+        # T16 T16 T16 T16 T24 T24 T24 T24
+        # T17 T17 T17 T17 T25 T25 T25 T25
+        # ...
+        # T23 T23 T23 T23 T31 T31 T31 T31
+        transpose_layout_for_ldmatrix_input = transposed
+        smem_tile_row, smem_tile_col = (M_DIM, k_dim) if not transposed else 
(k_dim, M_DIM)
+    else:
+        assert matrix_name == "B"
+        transpose_in_ldmatrix = not transposed
+        transpose_layout_for_ldmatrix_input = transposed
+        smem_tile_row, smem_tile_col = (k_dim, N_DIM) if not transposed else 
(N_DIM, k_dim)
 
     if k_dim == 16:
         assert dtype == "float16"
 
         index_map = shared_16x16_to_ldmatrix_32x8_layout
 
-        if transposed:
-            shared_offset = (
+        if transpose_layout_for_ldmatrix_input:
+            smem_offset = (
                 lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
                 + stride * (tx % 8)
                 + 8 * ((tx % HALF_WARP_expr) // 8)
             )
         else:
-            shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) 
+ 8 * (
+            smem_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) + 
8 * (
                 tx // HALF_WARP_expr
             )
     else:
+        # TODO(yixin): Support TN and TT matmul for int8
+        assert (
+            matrix_name == "B" or not transposed
+        ), "Now only B matrix can be transposed for int8 matmul"
         assert (
             k_dim == 32 and dtype == "int8"
         ), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
 
-        if ldmatrix_col_major:
+        if matrix_name == "B" and not transposed:
             index_map = shared_32x16_to_ldmatrix_32x16_layout
             # A dummy offset, ldmatrix cannot be used for int8 + trans case.
             # We still use the ldmatrix intrinsic, but lower it to a manual 
loop in the codegen.
             # Only the stride information is required.
-            shared_offset = lambda _, stride: stride
-        elif is_b and transposed:
+            smem_offset = lambda _, stride: stride
+        elif matrix_name == "B" and transposed:
             index_map = shared_16x32_to_ldmatrix_32x16_layout
-            shared_offset = (
+            smem_offset = (
                 lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
                 + (tx % 8) * stride
                 + 16 * ((tx % HALF_WARP_expr) // 8)
             )
-        else:
+        else:  # A, not transposed
             index_map = shared_16x32_to_ldmatrix_32x16_layout
-            shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx 
// 16)
-
-    assert index_map and shared_offset
-
-    if is_b and not transposed:
-        row_dim = k_dim
-        col_dim = M_DIM
-    else:
-        row_dim = M_DIM
-        col_dim = k_dim
+            smem_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx // 
16)
 
-    shmem_shape = (row_dim, col_dim)
-    offset_factor = col_dim
+    offset_factor = smem_tile_col
 
     @T.prim_func
     def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
         shared = T.match_buffer(
             shared_handle,
-            shmem_shape,
+            (smem_tile_row, smem_tile_col),
             dtype,
             align=64,
             offset_factor=offset_factor,
@@ -138,10 +165,10 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
         )
 
         with T.block("root"):
-            T.reads(shared[0:row_dim, 0:col_dim])
+            T.reads(shared[0:smem_tile_row, 0:smem_tile_col])
             T.writes(warp[0:WARP_SIZE, 0:local_size])
 
-            for ax0, ax1 in T.grid(row_dim, col_dim):
+            for ax0, ax1 in T.grid(smem_tile_row, smem_tile_col):
                 with T.block("shared_warp"):
                     v0, v1 = T.axis.remap("SS", [ax0, ax1])
                     T.reads(shared[v0, v1])
@@ -156,7 +183,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
         s1 = T.int32()
         shared = T.match_buffer(
             shared_handle,
-            shmem_shape,
+            (smem_tile_row, smem_tile_col),
             dtype,
             align=64,
             offset_factor=offset_factor,
@@ -173,28 +200,68 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
         )
 
         with T.block("root"):
-            T.reads(shared[0:row_dim, 0:col_dim])
+            T.reads(shared[0:smem_tile_row, 0:smem_tile_col])
             T.writes(warp[0:WARP_SIZE, 0:local_size])
-            tx = T.env_thread("threadIdx.x")
-            T.launch_thread(tx, WARP_SIZE)
-
-            T.evaluate(
-                T.ptx_ldmatrix(
-                    ldmatrix_col_major,
-                    4,  # Always load 4 matrices
-                    ".b16",
-                    warp.data,
-                    warp.elem_offset + lift(local_size) * tx,
-                    shared.access_ptr("r"),
-                    shared_offset(tx, s0),
-                    dtype=dtype,
+            for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+                T.evaluate(
+                    T.ptx_ldmatrix(
+                        transpose_in_ldmatrix,
+                        4,  # Always load 4 matrices
+                        ".b16",
+                        warp.data,
+                        warp.elem_offset + lift(local_size) * tx,
+                        shared.access_ptr("r"),
+                        smem_offset(tx, s0),
+                        dtype=dtype,
+                    )
                 )
-            )
 
     return ldmatrix_desc, ldmatrix_impl
 
 
-def get_mma_intrin(k_dim, out_dtype, b_transposed):
+LDMATRIX_f16_A_INTRIN = "mma_ldmatrix_f16_a"
+TensorIntrin.register(LDMATRIX_f16_A_INTRIN, *get_ldmatrix_intrin(16, 
"float16", "A", False))
+
+LDMATRIX_f16_B_INTRIN = "mma_ldmatrix_f16_b"
+TensorIntrin.register(LDMATRIX_f16_B_INTRIN, *get_ldmatrix_intrin(16, 
"float16", "B", False))
+
+LDMATRIX_f16_A_TRANS_INTRIN = "mma_ldmatrix_f16_a_trans"
+TensorIntrin.register(LDMATRIX_f16_A_TRANS_INTRIN, *get_ldmatrix_intrin(16, 
"float16", "A", True))
+
+LDMATRIX_f16_B_TRANS_INTRIN = "mma_ldmatrix_f16_b_trans"
+TensorIntrin.register(LDMATRIX_f16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, 
"float16", "B", True))
+
+LDMATRIX_f16_A_DYN_INTRIN = "mma_ldmatrix_f16_a_dyn"
+TensorIntrin.register(
+    LDMATRIX_f16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False, 
"shared.dyn")
+)
+
+LDMATRIX_f16_B_DYN_INTRIN = "mma_ldmatrix_f16_b_dyn"
+TensorIntrin.register(
+    LDMATRIX_f16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", False, 
"shared.dyn")
+)
+
+LDMATRIX_f16_A_TRANS_DYN_INTRIN = "mma_ldmatrix_f16_a_trans_dyn"
+TensorIntrin.register(
+    LDMATRIX_f16_A_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", 
True, "shared.dyn")
+)
+
+LDMATRIX_f16_B_TRANS_DYN_INTRIN = "mma_ldmatrix_f16_b_trans_dyn"
+TensorIntrin.register(
+    LDMATRIX_f16_B_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", 
True, "shared.dyn")
+)
+
+LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_i8_a"
+TensorIntrin.register(LDMATRIX_i8_A_INTRIN, *get_ldmatrix_intrin(32, "int8", 
"A", False))
+
+LDMATRIX_i8_B_INTRIN = "mma_ldmatrix_i8_b"
+TensorIntrin.register(LDMATRIX_i8_B_INTRIN, *get_ldmatrix_intrin(32, "int8", 
"B", False))
+
+LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans"
+TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, 
"int8", "B", True))
+
+
+def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed):
     local_size = (M_DIM * k_dim) // WARP_SIZE
     local_size_out = (M_DIM * N_DIM) // 32
 
@@ -223,18 +290,16 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
         in_dtype = "int8"
         in_dtype_abbrv = "int8"
 
-    def maybe_cast(v):
+    def cast_to_out_dtype(v):
         if out_dtype in ["float32", "int32"]:
             return Cast(out_dtype, v)
         return v
 
-    def maybe_swap(i, j):
-        if b_transposed:
-            return j, i
-        return i, j
+    def swap_if_flag(i, j, flag):
+        return (j, i) if flag else (i, j)
 
-    A_offset_factor = k_dim
-    B_offset_factor = maybe_swap(k_dim, N_DIM)[-1]
+    A_offset_factor = M_DIM if a_transposed else k_dim
+    B_offset_factor = k_dim if b_transposed else N_DIM
     out_offset_factor = N_DIM
 
     @T.prim_func
@@ -275,10 +340,11 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
                 with T.block("C"):
                     i, j, k = T.axis.remap("SSR", [i, j, k])
-                    b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j))
+                    a_row_ind, a_col_ind = T.meta_var(swap_if_flag(i, k, 
a_transposed))
+                    b_row_ind, b_col_ind = T.meta_var(swap_if_flag(k, j, 
b_transposed))
 
                     thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
-                    thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
+                    thread_id_A, local_id_A = 
T.meta_var(index_map_A(a_row_ind, a_col_ind))
                     thread_id_B, local_id_B = 
T.meta_var(index_map_B(b_row_ind, b_col_ind))
 
                     T.reads(
@@ -288,9 +354,9 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
                     )
                     T.writes(C[thread_id_C, local_id_C])
 
-                    C[thread_id_C, local_id_C] += maybe_cast(
+                    C[thread_id_C, local_id_C] += cast_to_out_dtype(
                         A[thread_id_A, local_id_A]
-                    ) * maybe_cast(B[thread_id_B, local_id_B])
+                    ) * cast_to_out_dtype(B[thread_id_B, local_id_B])
 
     @T.prim_func
     def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
@@ -326,50 +392,84 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
                 B[0:WARP_SIZE, 0:local_size],
             )
             T.writes(C[0:WARP_SIZE, 0:local_size_out])
-            tx = T.env_thread("threadIdx.x")
-            T.launch_thread(tx, WARP_SIZE)
 
-            T.evaluate(
-                T.ptx_mma(
-                    mma_prefix,
-                    "row",
-                    "col",
-                    in_dtype_abbrv,
-                    in_dtype_abbrv,
-                    out_dtype_abbrv,
-                    A.data,
-                    A.elem_offset + tx * lift(local_size),
-                    B.data,
-                    B.elem_offset + tx * lift(local_size),
-                    C.data,
-                    C.elem_offset + tx * lift(local_size_out),
-                    False,
-                    dtype=out_dtype,
+            for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+                T.evaluate(
+                    T.ptx_mma(
+                        mma_prefix,
+                        "row",
+                        "col",
+                        in_dtype_abbrv,
+                        in_dtype_abbrv,
+                        out_dtype_abbrv,
+                        A.data,
+                        A.elem_offset + tx * lift(local_size),
+                        B.data,
+                        B.elem_offset + tx * lift(local_size),
+                        C.data,
+                        C.elem_offset + tx * lift(local_size_out),
+                        False,
+                        dtype=out_dtype,
+                    )
                 )
-            )
 
-            T.evaluate(
-                T.ptx_mma(
-                    mma_prefix,
-                    "row",
-                    "col",
-                    in_dtype_abbrv,
-                    in_dtype_abbrv,
-                    out_dtype_abbrv,
-                    A.data,
-                    A.elem_offset + tx * lift(local_size),
-                    B.data,
-                    B.elem_offset + tx * lift(local_size) + lift(local_size) 
// 2,
-                    C.data,
-                    C.elem_offset + tx * lift(local_size_out) + 
lift(local_size_out) // 2,
-                    False,
-                    dtype=out_dtype,
+                T.evaluate(
+                    T.ptx_mma(
+                        mma_prefix,
+                        "row",
+                        "col",
+                        in_dtype_abbrv,
+                        in_dtype_abbrv,
+                        out_dtype_abbrv,
+                        A.data,
+                        A.elem_offset + tx * lift(local_size),
+                        B.data,
+                        B.elem_offset + tx * lift(local_size) + 
lift(local_size) // 2,
+                        C.data,
+                        C.elem_offset + tx * lift(local_size_out) + 
lift(local_size_out) // 2,
+                        False,
+                        dtype=out_dtype,
+                    )
                 )
-            )
 
     return mma_sync_desc, mma_sync_impl
 
 
+MMA_f16f16f32_INTRIN = "mma_f16f16f32"
+TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", 
False, False))
+
+MMA_f16f16f32_TRANS_B_INTRIN = "mma_f16f16f32_trans_b"
+TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16, 
"float32", False, True))
+
+MMA_f16f16f32_TRANS_A_INTRIN = "mma_f16f16f32_trans_a"
+TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16, 
"float32", True, False))
+
+MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f32_trans_a_trans_b"
+TensorIntrin.register(
+    MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", True, 
True)
+)
+
+MMA_f16f16f16_INTRIN = "mma_f16f16f16"
+TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", 
False, False))
+
+MMA_f16f16f16_TRANS_B_INTRIN = "mma_f16f16f16_trans_b"
+TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16, 
"float16", False, True))
+
+MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a"
+TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16, 
"float16", True, False))
+
+MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f16_trans_a_trans_b"
+TensorIntrin.register(
+    MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", True, 
True)
+)
+
+MMA_i8i8i32_INTRIN = "mma_i8i8i32"
+TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False, 
False))
+
+MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b"
+TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32", 
False, True))
+
+
 def get_mma_fill_intrin(dtype, local_size):
     zero = IntImm("int32", 0).astype(dtype)
 
@@ -400,17 +500,27 @@ def get_mma_fill_intrin(dtype, local_size):
         with T.block("root"):
             T.reads()
             T.writes(C_warp[0:WARP_SIZE, 0:local_size])
-            tx = T.env_thread("threadIdx.x")
-            T.launch_thread(tx, WARP_SIZE)
 
-            T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset, 
dtype=dtype))
+            for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+                T.evaluate(T.mma_fill(local_size, C_warp.data, 
C_warp.elem_offset, dtype=dtype))
 
     return mma_fill_desc, mma_fill_impl
 
 
-def get_mma_store_intrin(dtype, local_size, scope="global"):
+MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
+TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, 
*get_mma_fill_intrin("float32", 8))
+
+MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
+TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, 
*get_mma_fill_intrin("float16", 8))
+
+MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
+TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 
8))
+
+
+def get_mma_store_intrin(dtype, local_size, scope="global", 
use_mma_store_intrinic=True):
     # Assume M = N = 16
     index_map = shared_16x16_to_ldmatrix_32x8_layout
+    index_map_rev = ldmatrix_32x8_to_shared_16x16_layout
 
     @T.prim_func
     def mma_store_desc(a: T.handle, c: T.handle) -> None:
@@ -428,110 +538,183 @@ def get_mma_store_intrin(dtype, local_size, 
scope="global"):
                     T.writes(C[v0, v1])
                     C[v0, v1] = C_warp[thread_id, local_id]
 
-    @T.prim_func
-    def mma_store_impl(a: T.handle, c: T.handle) -> None:
-        s0 = T.int32()
-        s1 = T.int32()
+    if use_mma_store_intrinic:
 
-        C_warp = T.match_buffer(
-            a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
-        )
-        C = T.match_buffer(
-            c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, 
strides=[s0, s1]
-        )
+        @T.prim_func
+        def mma_store_impl(a: T.handle, c: T.handle) -> None:
+            s0 = T.int32()
+            s1 = T.int32()
 
-        with T.block("root"):
-            T.reads(C_warp[0:WARP_SIZE, 0:local_size])
-            T.writes(C[0:M_DIM, 0:N_DIM])
-            tx = T.env_thread("threadIdx.x")
-            T.launch_thread(tx, WARP_SIZE)
+            C_warp = T.match_buffer(
+                a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
+            )
+            C = T.match_buffer(
+                c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, 
strides=[s0, s1]
+            )
 
-            T.evaluate(
-                T.mma_store(
-                    M_DIM,
-                    N_DIM,
-                    C.access_ptr("w"),
-                    C_warp.data,
-                    C_warp.elem_offset,
-                    s0,
-                    dtype=dtype,
-                )
+            with T.block("root"):
+                T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+                T.writes(C[0:M_DIM, 0:N_DIM])
+
+                for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+                    T.evaluate(
+                        T.mma_store(
+                            M_DIM,
+                            N_DIM,
+                            C.access_ptr("w"),
+                            C_warp.data,
+                            C_warp.elem_offset,
+                            s0,
+                            dtype=dtype,
+                        )
+                    )
+
+    else:
+
+        @T.prim_func
+        def mma_store_impl(a: T.handle, c: T.handle) -> None:
+            s0 = T.int32()
+            s1 = T.int32()
+
+            C_warp = T.match_buffer(
+                a, [WARP_SIZE, local_size], dtype=dtype, scope="warp", 
offset_factor=1
+            )
+            C = T.match_buffer(
+                c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1, 
strides=[s0, s1]
             )
 
+            with T.block("root"):
+                T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+                T.writes(C[0:M_DIM, 0:N_DIM])
+
+                for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+                    for local_id in T.serial(local_size):
+                        row, col = T.meta_var(index_map_rev(tx, local_id))
+                        C[row, col] = C_warp[tx, local_id]
+
     return mma_store_desc, mma_store_impl
 
 
-LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
-TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16, 
"float16", False, False))
+MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_"
+TensorIntrin.register(
+    MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, 
"global", True)
+)
 
-LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
-TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16, 
"float16", True, False))
+MMA_store_16x16_f32_shared_dyn_INTRIN = "mma_store_16x16_f32_shared_dyn_"
+TensorIntrin.register(
+    MMA_store_16x16_f32_shared_dyn_INTRIN, *get_mma_store_intrin("float32", 8, 
"shared.dyn", True)
+)
 
-LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn"
+MMA_store_16x16_f32_shared_dyn_INTRIN_SIMPLE = 
"mma_store_16x16_f32_shared_dyn_simple_"
 TensorIntrin.register(
-    LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False, 
False, "shared.dyn")
+    MMA_store_16x16_f32_shared_dyn_INTRIN_SIMPLE,
+    *get_mma_store_intrin("float32", 8, "shared.dyn", False),
 )
 
-LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn"
+MMA_store_16x16_f16_shared_dyn_INTRIN_SIMPLE = 
"mma_store_16x16_f16_shared_dyn_simple_"
 TensorIntrin.register(
-    LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True, 
False, "shared.dyn")
+    MMA_store_16x16_f16_shared_dyn_INTRIN_SIMPLE,
+    *get_mma_store_intrin("float16", 8, "shared.dyn", False),
 )
 
-LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
+MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_"
 TensorIntrin.register(
-    LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True, 
True)
+    MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, 
"global", True)
 )
 
-LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a"
-TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32, 
"int8", False, False))
+MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
+TensorIntrin.register(
+    MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, 
"global", True)
+)
 
-LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b"
-TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32, 
"int8", True, False))
 
-LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
-TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32, 
"int8", True, True))
+def get_mma_intrin_group(
+    load_scope: Literal["shared", "shared.dyn"],
+    store_scope: Literal["global", "shared", "shared.dyn"],
+    in_dtype: Literal["float16", "int8"],
+    out_dtype: Literal["float16", "float32", "int32"],
+    trans_a: bool,
+    trans_b: bool,
+    not_use_mma_store_intrinic: bool = True,
+    store_to_smem_dtype: Optional[Literal["float16", "float32", "int32"]] = 
None,
+) -> Dict[str, str]:
+    """Get a group of intrinsics for mma tensor core with the given 
configurations
 
-MMA_f16f16f32_INTRIN = "mma_f16f16f32"
-TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32", 
False))
+    Parameters
+    ----------
+    load_scope : Literal["shared", "shared.dyn"]
+        The memory scope of the input buffer.
 
-MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans"
-TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16, 
"float32", True))
+    store_scope : Literal["global", "shared", "shared.dyn"]
+        The memory scope of the result buffer.
 
-MMA_f16f16f16_INTRIN = "mma_f16f16f16"
-TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16", 
False))
+    in_dtype : str
+        The input data type.
+
+    out_dtype : str
+        The output data dtype.
 
-MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans"
-TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16, 
"float16", True))
+    trans_a : bool
+        Whether the input matrix A is transposed.
 
-MMA_i8i8i32_INTRIN = "mma_i8i8i32"
-TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False))
+    trans_b : bool
+        Whether the input matrix B is transposed.
 
-MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
-TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32", 
True))
+    not_use_mma_store_intrinic : bool
+        Whether to not use the mma_store intrinsic. If True, use BufferStore 
stmts to store the
+        result of mma. Otherwise, use mma_store intrinsic.
 
-MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
-TensorIntrin.register(MMA_fill_16x16_f32_INTRIN, 
*get_mma_fill_intrin("float32", 8))
+        This is because if we use mma_store intrinsic, during swizzling shared 
memory visits, our
+        rearrangement scheme will involve areas accessed by different 
mma_store calls. This makes
+        swizzling quite complex. But BufferStore will not face this problem.
 
-MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
-TensorIntrin.register(MMA_fill_16x16_f16_INTRIN, 
*get_mma_fill_intrin("float16", 8))
+    store_to_smem_dtype : Optional[Literal["float16", "float32", "int32"]]
+        The dtype that we use to store from register to shared memory. By 
default it is out_dtype.
 
-MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
-TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32", 
8))
+    Returns
+    -------
+    ret : Dict[str, str]
+        A group of tensor intrinsics.
+    """
+    assert load_scope in ["shared", "shared.dyn"]
+    assert store_scope in ["global", "shared", "shared.dyn"]
+    assert in_dtype in ["float16", "int8"]
+    assert out_dtype in ["float16", "float32", "int32"]
 
-MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_"
-TensorIntrin.register(
-    MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8, 
"global")
-)
+    shape = "16x16"
 
-MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_"
-TensorIntrin.register(
-    MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8, 
"global")
-)
+    dtype_mapping = {"float16": "f16", "float32": "f32", "int8": "i8", 
"int32": "i32"}
+    in_dtype = dtype_mapping[in_dtype]
+    out_dtype = dtype_mapping[out_dtype]
 
-MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
-TensorIntrin.register(
-    MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8, 
"global")
-)
+    # e.g. mma_fill_16x16_f32
+    init_intrin = f"mma_fill_{shape}_{out_dtype}"
+
+    # e.g. mma_ldmatrix_f16_a_trans_dyn, mma_ldmatrix_f16_b_trans_dyn
+    trans_a = "_trans" if trans_a else ""
+    trans_b = "_trans" if trans_b else ""
+    load_scope = "_dyn" if load_scope == "shared.dyn" else ""
+    load_a_intrin = f"mma_ldmatrix_{in_dtype}_a{trans_a}{load_scope}"
+    load_b_intrin = f"mma_ldmatrix_{in_dtype}_b{trans_b}{load_scope}"
+
+    # e.g. mma_f16f16f32_trans_a_trans_b
+    trans_a_str = trans_a + "_a" if trans_a != "" else ""
+    trans_b_str = trans_b + "_b" if trans_b != "" else ""
+    compute_intrin = 
f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{trans_b_str}"
+
+    # e.g. mma_store_16x16_f32_shared_dyn_simple_
+    store_scope = store_scope.replace(".", "_")
+    store_to_smem_dtype = dtype_mapping[store_to_smem_dtype] if 
store_to_smem_dtype else out_dtype
+    suffix = "simple_" if not_use_mma_store_intrinic else ""
+    store_intrin = 
f"mma_store_{shape}_{store_to_smem_dtype}_{store_scope}_{suffix}"
+
+    return {
+        "init": init_intrin,
+        "load_a": load_a_intrin,
+        "load_b": load_b_intrin,
+        "compute": compute_intrin,
+        "store": store_intrin,
+    }
 
 
 ######## WMMA intrinsics ########
@@ -1235,11 +1418,11 @@ def get_mma_init_intrin(
         with T.block("root"):
             T.reads()
             T.writes(dst[0:m_dim, 0:n_dim])
-            tx = T.env_thread("threadIdx.x")
-            T.launch_thread(tx, 32)
-            for b in range(m_dim // 8):
-                for v in T.vectorized(n_dim // 4):
-                    dst[b * 8 + tx // 4, (tx % 4) * (n_dim // 4) + v] = zero
+
+            for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+                for b in range(m_dim // 8):
+                    for v in T.vectorized(n_dim // 4):
+                        dst[b * 8 + tx // 4, (tx % 4) * (n_dim // 4) + v] = 
zero
 
     return mma_init_desc, mma_init_impl
 
@@ -1310,21 +1493,19 @@ def get_mma_load_intrin(
             T.reads(src[0:frag_m, 0:frag_n])
             T.writes(dst[0:frag_m, 0:frag_n])
 
-            tx = T.env_thread("threadIdx.x")
-            T.launch_thread(tx, 32)
-
-            T.evaluate(
-                T.ptx_ldmatrix(
-                    trans,
-                    4,  # Always load 4 matrices
-                    ".b16",
-                    dst.data,
-                    get_index(dst.elem_offset, d0),
-                    src.access_ptr("r"),
-                    get_tx_index(tx, s0),
-                    dtype=dtype,
+            for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+                T.evaluate(
+                    T.ptx_ldmatrix(
+                        trans,
+                        4,  # Always load 4 matrices
+                        ".b16",
+                        dst.data,
+                        get_index(dst.elem_offset, d0),
+                        src.access_ptr("r"),
+                        get_tx_index(tx, s0),
+                        dtype=dtype,
+                    )
                 )
-            )
 
     return mma_load_desc, mma_load_impl
 
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py 
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
similarity index 88%
rename from tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
rename to 
tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
index 2a853a2431..d704dc2438 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
@@ -22,21 +22,21 @@ import tvm.testing
 from tvm import te
 from tvm.testing.tir import mma_schedule
 from tvm.tir.tensor_intrin.cuda import (
-    LDMATRIX_16x16_A_INTRIN,
-    LDMATRIX_16x16_B_INTRIN,
-    LDMATRIX_16x16_B_TRANS_INTRIN,
-    LDMATRIX_16x32_A_INTRIN,
-    LDMATRIX_16x32_B_TRANS_INTRIN,
-    LDMATRIX_32x16_B_INTRIN,
+    LDMATRIX_f16_A_INTRIN,
+    LDMATRIX_f16_B_INTRIN,
+    LDMATRIX_f16_B_TRANS_INTRIN,
+    LDMATRIX_i8_A_INTRIN,
+    LDMATRIX_i8_B_TRANS_INTRIN,
+    LDMATRIX_i8_B_INTRIN,
     MMA_f16f16f16_INTRIN,
-    MMA_f16f16f16_TRANS_INTRIN,
+    MMA_f16f16f16_TRANS_B_INTRIN,
     MMA_f16f16f32_INTRIN,
-    MMA_f16f16f32_TRANS_INTRIN,
+    MMA_f16f16f32_TRANS_B_INTRIN,
     MMA_fill_16x16_f16_INTRIN,
     MMA_fill_16x16_f32_INTRIN,
     MMA_fill_16x16_i32_INTRIN,
     MMA_i8i8i32_INTRIN,
-    MMA_i8i8i32_TRANS_INTRIN,
+    MMA_i8i8i32_TRANS_B_INTRIN,
     MMA_store_16x16_f16_global_INTRIN,
     MMA_store_16x16_f32_global_INTRIN,
     MMA_store_16x16_i32_global_INTRIN,
@@ -116,15 +116,15 @@ def run_test(
     dev = tvm.device("cuda", 0)
 
     if in_dtype == "float16":
-        a_np = np.random.uniform(size=(M, K)).astype("float16")
+        a_np = np.random.normal(size=(M, K)).astype("float16")
 
         if b_transposed:
-            b_np = np.random.uniform(size=(N, K)).astype("float16")
+            b_np = np.random.normal(size=(N, K)).astype("float16")
             c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32").transpose()).astype(
                 out_dtype
             )
         else:
-            b_np = np.random.uniform(size=(K, N)).astype("float16")
+            b_np = np.random.normal(size=(K, N)).astype("float16")
             c_np = np.dot(a_np.astype("float32"), 
b_np.astype("float32")).astype(out_dtype)
     else:
         a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
@@ -147,7 +147,7 @@ def run_test(
     if out_dtype != "float16":
         # The numpy reference is computed with fp32 precision (otherwise too 
slow).
         # So there is non-trivial accuracy difference if TVM result is 
computed with fp16 accumulation.
-        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
+        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2)
 
     return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
 
@@ -177,8 +177,8 @@ def test_f16f16f32_m16n16k16():
         index_map,
         index_map,
         index_map,
-        LDMATRIX_16x16_A_INTRIN,
-        LDMATRIX_16x16_B_INTRIN,
+        LDMATRIX_f16_A_INTRIN,
+        LDMATRIX_f16_B_INTRIN,
         MMA_f16f16f32_INTRIN,
         MMA_fill_16x16_f32_INTRIN,
         MMA_store_16x16_f32_global_INTRIN,
@@ -198,9 +198,9 @@ def test_f16f16f32_m16n16k16():
         index_map,
         index_map,
         index_map,
-        LDMATRIX_16x16_A_INTRIN,
-        LDMATRIX_16x16_B_TRANS_INTRIN,
-        MMA_f16f16f32_TRANS_INTRIN,
+        LDMATRIX_f16_A_INTRIN,
+        LDMATRIX_f16_B_TRANS_INTRIN,
+        MMA_f16f16f32_TRANS_B_INTRIN,
         MMA_fill_16x16_f32_INTRIN,
         MMA_store_16x16_f32_global_INTRIN,
     )
@@ -234,8 +234,8 @@ def test_f16f16f16_m16n16k16():
         index_map,
         index_map,
         index_map,
-        LDMATRIX_16x16_A_INTRIN,
-        LDMATRIX_16x16_B_INTRIN,
+        LDMATRIX_f16_A_INTRIN,
+        LDMATRIX_f16_B_INTRIN,
         MMA_f16f16f16_INTRIN,
         MMA_fill_16x16_f16_INTRIN,
         MMA_store_16x16_f16_global_INTRIN,
@@ -255,9 +255,9 @@ def test_f16f16f16_m16n16k16():
         index_map,
         index_map,
         index_map,
-        LDMATRIX_16x16_A_INTRIN,
-        LDMATRIX_16x16_B_TRANS_INTRIN,
-        MMA_f16f16f16_TRANS_INTRIN,
+        LDMATRIX_f16_A_INTRIN,
+        LDMATRIX_f16_B_TRANS_INTRIN,
+        MMA_f16f16f16_TRANS_B_INTRIN,
         MMA_fill_16x16_f16_INTRIN,
         MMA_store_16x16_f16_global_INTRIN,
     )
@@ -305,8 +305,8 @@ def test_i8i8i32_m16n16k32():
         index_map_A,
         index_map_B,
         index_map_C,
-        LDMATRIX_16x32_A_INTRIN,
-        LDMATRIX_32x16_B_INTRIN,
+        LDMATRIX_i8_A_INTRIN,
+        LDMATRIX_i8_B_INTRIN,
         MMA_i8i8i32_INTRIN,
         MMA_fill_16x16_i32_INTRIN,
         MMA_store_16x16_i32_global_INTRIN,
@@ -326,9 +326,9 @@ def test_i8i8i32_m16n16k32():
         index_map_A,
         index_map_A,
         index_map_C,
-        LDMATRIX_16x32_A_INTRIN,
-        LDMATRIX_16x32_B_TRANS_INTRIN,
-        MMA_i8i8i32_TRANS_INTRIN,
+        LDMATRIX_i8_A_INTRIN,
+        LDMATRIX_i8_B_TRANS_INTRIN,
+        MMA_i8i8i32_TRANS_B_INTRIN,
         MMA_fill_16x16_i32_INTRIN,
         MMA_store_16x16_i32_global_INTRIN,
     )
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_mfma.py 
b/tests/python/unittest/test_tir_schedule_tensorize_mfma_numeric.py
similarity index 100%
rename from tests/python/unittest/test_tir_schedule_tensorize_mfma.py
rename to tests/python/unittest/test_tir_schedule_tensorize_mfma_numeric.py
diff --git 
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py 
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index a013cf0f65..bc3e979f94 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -26,8 +26,8 @@ from tvm.meta_schedule.testing import te_workload
 from tvm.script import tir as T
 from tvm.testing.tir import mma_schedule
 from tvm.tir.tensor_intrin.cuda import (
-    LDMATRIX_16x16_A_DYN_INTRIN,
-    LDMATRIX_16x16_B_DYN_INTRIN,
+    LDMATRIX_f16_A_DYN_INTRIN,
+    LDMATRIX_f16_B_DYN_INTRIN,
     MMA_f16f16f32_INTRIN,
     MMA_fill_16x16_f32_INTRIN,
     MMA_store_16x16_f32_global_INTRIN,
@@ -1520,8 +1520,8 @@ def get_mma_schedule():
         index_map,
         index_map,
         index_map,
-        LDMATRIX_16x16_A_DYN_INTRIN,
-        LDMATRIX_16x16_B_DYN_INTRIN,
+        LDMATRIX_f16_A_DYN_INTRIN,
+        LDMATRIX_f16_B_DYN_INTRIN,
         MMA_f16f16f32_INTRIN,
         MMA_fill_16x16_f32_INTRIN,
         MMA_store_16x16_f32_global_INTRIN,


Reply via email to