This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c1d1e9ffb8 [TIR] Add CUDA int4 tensor core intrinsics (#14598)
c1d1e9ffb8 is described below
commit c1d1e9ffb84d5eea21c99b936e4ff260ec0da63a
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Apr 12 06:50:57 2023 -0700
[TIR] Add CUDA int4 tensor core intrinsics (#14598)
This PR added int4 tensor intrinsic for CUDA tensor core.
---
python/tvm/tir/tensor_intrin/cuda.py | 216 +++++++++++++++++++++++++++++------
1 file changed, 182 insertions(+), 34 deletions(-)
diff --git a/python/tvm/tir/tensor_intrin/cuda.py
b/python/tvm/tir/tensor_intrin/cuda.py
index 3bc16f234f..8d12a39ca7 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for tensorization on NVIDIA GPU."""
+import re
from typing import Dict, Tuple
from typing_extensions import Literal
@@ -43,6 +44,16 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
+def get_tensor_core_load_offset_factor(dtype):
+ """get offset factor for tensor core load intrin"""
+ bits = re.search(r"(\d+)", dtype).group(0)
+ bits = int(bits)
+ if bits <= 4:
+ # sub-byte oeprations have different offset factor
+ return 128 // bits
+ return 256 // bits
+
+
@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
i, j = ind[0], ind[1]
@@ -116,6 +127,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
col_dim = k_dim
shmem_shape = (row_dim, col_dim)
+ offset_factor = get_tensor_core_load_offset_factor(dtype)
@T.prim_func
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
@@ -124,11 +136,16 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
shmem_shape,
dtype,
align=64,
- offset_factor=16,
+ offset_factor=offset_factor,
scope=shared_scope,
)
warp = T.match_buffer(
- warp_handle, (WARP_SIZE, local_size), dtype, align=64,
offset_factor=16, scope="warp"
+ warp_handle,
+ (WARP_SIZE, local_size),
+ dtype,
+ align=64,
+ offset_factor=offset_factor,
+ scope="warp",
)
with T.block("root"):
@@ -153,12 +170,17 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
shmem_shape,
dtype,
align=64,
- offset_factor=16,
+ offset_factor=offset_factor,
scope=shared_scope,
strides=[s0, s1],
)
warp = T.match_buffer(
- warp_handle, (WARP_SIZE, local_size), dtype, align=64,
offset_factor=16, scope="warp"
+ warp_handle,
+ (WARP_SIZE, local_size),
+ dtype,
+ align=64,
+ offset_factor=offset_factor,
+ scope="warp",
)
with T.block("root"):
@@ -222,16 +244,34 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
return j, i
return i, j
+ in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
+ out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
+
@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
- a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16,
scope="warp"
+ a,
+ (WARP_SIZE, local_size),
+ in_dtype,
+ align=64,
+ offset_factor=in_offset_factor,
+ scope="warp",
)
B = T.match_buffer(
- b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16,
scope="warp"
+ b,
+ (WARP_SIZE, local_size),
+ in_dtype,
+ align=64,
+ offset_factor=in_offset_factor,
+ scope="warp",
)
C = T.match_buffer(
- c, (WARP_SIZE, local_size_out), out_dtype, align=64,
offset_factor=16, scope="warp"
+ c,
+ (WARP_SIZE, local_size_out),
+ out_dtype,
+ align=64,
+ offset_factor=out_offset_factor,
+ scope="warp",
)
with T.block("root"):
@@ -265,13 +305,28 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
@T.prim_func
def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
- a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16,
scope="warp"
+ a,
+ (WARP_SIZE, local_size),
+ in_dtype,
+ align=64,
+ offset_factor=in_offset_factor,
+ scope="warp",
)
B = T.match_buffer(
- b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16,
scope="warp"
+ b,
+ (WARP_SIZE, local_size),
+ in_dtype,
+ align=64,
+ offset_factor=in_offset_factor,
+ scope="warp",
)
C = T.match_buffer(
- c, (WARP_SIZE, local_size_out), out_dtype, align=64,
offset_factor=16, scope="warp"
+ c,
+ (WARP_SIZE, local_size_out),
+ out_dtype,
+ align=64,
+ offset_factor=out_offset_factor,
+ scope="warp",
)
with T.block("root"):
@@ -513,17 +568,29 @@ def get_wmma_load_intrin(
"""Generator of wmma_load intrins"""
wmma_fragment_scope = "wmma.matrix_{}".format("b" if is_b else "a")
layout = "col_major" if is_col_major else "row_major"
+ offset_factor = get_tensor_core_load_offset_factor(dtype)
+
+ frag_m, frag_n = (k_dim, n_dim) if is_b else (m_dim, k_dim)
+ if is_col_major:
+ frag_m, frag_n = frag_n, frag_m
@T.prim_func
def wmma_load_desc(a: T.handle, c: T.handle) -> None:
- A = T.match_buffer(a, (m_dim, n_dim), dtype, align=64,
offset_factor=16, scope=shared_scope)
+ A = T.match_buffer(
+ a, (frag_m, frag_n), dtype, align=64, offset_factor=offset_factor,
scope=shared_scope
+ )
C = T.match_buffer(
- c, (m_dim, n_dim), dtype, align=64, offset_factor=16,
scope=wmma_fragment_scope
+ c,
+ (frag_m, frag_n),
+ dtype,
+ align=64,
+ offset_factor=offset_factor,
+ scope=wmma_fragment_scope,
)
with T.block("root"):
- T.reads(A[0:m_dim, 0:n_dim])
- T.writes(C[0:m_dim, 0:n_dim])
- for i, j in T.grid(m_dim, n_dim):
+ T.reads(A[0:frag_m, 0:frag_n])
+ T.writes(C[0:frag_m, 0:frag_n])
+ for i, j in T.grid(frag_m, frag_n):
with T.block("load"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = A[vii, vjj]
@@ -536,32 +603,32 @@ def get_wmma_load_intrin(
d0 = T.int32()
A = T.match_buffer(
a,
- (m_dim, n_dim),
+ (frag_m, frag_n),
dtype,
align=64,
- offset_factor=16,
+ offset_factor=offset_factor,
scope=shared_scope,
strides=[s1, s0],
)
C = T.match_buffer(
c,
- (m_dim, n_dim),
+ (frag_m, frag_n),
dtype,
align=64,
- offset_factor=16,
+ offset_factor=offset_factor,
scope=wmma_fragment_scope,
strides=[d1, d0],
)
with T.block("root"):
- T.reads(A[0:m_dim, 0:n_dim])
- T.writes(C[0:m_dim, 0:n_dim])
+ T.reads(A[0:frag_m, 0:frag_n])
+ T.writes(C[0:frag_m, 0:frag_n])
T.evaluate(
T.tvm_load_matrix_sync(
C.data,
m_dim,
n_dim,
k_dim,
- get_wmma_fragment_index(C, d1, m_dim, n_dim),
+ get_wmma_fragment_index(C, d1, frag_m, frag_n),
A.access_ptr("r"),
s1,
layout,
@@ -577,11 +644,17 @@ def get_wmma_fill_intrin(
) -> Tuple[PrimFunc, PrimFunc]:
"""Generator of wmma_fill intrins"""
zero = IntImm("int32", 0).astype(dtype)
+ offset_factor = get_tensor_core_load_offset_factor(dtype)
@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
C = T.match_buffer(
- c, (m_dim, n_dim), dtype, align=64, offset_factor=16,
scope="wmma.accumulator"
+ c,
+ (m_dim, n_dim),
+ dtype,
+ align=64,
+ offset_factor=offset_factor,
+ scope="wmma.accumulator",
)
with T.block("root"):
T.reads()
@@ -600,7 +673,7 @@ def get_wmma_fill_intrin(
(m_dim, n_dim),
dtype,
align=64,
- offset_factor=16,
+ offset_factor=offset_factor,
scope="wmma.accumulator",
strides=[d1, d0],
)
@@ -626,13 +699,21 @@ def get_wmma_store_intrin(
m_dim: int, n_dim: int, k_dim: int, dtype: str, scope: str
) -> Tuple[PrimFunc, PrimFunc]:
"""Generator of wmma_store intrins"""
+ offset_factor = get_tensor_core_load_offset_factor(dtype)
@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
A = T.match_buffer(
- a, (m_dim, n_dim), dtype, align=64, offset_factor=16,
scope="wmma.accumulator"
+ a,
+ (m_dim, n_dim),
+ dtype,
+ align=64,
+ offset_factor=offset_factor,
+ scope="wmma.accumulator",
+ )
+ C = T.match_buffer(
+ c, (m_dim, n_dim), dtype, align=64, offset_factor=offset_factor,
scope=scope
)
- C = T.match_buffer(c, (m_dim, n_dim), dtype, align=64,
offset_factor=16, scope=scope)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
T.writes(C[0:m_dim, 0:n_dim])
@@ -652,12 +733,18 @@ def get_wmma_store_intrin(
(m_dim, n_dim),
dtype,
align=64,
- offset_factor=16,
+ offset_factor=offset_factor,
scope="wmma.accumulator",
strides=[d1, d0],
)
C = T.match_buffer(
- c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope,
strides=[s1, s0]
+ c,
+ (m_dim, n_dim),
+ dtype,
+ align=64,
+ offset_factor=offset_factor,
+ scope=scope,
+ strides=[s1, s0],
)
with T.block("root"):
T.reads(A[0:m_dim, 0:n_dim])
@@ -683,6 +770,8 @@ def get_wmma_sync_intrin(
m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str,
b_transposed: bool
) -> Tuple[PrimFunc, PrimFunc]:
"""Generator of wmma_sync intrins"""
+ in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
+ out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
def maybe_cast(v):
if in_dtype != out_dtype:
@@ -699,18 +788,28 @@ def get_wmma_sync_intrin(
@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
- a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16,
scope="wmma.matrix_a"
+ a,
+ (m_dim, k_dim),
+ in_dtype,
+ align=64,
+ offset_factor=in_offset_factor,
+ scope="wmma.matrix_a",
)
B = T.match_buffer(
b,
maybe_swap(k_dim, n_dim),
in_dtype,
align=64,
- offset_factor=16,
+ offset_factor=in_offset_factor,
scope="wmma.matrix_b",
)
C = T.match_buffer(
- c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16,
scope="wmma.accumulator"
+ c,
+ (m_dim, n_dim),
+ out_dtype,
+ align=64,
+ offset_factor=out_offset_factor,
+ scope="wmma.accumulator",
)
with T.block("root"):
@@ -738,7 +837,7 @@ def get_wmma_sync_intrin(
(m_dim, k_dim),
in_dtype,
align=64,
- offset_factor=16,
+ offset_factor=in_offset_factor,
scope="wmma.matrix_a",
strides=[a1, a0],
)
@@ -747,7 +846,7 @@ def get_wmma_sync_intrin(
maybe_swap(k_dim, n_dim),
in_dtype,
align=64,
- offset_factor=16,
+ offset_factor=in_offset_factor,
scope="wmma.matrix_b",
strides=[b1, b0],
)
@@ -756,7 +855,7 @@ def get_wmma_sync_intrin(
(m_dim, n_dim),
out_dtype,
align=64,
- offset_factor=16,
+ offset_factor=out_offset_factor,
scope="wmma.accumulator",
strides=[c1, c0],
)
@@ -817,6 +916,12 @@ TensorIntrin.register(
*get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True),
)
+WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN = "wmma_sync_8x8x32_s4s4s32_trans"
+TensorIntrin.register(
+ WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN,
+ *get_wmma_sync_intrin(8, 8, 32, "int4", "int32", True),
+)
+
WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared"
TensorIntrin.register(
WMMA_LOAD_16x16x16_F16_A_INTRIN,
@@ -913,6 +1018,30 @@ TensorIntrin.register(
*get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True),
)
+WMMA_LOAD_8x8x32_S4_A_INTRIN = "wmma_load_8x8x32_s4_a_shared"
+TensorIntrin.register(
+ WMMA_LOAD_8x8x32_S4_A_INTRIN,
+ *get_wmma_load_intrin(8, 8, 32, "int4", "shared", False, False),
+)
+
+WMMA_LOAD_8x8x32_S4_A_DYN_INTRIN = "wmma_load_8x8x32_s4_a_shared_dyn"
+TensorIntrin.register(
+ WMMA_LOAD_8x8x32_S4_A_DYN_INTRIN,
+ *get_wmma_load_intrin(8, 8, 32, "int4", "shared.dyn", False, False),
+)
+
+WMMA_LOAD_8x8x32_S4_B_TRANS_INTRIN = "wmma_load_8x8x32_s4_b_trans_shared"
+TensorIntrin.register(
+ WMMA_LOAD_8x8x32_S4_B_TRANS_INTRIN,
+ *get_wmma_load_intrin(8, 8, 32, "int4", "shared", True, True),
+)
+
+WMMA_LOAD_8x8x32_S4_B_TRANS_DYN_INTRIN =
"wmma_load_8x8x32_s4_b_trans_shared_dyn"
+TensorIntrin.register(
+ WMMA_LOAD_8x8x32_S4_B_TRANS_DYN_INTRIN,
+ *get_wmma_load_intrin(8, 8, 32, "int4", "shared.dyn", True, True),
+)
+
WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32"
TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16,
16, 16, "float32"))
@@ -922,6 +1051,9 @@ TensorIntrin.register(WMMA_FILL_16x16x16_F16_INTRIN,
*get_wmma_fill_intrin(16, 1
WMMA_FILL_16x16x16_S32_INTRIN = "wmma_fill_16x16x16_s32"
TensorIntrin.register(WMMA_FILL_16x16x16_S32_INTRIN, *get_wmma_fill_intrin(16,
16, 16, "int32"))
+WMMA_FILL_8x8x32_S32_INTRIN = "wmma_fill_8x8x32_s32"
+TensorIntrin.register(WMMA_FILL_8x8x32_S32_INTRIN, *get_wmma_fill_intrin(8, 8,
32, "int32"))
+
WMMA_STORE_16x16x16_F32_SHARED_INTRIN = "wmma_store_16x16x16_f32_shared"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16,
"float32", "shared")
@@ -955,6 +1087,17 @@ TensorIntrin.register(
*get_wmma_store_intrin(16, 16, 16, "int32", "shared.dyn"),
)
+WMMA_STORE_8x8x32_S32_SHARED_INTRIN = "wmma_store_8x8x32_s32_shared"
+TensorIntrin.register(
+ WMMA_STORE_8x8x32_S32_SHARED_INTRIN, *get_wmma_store_intrin(8, 8, 32,
"int32", "shared")
+)
+
+WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN = "wmma_store_8x8x32_s32_shared_dyn"
+TensorIntrin.register(
+ WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN,
+ *get_wmma_store_intrin(8, 8, 32, "int32", "shared.dyn"),
+)
+
WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global"
TensorIntrin.register(
WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16,
"float32", "global")
@@ -970,6 +1113,11 @@ TensorIntrin.register(
WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16,
"int32", "global")
)
+WMMA_STORE_8x8x32_S32_GLOBAL_INTRIN = "wmma_store_8x8x32_s32_global"
+TensorIntrin.register(
+ WMMA_STORE_8x8x32_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(8, 8, 32,
"int32", "global")
+)
+
def get_wmma_intrin_group(
load_scope: Literal["shared", "shared.dyn"],