This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new db4290b608 [TIR] Support more mma intrinsics and
`get_mma_intrin_group` utility (#16073)
db4290b608 is described below
commit db4290b608a66d812ec49cf147440fe930ab2b6f
Author: Yixin Dong <[email protected]>
AuthorDate: Tue Nov 7 13:22:38 2023 -0800
[TIR] Support more mma intrinsics and `get_mma_intrin_group` utility
(#16073)
* 1104
* 1104
* 1105
* fix ci
* fix ci
---
python/tvm/tir/tensor_intrin/cuda.py | 571 ++++++++++++++-------
...tir_schedule_tensorize_ldmatrix_mma_numeric.py} | 56 +-
...=> test_tir_schedule_tensorize_mfma_numeric.py} | 0
.../test_tir_transform_inject_software_pipeline.py | 8 +-
4 files changed, 408 insertions(+), 227 deletions(-)
diff --git a/python/tvm/tir/tensor_intrin/cuda.py
b/python/tvm/tir/tensor_intrin/cuda.py
index 6ee00ee634..409a1ff10a 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -14,18 +14,15 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=invalid-name,missing-function-docstring
+# pylint: disable=invalid-name,missing-function-docstring,unused-variable
"""Intrinsics for tensorization on NVIDIA GPU."""
-from typing import Dict, Tuple
-
-from typing_extensions import Literal
+from typing import Dict, Optional, Tuple, Literal
+from tvm._ffi import register_func
+from tvm.runtime import convert
from tvm.script import tir as T
from tvm.tir.function import PrimFunc
-
-from ..._ffi import register_func
-from ...runtime import convert
-from .. import Cast, IntImm, TensorIntrin
+from tvm.tir import Cast, IntImm, TensorIntrin
def shared_16x16_to_ldmatrix_32x8_layout(i, j):
@@ -43,6 +40,12 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
+def ldmatrix_32x8_to_shared_16x16_layout(thread_id, local_id):
+ row = 8 * (local_id % 4 // 2) + (thread_id // 4)
+ col = 8 * (local_id // 4) + (thread_id % 4) * 2 + (local_id % 2)
+ return row, col
+
+
@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
i, j = ind[0], ind[1]
@@ -59,70 +62,94 @@ HALF_WARP = WARP_SIZE // 2
HALF_WARP_expr = lift(HALF_WARP)
-def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"):
+def get_ldmatrix_intrin(
+ k_dim: int,
+ dtype: str,
+ matrix_name: Literal["A", "B"],
+ transposed: bool,
+ shared_scope: str = "shared",
+):
local_size = (M_DIM * k_dim) // WARP_SIZE
- shared_offset = None
+ smem_offset = None
index_map = None
- if transposed:
- assert is_b, "Transposed A matrix not supported"
-
- ldmatrix_col_major = is_b and not transposed
+ if matrix_name == "A":
+ transpose_in_ldmatrix = transposed
+ # transpose_layout_for_ldmatrix_input: Every thread loads 8 bytes
data. This determines
+ # which 8 bytes every thread loads.
+ # If transpose_layout_for_ldmatrix_input is False, the load pattern is
+ # T0 T0 T0 T0 T16 T16 T16 T16
+ # T1 T1 T1 T1 T17 T17 T17 T17
+ # ...
+ # T8 T8 T8 T8 T24 T24 T24 T24
+ # T9 T9 T9 T9 T25 T25 T25 T25
+ # ...
+ # T15 T15 T15 T15 T31 T31 T31 T31
+ # Otherwise, the load pattern is
+ # T0 T0 T0 T0 T8 T8 T8 T8
+ # T1 T1 T1 T1 T9 T9 T9 T9
+ # ...
+ # T7 T7 T7 T7 T15 T15 T15 T15
+ # T16 T16 T16 T16 T24 T24 T24 T24
+ # T17 T17 T17 T17 T25 T25 T25 T25
+ # ...
+ # T23 T23 T23 T23 T31 T31 T31 T31
+ transpose_layout_for_ldmatrix_input = transposed
+ smem_tile_row, smem_tile_col = (M_DIM, k_dim) if not transposed else
(k_dim, M_DIM)
+ else:
+ assert matrix_name == "B"
+ transpose_in_ldmatrix = not transposed
+ transpose_layout_for_ldmatrix_input = transposed
+ smem_tile_row, smem_tile_col = (k_dim, N_DIM) if not transposed else
(N_DIM, k_dim)
if k_dim == 16:
assert dtype == "float16"
index_map = shared_16x16_to_ldmatrix_32x8_layout
- if transposed:
- shared_offset = (
+ if transpose_layout_for_ldmatrix_input:
+ smem_offset = (
lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+ stride * (tx % 8)
+ 8 * ((tx % HALF_WARP_expr) // 8)
)
else:
- shared_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr)
+ 8 * (
+ smem_offset = lambda tx, stride: stride * (tx % HALF_WARP_expr) +
8 * (
tx // HALF_WARP_expr
)
else:
+ # TODO(yixin): Support TN and TT matmul for int8
+ assert (
+ matrix_name == "B" or not transposed
+ ), "Now only B matrix can be transposed for int8 matmul"
assert (
k_dim == 32 and dtype == "int8"
), "Only k_dim == 16 (float16) or k_dim == 32 (int8) supported for now"
- if ldmatrix_col_major:
+ if matrix_name == "B" and not transposed:
index_map = shared_32x16_to_ldmatrix_32x16_layout
# A dummy offset, ldmatrix cannot be used for int8 + trans case.
# We still use the ldmatrix intrinsic, but lower it to a manual
loop in the codegen.
# Only the stride information is required.
- shared_offset = lambda _, stride: stride
- elif is_b and transposed:
+ smem_offset = lambda _, stride: stride
+ elif matrix_name == "B" and transposed:
index_map = shared_16x32_to_ldmatrix_32x16_layout
- shared_offset = (
+ smem_offset = (
lambda tx, stride: stride * 8 * (tx // HALF_WARP_expr)
+ (tx % 8) * stride
+ 16 * ((tx % HALF_WARP_expr) // 8)
)
- else:
+ else: # A, not transposed
index_map = shared_16x32_to_ldmatrix_32x16_layout
- shared_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx
// 16)
-
- assert index_map and shared_offset
-
- if is_b and not transposed:
- row_dim = k_dim
- col_dim = M_DIM
- else:
- row_dim = M_DIM
- col_dim = k_dim
+ smem_offset = lambda tx, stride: stride * (tx % 16) + 16 * (tx //
16)
- shmem_shape = (row_dim, col_dim)
- offset_factor = col_dim
+ offset_factor = smem_tile_col
@T.prim_func
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
shared = T.match_buffer(
shared_handle,
- shmem_shape,
+ (smem_tile_row, smem_tile_col),
dtype,
align=64,
offset_factor=offset_factor,
@@ -138,10 +165,10 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
)
with T.block("root"):
- T.reads(shared[0:row_dim, 0:col_dim])
+ T.reads(shared[0:smem_tile_row, 0:smem_tile_col])
T.writes(warp[0:WARP_SIZE, 0:local_size])
- for ax0, ax1 in T.grid(row_dim, col_dim):
+ for ax0, ax1 in T.grid(smem_tile_row, smem_tile_col):
with T.block("shared_warp"):
v0, v1 = T.axis.remap("SS", [ax0, ax1])
T.reads(shared[v0, v1])
@@ -156,7 +183,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
s1 = T.int32()
shared = T.match_buffer(
shared_handle,
- shmem_shape,
+ (smem_tile_row, smem_tile_col),
dtype,
align=64,
offset_factor=offset_factor,
@@ -173,28 +200,68 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
)
with T.block("root"):
- T.reads(shared[0:row_dim, 0:col_dim])
+ T.reads(shared[0:smem_tile_row, 0:smem_tile_col])
T.writes(warp[0:WARP_SIZE, 0:local_size])
- tx = T.env_thread("threadIdx.x")
- T.launch_thread(tx, WARP_SIZE)
-
- T.evaluate(
- T.ptx_ldmatrix(
- ldmatrix_col_major,
- 4, # Always load 4 matrices
- ".b16",
- warp.data,
- warp.elem_offset + lift(local_size) * tx,
- shared.access_ptr("r"),
- shared_offset(tx, s0),
- dtype=dtype,
+ for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+ T.evaluate(
+ T.ptx_ldmatrix(
+ transpose_in_ldmatrix,
+ 4, # Always load 4 matrices
+ ".b16",
+ warp.data,
+ warp.elem_offset + lift(local_size) * tx,
+ shared.access_ptr("r"),
+ smem_offset(tx, s0),
+ dtype=dtype,
+ )
)
- )
return ldmatrix_desc, ldmatrix_impl
-def get_mma_intrin(k_dim, out_dtype, b_transposed):
+LDMATRIX_f16_A_INTRIN = "mma_ldmatrix_f16_a"
+TensorIntrin.register(LDMATRIX_f16_A_INTRIN, *get_ldmatrix_intrin(16,
"float16", "A", False))
+
+LDMATRIX_f16_B_INTRIN = "mma_ldmatrix_f16_b"
+TensorIntrin.register(LDMATRIX_f16_B_INTRIN, *get_ldmatrix_intrin(16,
"float16", "B", False))
+
+LDMATRIX_f16_A_TRANS_INTRIN = "mma_ldmatrix_f16_a_trans"
+TensorIntrin.register(LDMATRIX_f16_A_TRANS_INTRIN, *get_ldmatrix_intrin(16,
"float16", "A", True))
+
+LDMATRIX_f16_B_TRANS_INTRIN = "mma_ldmatrix_f16_b_trans"
+TensorIntrin.register(LDMATRIX_f16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16,
"float16", "B", True))
+
+LDMATRIX_f16_A_DYN_INTRIN = "mma_ldmatrix_f16_a_dyn"
+TensorIntrin.register(
+ LDMATRIX_f16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A", False,
"shared.dyn")
+)
+
+LDMATRIX_f16_B_DYN_INTRIN = "mma_ldmatrix_f16_b_dyn"
+TensorIntrin.register(
+ LDMATRIX_f16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B", False,
"shared.dyn")
+)
+
+LDMATRIX_f16_A_TRANS_DYN_INTRIN = "mma_ldmatrix_f16_a_trans_dyn"
+TensorIntrin.register(
+ LDMATRIX_f16_A_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "A",
True, "shared.dyn")
+)
+
+LDMATRIX_f16_B_TRANS_DYN_INTRIN = "mma_ldmatrix_f16_b_trans_dyn"
+TensorIntrin.register(
+ LDMATRIX_f16_B_TRANS_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", "B",
True, "shared.dyn")
+)
+
+LDMATRIX_i8_A_INTRIN = "mma_ldmatrix_i8_a"
+TensorIntrin.register(LDMATRIX_i8_A_INTRIN, *get_ldmatrix_intrin(32, "int8",
"A", False))
+
+LDMATRIX_i8_B_INTRIN = "mma_ldmatrix_i8_b"
+TensorIntrin.register(LDMATRIX_i8_B_INTRIN, *get_ldmatrix_intrin(32, "int8",
"B", False))
+
+LDMATRIX_i8_B_TRANS_INTRIN = "mma_ldmatrix_i8_b_trans"
+TensorIntrin.register(LDMATRIX_i8_B_TRANS_INTRIN, *get_ldmatrix_intrin(32,
"int8", "B", True))
+
+
+def get_mma_intrin(k_dim, out_dtype, a_transposed, b_transposed):
local_size = (M_DIM * k_dim) // WARP_SIZE
local_size_out = (M_DIM * N_DIM) // 32
@@ -223,18 +290,16 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
in_dtype = "int8"
in_dtype_abbrv = "int8"
- def maybe_cast(v):
+ def cast_to_out_dtype(v):
if out_dtype in ["float32", "int32"]:
return Cast(out_dtype, v)
return v
- def maybe_swap(i, j):
- if b_transposed:
- return j, i
- return i, j
+ def swap_if_flag(i, j, flag):
+ return (j, i) if flag else (i, j)
- A_offset_factor = k_dim
- B_offset_factor = maybe_swap(k_dim, N_DIM)[-1]
+ A_offset_factor = M_DIM if a_transposed else k_dim
+ B_offset_factor = k_dim if b_transposed else N_DIM
out_offset_factor = N_DIM
@T.prim_func
@@ -275,10 +340,11 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
for i, j, k in T.grid(M_DIM, N_DIM, k_dim):
with T.block("C"):
i, j, k = T.axis.remap("SSR", [i, j, k])
- b_row_ind, b_col_ind = T.meta_var(maybe_swap(k, j))
+ a_row_ind, a_col_ind = T.meta_var(swap_if_flag(i, k,
a_transposed))
+ b_row_ind, b_col_ind = T.meta_var(swap_if_flag(k, j,
b_transposed))
thread_id_C, local_id_C = T.meta_var(index_map_C(i, j))
- thread_id_A, local_id_A = T.meta_var(index_map_A(i, k))
+ thread_id_A, local_id_A =
T.meta_var(index_map_A(a_row_ind, a_col_ind))
thread_id_B, local_id_B =
T.meta_var(index_map_B(b_row_ind, b_col_ind))
T.reads(
@@ -288,9 +354,9 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
)
T.writes(C[thread_id_C, local_id_C])
- C[thread_id_C, local_id_C] += maybe_cast(
+ C[thread_id_C, local_id_C] += cast_to_out_dtype(
A[thread_id_A, local_id_A]
- ) * maybe_cast(B[thread_id_B, local_id_B])
+ ) * cast_to_out_dtype(B[thread_id_B, local_id_B])
@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
@@ -326,50 +392,84 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
B[0:WARP_SIZE, 0:local_size],
)
T.writes(C[0:WARP_SIZE, 0:local_size_out])
- tx = T.env_thread("threadIdx.x")
- T.launch_thread(tx, WARP_SIZE)
- T.evaluate(
- T.ptx_mma(
- mma_prefix,
- "row",
- "col",
- in_dtype_abbrv,
- in_dtype_abbrv,
- out_dtype_abbrv,
- A.data,
- A.elem_offset + tx * lift(local_size),
- B.data,
- B.elem_offset + tx * lift(local_size),
- C.data,
- C.elem_offset + tx * lift(local_size_out),
- False,
- dtype=out_dtype,
+ for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+ T.evaluate(
+ T.ptx_mma(
+ mma_prefix,
+ "row",
+ "col",
+ in_dtype_abbrv,
+ in_dtype_abbrv,
+ out_dtype_abbrv,
+ A.data,
+ A.elem_offset + tx * lift(local_size),
+ B.data,
+ B.elem_offset + tx * lift(local_size),
+ C.data,
+ C.elem_offset + tx * lift(local_size_out),
+ False,
+ dtype=out_dtype,
+ )
)
- )
- T.evaluate(
- T.ptx_mma(
- mma_prefix,
- "row",
- "col",
- in_dtype_abbrv,
- in_dtype_abbrv,
- out_dtype_abbrv,
- A.data,
- A.elem_offset + tx * lift(local_size),
- B.data,
- B.elem_offset + tx * lift(local_size) + lift(local_size)
// 2,
- C.data,
- C.elem_offset + tx * lift(local_size_out) +
lift(local_size_out) // 2,
- False,
- dtype=out_dtype,
+ T.evaluate(
+ T.ptx_mma(
+ mma_prefix,
+ "row",
+ "col",
+ in_dtype_abbrv,
+ in_dtype_abbrv,
+ out_dtype_abbrv,
+ A.data,
+ A.elem_offset + tx * lift(local_size),
+ B.data,
+ B.elem_offset + tx * lift(local_size) +
lift(local_size) // 2,
+ C.data,
+ C.elem_offset + tx * lift(local_size_out) +
lift(local_size_out) // 2,
+ False,
+ dtype=out_dtype,
+ )
)
- )
return mma_sync_desc, mma_sync_impl
+MMA_f16f16f32_INTRIN = "mma_f16f16f32"
+TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32",
False, False))
+
+MMA_f16f16f32_TRANS_B_INTRIN = "mma_f16f16f32_trans_b"
+TensorIntrin.register(MMA_f16f16f32_TRANS_B_INTRIN, *get_mma_intrin(16,
"float32", False, True))
+
+MMA_f16f16f32_TRANS_A_INTRIN = "mma_f16f16f32_trans_a"
+TensorIntrin.register(MMA_f16f16f32_TRANS_A_INTRIN, *get_mma_intrin(16,
"float32", True, False))
+
+MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f32_trans_a_trans_b"
+TensorIntrin.register(
+ MMA_f16f16f32_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float32", True,
True)
+)
+
+MMA_f16f16f16_INTRIN = "mma_f16f16f16"
+TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16",
False, False))
+
+MMA_f16f16f16_TRANS_B_INTRIN = "mma_f16f16f16_trans_b"
+TensorIntrin.register(MMA_f16f16f16_TRANS_B_INTRIN, *get_mma_intrin(16,
"float16", False, True))
+
+MMA_f16f16f16_TRANS_A_INTRIN = "mma_f16f16f16_trans_a"
+TensorIntrin.register(MMA_f16f16f16_TRANS_A_INTRIN, *get_mma_intrin(16,
"float16", True, False))
+
+MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN = "mma_f16f16f16_trans_a_trans_b"
+TensorIntrin.register(
+ MMA_f16f16f16_TRANS_A_TRANS_B_INTRIN, *get_mma_intrin(16, "float16", True,
True)
+)
+
+MMA_i8i8i32_INTRIN = "mma_i8i8i32"
+TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False,
False))
+
+MMA_i8i8i32_TRANS_B_INTRIN = "mma_i8i8i32_trans_b"
+TensorIntrin.register(MMA_i8i8i32_TRANS_B_INTRIN, *get_mma_intrin(32, "int32",
False, True))
+
+
def get_mma_fill_intrin(dtype, local_size):
zero = IntImm("int32", 0).astype(dtype)
@@ -400,17 +500,27 @@ def get_mma_fill_intrin(dtype, local_size):
with T.block("root"):
T.reads()
T.writes(C_warp[0:WARP_SIZE, 0:local_size])
- tx = T.env_thread("threadIdx.x")
- T.launch_thread(tx, WARP_SIZE)
- T.evaluate(T.mma_fill(local_size, C_warp.data, C_warp.elem_offset,
dtype=dtype))
+ for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+ T.evaluate(T.mma_fill(local_size, C_warp.data,
C_warp.elem_offset, dtype=dtype))
return mma_fill_desc, mma_fill_impl
-def get_mma_store_intrin(dtype, local_size, scope="global"):
+MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
+TensorIntrin.register(MMA_fill_16x16_f32_INTRIN,
*get_mma_fill_intrin("float32", 8))
+
+MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
+TensorIntrin.register(MMA_fill_16x16_f16_INTRIN,
*get_mma_fill_intrin("float16", 8))
+
+MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
+TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32",
8))
+
+
+def get_mma_store_intrin(dtype, local_size, scope="global",
use_mma_store_intrinic=True):
# Assume M = N = 16
index_map = shared_16x16_to_ldmatrix_32x8_layout
+ index_map_rev = ldmatrix_32x8_to_shared_16x16_layout
@T.prim_func
def mma_store_desc(a: T.handle, c: T.handle) -> None:
@@ -428,110 +538,183 @@ def get_mma_store_intrin(dtype, local_size,
scope="global"):
T.writes(C[v0, v1])
C[v0, v1] = C_warp[thread_id, local_id]
- @T.prim_func
- def mma_store_impl(a: T.handle, c: T.handle) -> None:
- s0 = T.int32()
- s1 = T.int32()
+ if use_mma_store_intrinic:
- C_warp = T.match_buffer(
- a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
- )
- C = T.match_buffer(
- c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1,
strides=[s0, s1]
- )
+ @T.prim_func
+ def mma_store_impl(a: T.handle, c: T.handle) -> None:
+ s0 = T.int32()
+ s1 = T.int32()
- with T.block("root"):
- T.reads(C_warp[0:WARP_SIZE, 0:local_size])
- T.writes(C[0:M_DIM, 0:N_DIM])
- tx = T.env_thread("threadIdx.x")
- T.launch_thread(tx, WARP_SIZE)
+ C_warp = T.match_buffer(
+ a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
+ )
+ C = T.match_buffer(
+ c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1,
strides=[s0, s1]
+ )
- T.evaluate(
- T.mma_store(
- M_DIM,
- N_DIM,
- C.access_ptr("w"),
- C_warp.data,
- C_warp.elem_offset,
- s0,
- dtype=dtype,
- )
+ with T.block("root"):
+ T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+ T.writes(C[0:M_DIM, 0:N_DIM])
+
+ for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+ T.evaluate(
+ T.mma_store(
+ M_DIM,
+ N_DIM,
+ C.access_ptr("w"),
+ C_warp.data,
+ C_warp.elem_offset,
+ s0,
+ dtype=dtype,
+ )
+ )
+
+ else:
+
+ @T.prim_func
+ def mma_store_impl(a: T.handle, c: T.handle) -> None:
+ s0 = T.int32()
+ s1 = T.int32()
+
+ C_warp = T.match_buffer(
+ a, [WARP_SIZE, local_size], dtype=dtype, scope="warp",
offset_factor=1
+ )
+ C = T.match_buffer(
+ c, [M_DIM, N_DIM], dtype=dtype, scope=scope, offset_factor=1,
strides=[s0, s1]
)
+ with T.block("root"):
+ T.reads(C_warp[0:WARP_SIZE, 0:local_size])
+ T.writes(C[0:M_DIM, 0:N_DIM])
+
+ for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+ for local_id in T.serial(local_size):
+ row, col = T.meta_var(index_map_rev(tx, local_id))
+ C[row, col] = C_warp[tx, local_id]
+
return mma_store_desc, mma_store_impl
-LDMATRIX_16x16_A_INTRIN = "mma.ldmatrix_16x16_a"
-TensorIntrin.register(LDMATRIX_16x16_A_INTRIN, *get_ldmatrix_intrin(16,
"float16", False, False))
+MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_"
+TensorIntrin.register(
+ MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8,
"global", True)
+)
-LDMATRIX_16x16_B_INTRIN = "mma.ldmatrix_16x16_b"
-TensorIntrin.register(LDMATRIX_16x16_B_INTRIN, *get_ldmatrix_intrin(16,
"float16", True, False))
+MMA_store_16x16_f32_shared_dyn_INTRIN = "mma_store_16x16_f32_shared_dyn_"
+TensorIntrin.register(
+ MMA_store_16x16_f32_shared_dyn_INTRIN, *get_mma_store_intrin("float32", 8,
"shared.dyn", True)
+)
-LDMATRIX_16x16_A_DYN_INTRIN = "mma.ldmatrix_16x16_a_dyn"
+MMA_store_16x16_f32_shared_dyn_INTRIN_SIMPLE =
"mma_store_16x16_f32_shared_dyn_simple_"
TensorIntrin.register(
- LDMATRIX_16x16_A_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", False,
False, "shared.dyn")
+ MMA_store_16x16_f32_shared_dyn_INTRIN_SIMPLE,
+ *get_mma_store_intrin("float32", 8, "shared.dyn", False),
)
-LDMATRIX_16x16_B_DYN_INTRIN = "mma.ldmatrix_16x16_b_dyn"
+MMA_store_16x16_f16_shared_dyn_INTRIN_SIMPLE =
"mma_store_16x16_f16_shared_dyn_simple_"
TensorIntrin.register(
- LDMATRIX_16x16_B_DYN_INTRIN, *get_ldmatrix_intrin(16, "float16", True,
False, "shared.dyn")
+ MMA_store_16x16_f16_shared_dyn_INTRIN_SIMPLE,
+ *get_mma_store_intrin("float16", 8, "shared.dyn", False),
)
-LDMATRIX_16x16_B_TRANS_INTRIN = "mma.ldmatrix_16x16_b_trans"
+MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_"
TensorIntrin.register(
- LDMATRIX_16x16_B_TRANS_INTRIN, *get_ldmatrix_intrin(16, "float16", True,
True)
+ MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8,
"global", True)
)
-LDMATRIX_16x32_A_INTRIN = "mma.ldmatrix_16x32_a"
-TensorIntrin.register(LDMATRIX_16x32_A_INTRIN, *get_ldmatrix_intrin(32,
"int8", False, False))
+MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
+TensorIntrin.register(
+ MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8,
"global", True)
+)
-LDMATRIX_32x16_B_INTRIN = "mma.ldmatrix_32x16_b"
-TensorIntrin.register(LDMATRIX_32x16_B_INTRIN, *get_ldmatrix_intrin(32,
"int8", True, False))
-LDMATRIX_16x32_B_TRANS_INTRIN = "mma.ldmatrix_16x32_b_trans"
-TensorIntrin.register(LDMATRIX_16x32_B_TRANS_INTRIN, *get_ldmatrix_intrin(32,
"int8", True, True))
+def get_mma_intrin_group(
+ load_scope: Literal["shared", "shared.dyn"],
+ store_scope: Literal["global", "shared", "shared.dyn"],
+ in_dtype: Literal["float16", "int8"],
+ out_dtype: Literal["float16", "float32", "int32"],
+ trans_a: bool,
+ trans_b: bool,
+ not_use_mma_store_intrinic: bool = True,
+ store_to_smem_dtype: Optional[Literal["float16", "float32", "int32"]] =
None,
+) -> Dict[str, str]:
+ """Get a group of intrinsics for mma tensor core with the given
configurations
-MMA_f16f16f32_INTRIN = "mma_f16f16f32"
-TensorIntrin.register(MMA_f16f16f32_INTRIN, *get_mma_intrin(16, "float32",
False))
+ Parameters
+ ----------
+ load_scope : Literal["shared", "shared.dyn"]
+ The memory scope of the input buffer.
-MMA_f16f16f32_TRANS_INTRIN = "mma_f16f16f32_trans"
-TensorIntrin.register(MMA_f16f16f32_TRANS_INTRIN, *get_mma_intrin(16,
"float32", True))
+ store_scope : Literal["global", "shared", "shared.dyn"]
+ The memory scope of the result buffer.
-MMA_f16f16f16_INTRIN = "mma_f16f16f16"
-TensorIntrin.register(MMA_f16f16f16_INTRIN, *get_mma_intrin(16, "float16",
False))
+ in_dtype : str
+ The input data type.
+
+ out_dtype : str
+ The output data dtype.
-MMA_f16f16f16_TRANS_INTRIN = "mma_f16f16f16_trans"
-TensorIntrin.register(MMA_f16f16f16_TRANS_INTRIN, *get_mma_intrin(16,
"float16", True))
+ trans_a : bool
+ Whether the input matrix A is transposed.
-MMA_i8i8i32_INTRIN = "mma_i8i8i32"
-TensorIntrin.register(MMA_i8i8i32_INTRIN, *get_mma_intrin(32, "int32", False))
+ trans_b : bool
+ Whether the input matrix B is transposed.
-MMA_i8i8i32_TRANS_INTRIN = "mma_i8i8i32_trans"
-TensorIntrin.register(MMA_i8i8i32_TRANS_INTRIN, *get_mma_intrin(32, "int32",
True))
+ not_use_mma_store_intrinic : bool
+ Whether to not use the mma_store intrinsic. If True, use BufferStore
stmts to store the
+ result of mma. Otherwise, use mma_store intrinsic.
-MMA_fill_16x16_f32_INTRIN = "mma_fill_16x16_f32"
-TensorIntrin.register(MMA_fill_16x16_f32_INTRIN,
*get_mma_fill_intrin("float32", 8))
+ This is because if we use mma_store intrinsic, during swizzling shared
memory visits, our
+ rearrangement scheme will involve areas accessed by different
mma_store calls. This makes
+ swizzling quite complex. But BufferStore will not face this problem.
-MMA_fill_16x16_f16_INTRIN = "mma_fill_16x16_f16"
-TensorIntrin.register(MMA_fill_16x16_f16_INTRIN,
*get_mma_fill_intrin("float16", 8))
+ store_to_smem_dtype : Optional[Literal["float16", "float32", "int32"]]
+ The dtype that we use to store from register to shared memory. By
default it is out_dtype.
-MMA_fill_16x16_i32_INTRIN = "mma_fill_16x16_i32"
-TensorIntrin.register(MMA_fill_16x16_i32_INTRIN, *get_mma_fill_intrin("int32",
8))
+ Returns
+ -------
+ ret : Dict[str, str]
+ A group of tensor intrinsics.
+ """
+ assert load_scope in ["shared", "shared.dyn"]
+ assert store_scope in ["global", "shared", "shared.dyn"]
+ assert in_dtype in ["float16", "int8"]
+ assert out_dtype in ["float16", "float32", "int32"]
-MMA_store_16x16_f32_global_INTRIN = "mma_store_16x16_f32_global_"
-TensorIntrin.register(
- MMA_store_16x16_f32_global_INTRIN, *get_mma_store_intrin("float32", 8,
"global")
-)
+ shape = "16x16"
-MMA_store_16x16_f16_global_INTRIN = "mma_store_16x16_f16_global_"
-TensorIntrin.register(
- MMA_store_16x16_f16_global_INTRIN, *get_mma_store_intrin("float16", 8,
"global")
-)
+ dtype_mapping = {"float16": "f16", "float32": "f32", "int8": "i8",
"int32": "i32"}
+ in_dtype = dtype_mapping[in_dtype]
+ out_dtype = dtype_mapping[out_dtype]
-MMA_store_16x16_i32_global_INTRIN = "mma_store_16x16_i32_global_"
-TensorIntrin.register(
- MMA_store_16x16_i32_global_INTRIN, *get_mma_store_intrin("int32", 8,
"global")
-)
+ # e.g. mma_fill_16x16_f32
+ init_intrin = f"mma_fill_{shape}_{out_dtype}"
+
+ # e.g. mma_ldmatrix_f16_a_trans_dyn, mma_ldmatrix_f16_b_trans_dyn
+ trans_a = "_trans" if trans_a else ""
+ trans_b = "_trans" if trans_b else ""
+ load_scope = "_dyn" if load_scope == "shared.dyn" else ""
+ load_a_intrin = f"mma_ldmatrix_{in_dtype}_a{trans_a}{load_scope}"
+ load_b_intrin = f"mma_ldmatrix_{in_dtype}_b{trans_b}{load_scope}"
+
+ # e.g. mma_f16f16f32_trans_a_trans_b
+ trans_a_str = trans_a + "_a" if trans_a != "" else ""
+ trans_b_str = trans_b + "_b" if trans_b != "" else ""
+ compute_intrin =
f"mma_{in_dtype}{in_dtype}{out_dtype}{trans_a_str}{trans_b_str}"
+
+ # e.g. mma_store_16x16_f32_shared_dyn_simple_
+ store_scope = store_scope.replace(".", "_")
+ store_to_smem_dtype = dtype_mapping[store_to_smem_dtype] if
store_to_smem_dtype else out_dtype
+ suffix = "simple_" if not_use_mma_store_intrinic else ""
+ store_intrin =
f"mma_store_{shape}_{store_to_smem_dtype}_{store_scope}_{suffix}"
+
+ return {
+ "init": init_intrin,
+ "load_a": load_a_intrin,
+ "load_b": load_b_intrin,
+ "compute": compute_intrin,
+ "store": store_intrin,
+ }
######## WMMA intrinsics ########
@@ -1235,11 +1418,11 @@ def get_mma_init_intrin(
with T.block("root"):
T.reads()
T.writes(dst[0:m_dim, 0:n_dim])
- tx = T.env_thread("threadIdx.x")
- T.launch_thread(tx, 32)
- for b in range(m_dim // 8):
- for v in T.vectorized(n_dim // 4):
- dst[b * 8 + tx // 4, (tx % 4) * (n_dim // 4) + v] = zero
+
+ for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+ for b in range(m_dim // 8):
+ for v in T.vectorized(n_dim // 4):
+ dst[b * 8 + tx // 4, (tx % 4) * (n_dim // 4) + v] =
zero
return mma_init_desc, mma_init_impl
@@ -1310,21 +1493,19 @@ def get_mma_load_intrin(
T.reads(src[0:frag_m, 0:frag_n])
T.writes(dst[0:frag_m, 0:frag_n])
- tx = T.env_thread("threadIdx.x")
- T.launch_thread(tx, 32)
-
- T.evaluate(
- T.ptx_ldmatrix(
- trans,
- 4, # Always load 4 matrices
- ".b16",
- dst.data,
- get_index(dst.elem_offset, d0),
- src.access_ptr("r"),
- get_tx_index(tx, s0),
- dtype=dtype,
+ for tx in T.thread_binding(0, WARP_SIZE, "threadIdx.x"):
+ T.evaluate(
+ T.ptx_ldmatrix(
+ trans,
+ 4, # Always load 4 matrices
+ ".b16",
+ dst.data,
+ get_index(dst.elem_offset, d0),
+ src.access_ptr("r"),
+ get_tx_index(tx, s0),
+ dtype=dtype,
+ )
)
- )
return mma_load_desc, mma_load_impl
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
similarity index 88%
rename from tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
rename to
tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
index 2a853a2431..d704dc2438 100644
--- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py
+++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma_numeric.py
@@ -22,21 +22,21 @@ import tvm.testing
from tvm import te
from tvm.testing.tir import mma_schedule
from tvm.tir.tensor_intrin.cuda import (
- LDMATRIX_16x16_A_INTRIN,
- LDMATRIX_16x16_B_INTRIN,
- LDMATRIX_16x16_B_TRANS_INTRIN,
- LDMATRIX_16x32_A_INTRIN,
- LDMATRIX_16x32_B_TRANS_INTRIN,
- LDMATRIX_32x16_B_INTRIN,
+ LDMATRIX_f16_A_INTRIN,
+ LDMATRIX_f16_B_INTRIN,
+ LDMATRIX_f16_B_TRANS_INTRIN,
+ LDMATRIX_i8_A_INTRIN,
+ LDMATRIX_i8_B_TRANS_INTRIN,
+ LDMATRIX_i8_B_INTRIN,
MMA_f16f16f16_INTRIN,
- MMA_f16f16f16_TRANS_INTRIN,
+ MMA_f16f16f16_TRANS_B_INTRIN,
MMA_f16f16f32_INTRIN,
- MMA_f16f16f32_TRANS_INTRIN,
+ MMA_f16f16f32_TRANS_B_INTRIN,
MMA_fill_16x16_f16_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_fill_16x16_i32_INTRIN,
MMA_i8i8i32_INTRIN,
- MMA_i8i8i32_TRANS_INTRIN,
+ MMA_i8i8i32_TRANS_B_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
@@ -116,15 +116,15 @@ def run_test(
dev = tvm.device("cuda", 0)
if in_dtype == "float16":
- a_np = np.random.uniform(size=(M, K)).astype("float16")
+ a_np = np.random.normal(size=(M, K)).astype("float16")
if b_transposed:
- b_np = np.random.uniform(size=(N, K)).astype("float16")
+ b_np = np.random.normal(size=(N, K)).astype("float16")
c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32").transpose()).astype(
out_dtype
)
else:
- b_np = np.random.uniform(size=(K, N)).astype("float16")
+ b_np = np.random.normal(size=(K, N)).astype("float16")
c_np = np.dot(a_np.astype("float32"),
b_np.astype("float32")).astype(out_dtype)
else:
a_np = np.random.randint(-128, 128, (M, K)).astype("int8")
@@ -147,7 +147,7 @@ def run_test(
if out_dtype != "float16":
# The numpy reference is computed with fp32 precision (otherwise too
slow).
# So there is non-trivial accuracy difference if TVM result is
computed with fp16 accumulation.
- tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3)
+ tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-2, atol=1e-2)
return lambda: f.time_evaluator(f.entry_name, dev, number=500)(a, b, c)
@@ -177,8 +177,8 @@ def test_f16f16f32_m16n16k16():
index_map,
index_map,
index_map,
- LDMATRIX_16x16_A_INTRIN,
- LDMATRIX_16x16_B_INTRIN,
+ LDMATRIX_f16_A_INTRIN,
+ LDMATRIX_f16_B_INTRIN,
MMA_f16f16f32_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
@@ -198,9 +198,9 @@ def test_f16f16f32_m16n16k16():
index_map,
index_map,
index_map,
- LDMATRIX_16x16_A_INTRIN,
- LDMATRIX_16x16_B_TRANS_INTRIN,
- MMA_f16f16f32_TRANS_INTRIN,
+ LDMATRIX_f16_A_INTRIN,
+ LDMATRIX_f16_B_TRANS_INTRIN,
+ MMA_f16f16f32_TRANS_B_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
)
@@ -234,8 +234,8 @@ def test_f16f16f16_m16n16k16():
index_map,
index_map,
index_map,
- LDMATRIX_16x16_A_INTRIN,
- LDMATRIX_16x16_B_INTRIN,
+ LDMATRIX_f16_A_INTRIN,
+ LDMATRIX_f16_B_INTRIN,
MMA_f16f16f16_INTRIN,
MMA_fill_16x16_f16_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
@@ -255,9 +255,9 @@ def test_f16f16f16_m16n16k16():
index_map,
index_map,
index_map,
- LDMATRIX_16x16_A_INTRIN,
- LDMATRIX_16x16_B_TRANS_INTRIN,
- MMA_f16f16f16_TRANS_INTRIN,
+ LDMATRIX_f16_A_INTRIN,
+ LDMATRIX_f16_B_TRANS_INTRIN,
+ MMA_f16f16f16_TRANS_B_INTRIN,
MMA_fill_16x16_f16_INTRIN,
MMA_store_16x16_f16_global_INTRIN,
)
@@ -305,8 +305,8 @@ def test_i8i8i32_m16n16k32():
index_map_A,
index_map_B,
index_map_C,
- LDMATRIX_16x32_A_INTRIN,
- LDMATRIX_32x16_B_INTRIN,
+ LDMATRIX_i8_A_INTRIN,
+ LDMATRIX_i8_B_INTRIN,
MMA_i8i8i32_INTRIN,
MMA_fill_16x16_i32_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
@@ -326,9 +326,9 @@ def test_i8i8i32_m16n16k32():
index_map_A,
index_map_A,
index_map_C,
- LDMATRIX_16x32_A_INTRIN,
- LDMATRIX_16x32_B_TRANS_INTRIN,
- MMA_i8i8i32_TRANS_INTRIN,
+ LDMATRIX_i8_A_INTRIN,
+ LDMATRIX_i8_B_TRANS_INTRIN,
+ MMA_i8i8i32_TRANS_B_INTRIN,
MMA_fill_16x16_i32_INTRIN,
MMA_store_16x16_i32_global_INTRIN,
)
diff --git a/tests/python/unittest/test_tir_schedule_tensorize_mfma.py
b/tests/python/unittest/test_tir_schedule_tensorize_mfma_numeric.py
similarity index 100%
rename from tests/python/unittest/test_tir_schedule_tensorize_mfma.py
rename to tests/python/unittest/test_tir_schedule_tensorize_mfma_numeric.py
diff --git
a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
index a013cf0f65..bc3e979f94 100644
--- a/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
+++ b/tests/python/unittest/test_tir_transform_inject_software_pipeline.py
@@ -26,8 +26,8 @@ from tvm.meta_schedule.testing import te_workload
from tvm.script import tir as T
from tvm.testing.tir import mma_schedule
from tvm.tir.tensor_intrin.cuda import (
- LDMATRIX_16x16_A_DYN_INTRIN,
- LDMATRIX_16x16_B_DYN_INTRIN,
+ LDMATRIX_f16_A_DYN_INTRIN,
+ LDMATRIX_f16_B_DYN_INTRIN,
MMA_f16f16f32_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,
@@ -1520,8 +1520,8 @@ def get_mma_schedule():
index_map,
index_map,
index_map,
- LDMATRIX_16x16_A_DYN_INTRIN,
- LDMATRIX_16x16_B_DYN_INTRIN,
+ LDMATRIX_f16_A_DYN_INTRIN,
+ LDMATRIX_f16_B_DYN_INTRIN,
MMA_f16f16f32_INTRIN,
MMA_fill_16x16_f32_INTRIN,
MMA_store_16x16_f32_global_INTRIN,