This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c1d1e9ffb8 [TIR] Add CUDA int4 tensor core intrinsics (#14598)
c1d1e9ffb8 is described below

commit c1d1e9ffb84d5eea21c99b936e4ff260ec0da63a
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Apr 12 06:50:57 2023 -0700

    [TIR] Add CUDA int4 tensor core intrinsics (#14598)
    
    This PR added int4 tensor intrinsic for CUDA tensor core.
---
 python/tvm/tir/tensor_intrin/cuda.py | 216 +++++++++++++++++++++++++++++------
 1 file changed, 182 insertions(+), 34 deletions(-)

diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
index 3bc16f234f..8d12a39ca7 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -16,6 +16,7 @@
 # under the License.
 # pylint: disable=invalid-name,missing-function-docstring
 """Intrinsics for tensorization on NVIDIA GPU."""
+import re
 from typing import Dict, Tuple
 
 from typing_extensions import Literal
@@ -43,6 +44,16 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
     return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
 
 
+def get_tensor_core_load_offset_factor(dtype):
+    """get offset factor for tensor core load intrin"""
+    bits = re.search(r"(\d+)", dtype).group(0)
+    bits = int(bits)
+    if bits <= 4:
+        # sub-byte oeprations have different offset factor
+        return 128 // bits
+    return 256 // bits
+
+
 @register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
 def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
     i, j = ind[0], ind[1]
@@ -116,6 +127,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
         col_dim = k_dim
 
     shmem_shape = (row_dim, col_dim)
+    offset_factor = get_tensor_core_load_offset_factor(dtype)
 
     @T.prim_func
     def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
@@ -124,11 +136,16 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
             shmem_shape,
             dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=offset_factor,
             scope=shared_scope,
         )
         warp = T.match_buffer(
-            warp_handle, (WARP_SIZE, local_size), dtype, align=64, 
offset_factor=16, scope="warp"
+            warp_handle,
+            (WARP_SIZE, local_size),
+            dtype,
+            align=64,
+            offset_factor=offset_factor,
+            scope="warp",
         )
 
         with T.block("root"):
@@ -153,12 +170,17 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
             shmem_shape,
             dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=offset_factor,
             scope=shared_scope,
             strides=[s0, s1],
         )
         warp = T.match_buffer(
-            warp_handle, (WARP_SIZE, local_size), dtype, align=64, 
offset_factor=16, scope="warp"
+            warp_handle,
+            (WARP_SIZE, local_size),
+            dtype,
+            align=64,
+            offset_factor=offset_factor,
+            scope="warp",
         )
 
         with T.block("root"):
@@ -222,16 +244,34 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             return j, i
         return i, j
 
+    in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
+    out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
+
     @T.prim_func
     def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
         A = T.match_buffer(
-            a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, 
scope="warp"
+            a,
+            (WARP_SIZE, local_size),
+            in_dtype,
+            align=64,
+            offset_factor=in_offset_factor,
+            scope="warp",
         )
         B = T.match_buffer(
-            b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, 
scope="warp"
+            b,
+            (WARP_SIZE, local_size),
+            in_dtype,
+            align=64,
+            offset_factor=in_offset_factor,
+            scope="warp",
         )
         C = T.match_buffer(
-            c, (WARP_SIZE, local_size_out), out_dtype, align=64, 
offset_factor=16, scope="warp"
+            c,
+            (WARP_SIZE, local_size_out),
+            out_dtype,
+            align=64,
+            offset_factor=out_offset_factor,
+            scope="warp",
         )
 
         with T.block("root"):
@@ -265,13 +305,28 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
     @T.prim_func
     def mma_sync_impl(a: T.handle, b: T.handle, c: T.handle) -> None:
         A = T.match_buffer(
-            a, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, 
scope="warp"
+            a,
+            (WARP_SIZE, local_size),
+            in_dtype,
+            align=64,
+            offset_factor=in_offset_factor,
+            scope="warp",
         )
         B = T.match_buffer(
-            b, (WARP_SIZE, local_size), in_dtype, align=64, offset_factor=16, 
scope="warp"
+            b,
+            (WARP_SIZE, local_size),
+            in_dtype,
+            align=64,
+            offset_factor=in_offset_factor,
+            scope="warp",
         )
         C = T.match_buffer(
-            c, (WARP_SIZE, local_size_out), out_dtype, align=64, 
offset_factor=16, scope="warp"
+            c,
+            (WARP_SIZE, local_size_out),
+            out_dtype,
+            align=64,
+            offset_factor=out_offset_factor,
+            scope="warp",
         )
 
         with T.block("root"):
@@ -513,17 +568,29 @@ def get_wmma_load_intrin(
     """Generator of wmma_load intrins"""
     wmma_fragment_scope = "wmma.matrix_{}".format("b" if is_b else "a")
     layout = "col_major" if is_col_major else "row_major"
+    offset_factor = get_tensor_core_load_offset_factor(dtype)
+
+    frag_m, frag_n = (k_dim, n_dim) if is_b else (m_dim, k_dim)
+    if is_col_major:
+        frag_m, frag_n = frag_n, frag_m
 
     @T.prim_func
     def wmma_load_desc(a: T.handle, c: T.handle) -> None:
-        A = T.match_buffer(a, (m_dim, n_dim), dtype, align=64, 
offset_factor=16, scope=shared_scope)
+        A = T.match_buffer(
+            a, (frag_m, frag_n), dtype, align=64, offset_factor=offset_factor, 
scope=shared_scope
+        )
         C = T.match_buffer(
-            c, (m_dim, n_dim), dtype, align=64, offset_factor=16, 
scope=wmma_fragment_scope
+            c,
+            (frag_m, frag_n),
+            dtype,
+            align=64,
+            offset_factor=offset_factor,
+            scope=wmma_fragment_scope,
         )
         with T.block("root"):
-            T.reads(A[0:m_dim, 0:n_dim])
-            T.writes(C[0:m_dim, 0:n_dim])
-            for i, j in T.grid(m_dim, n_dim):
+            T.reads(A[0:frag_m, 0:frag_n])
+            T.writes(C[0:frag_m, 0:frag_n])
+            for i, j in T.grid(frag_m, frag_n):
                 with T.block("load"):
                     vii, vjj = T.axis.remap("SS", [i, j])
                     C[vii, vjj] = A[vii, vjj]
@@ -536,32 +603,32 @@ def get_wmma_load_intrin(
         d0 = T.int32()
         A = T.match_buffer(
             a,
-            (m_dim, n_dim),
+            (frag_m, frag_n),
             dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=offset_factor,
             scope=shared_scope,
             strides=[s1, s0],
         )
         C = T.match_buffer(
             c,
-            (m_dim, n_dim),
+            (frag_m, frag_n),
             dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=offset_factor,
             scope=wmma_fragment_scope,
             strides=[d1, d0],
         )
         with T.block("root"):
-            T.reads(A[0:m_dim, 0:n_dim])
-            T.writes(C[0:m_dim, 0:n_dim])
+            T.reads(A[0:frag_m, 0:frag_n])
+            T.writes(C[0:frag_m, 0:frag_n])
             T.evaluate(
                 T.tvm_load_matrix_sync(
                     C.data,
                     m_dim,
                     n_dim,
                     k_dim,
-                    get_wmma_fragment_index(C, d1, m_dim, n_dim),
+                    get_wmma_fragment_index(C, d1, frag_m, frag_n),
                     A.access_ptr("r"),
                     s1,
                     layout,
@@ -577,11 +644,17 @@ def get_wmma_fill_intrin(
 ) -> Tuple[PrimFunc, PrimFunc]:
     """Generator of wmma_fill intrins"""
     zero = IntImm("int32", 0).astype(dtype)
+    offset_factor = get_tensor_core_load_offset_factor(dtype)
 
     @T.prim_func
     def wmma_fill_desc(c: T.handle) -> None:
         C = T.match_buffer(
-            c, (m_dim, n_dim), dtype, align=64, offset_factor=16, 
scope="wmma.accumulator"
+            c,
+            (m_dim, n_dim),
+            dtype,
+            align=64,
+            offset_factor=offset_factor,
+            scope="wmma.accumulator",
         )
         with T.block("root"):
             T.reads()
@@ -600,7 +673,7 @@ def get_wmma_fill_intrin(
             (m_dim, n_dim),
             dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=offset_factor,
             scope="wmma.accumulator",
             strides=[d1, d0],
         )
@@ -626,13 +699,21 @@ def get_wmma_store_intrin(
     m_dim: int, n_dim: int, k_dim: int, dtype: str, scope: str
 ) -> Tuple[PrimFunc, PrimFunc]:
     """Generator of wmma_store intrins"""
+    offset_factor = get_tensor_core_load_offset_factor(dtype)
 
     @T.prim_func
     def wmma_store_desc(a: T.handle, c: T.handle) -> None:
         A = T.match_buffer(
-            a, (m_dim, n_dim), dtype, align=64, offset_factor=16, 
scope="wmma.accumulator"
+            a,
+            (m_dim, n_dim),
+            dtype,
+            align=64,
+            offset_factor=offset_factor,
+            scope="wmma.accumulator",
+        )
+        C = T.match_buffer(
+            c, (m_dim, n_dim), dtype, align=64, offset_factor=offset_factor, 
scope=scope
         )
-        C = T.match_buffer(c, (m_dim, n_dim), dtype, align=64, 
offset_factor=16, scope=scope)
         with T.block("root"):
             T.reads(A[0:m_dim, 0:n_dim])
             T.writes(C[0:m_dim, 0:n_dim])
@@ -652,12 +733,18 @@ def get_wmma_store_intrin(
             (m_dim, n_dim),
             dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=offset_factor,
             scope="wmma.accumulator",
             strides=[d1, d0],
         )
         C = T.match_buffer(
-            c, (m_dim, n_dim), dtype, align=64, offset_factor=16, scope=scope, 
strides=[s1, s0]
+            c,
+            (m_dim, n_dim),
+            dtype,
+            align=64,
+            offset_factor=offset_factor,
+            scope=scope,
+            strides=[s1, s0],
         )
         with T.block("root"):
             T.reads(A[0:m_dim, 0:n_dim])
@@ -683,6 +770,8 @@ def get_wmma_sync_intrin(
     m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str, 
b_transposed: bool
 ) -> Tuple[PrimFunc, PrimFunc]:
     """Generator of wmma_sync intrins"""
+    in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
+    out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
 
     def maybe_cast(v):
         if in_dtype != out_dtype:
@@ -699,18 +788,28 @@ def get_wmma_sync_intrin(
     @T.prim_func
     def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
         A = T.match_buffer(
-            a, (m_dim, k_dim), in_dtype, align=64, offset_factor=16, 
scope="wmma.matrix_a"
+            a,
+            (m_dim, k_dim),
+            in_dtype,
+            align=64,
+            offset_factor=in_offset_factor,
+            scope="wmma.matrix_a",
         )
         B = T.match_buffer(
             b,
             maybe_swap(k_dim, n_dim),
             in_dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=in_offset_factor,
             scope="wmma.matrix_b",
         )
         C = T.match_buffer(
-            c, (m_dim, n_dim), out_dtype, align=64, offset_factor=16, 
scope="wmma.accumulator"
+            c,
+            (m_dim, n_dim),
+            out_dtype,
+            align=64,
+            offset_factor=out_offset_factor,
+            scope="wmma.accumulator",
         )
 
         with T.block("root"):
@@ -738,7 +837,7 @@ def get_wmma_sync_intrin(
             (m_dim, k_dim),
             in_dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=in_offset_factor,
             scope="wmma.matrix_a",
             strides=[a1, a0],
         )
@@ -747,7 +846,7 @@ def get_wmma_sync_intrin(
             maybe_swap(k_dim, n_dim),
             in_dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=in_offset_factor,
             scope="wmma.matrix_b",
             strides=[b1, b0],
         )
@@ -756,7 +855,7 @@ def get_wmma_sync_intrin(
             (m_dim, n_dim),
             out_dtype,
             align=64,
-            offset_factor=16,
+            offset_factor=out_offset_factor,
             scope="wmma.accumulator",
             strides=[c1, c0],
         )
@@ -817,6 +916,12 @@ TensorIntrin.register(
     *get_wmma_sync_intrin(16, 16, 16, "int8", "int32", True),
 )
 
+WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN = "wmma_sync_8x8x32_s4s4s32_trans"
+TensorIntrin.register(
+    WMMA_SYNC_8x8x32_s4s4s32_TRANS_INTRIN,
+    *get_wmma_sync_intrin(8, 8, 32, "int4", "int32", True),
+)
+
 WMMA_LOAD_16x16x16_F16_A_INTRIN = "wmma_load_16x16x16_f16_a_shared"
 TensorIntrin.register(
     WMMA_LOAD_16x16x16_F16_A_INTRIN,
@@ -913,6 +1018,30 @@ TensorIntrin.register(
     *get_wmma_load_intrin(16, 16, 16, "int8", "shared.dyn", True, True),
 )
 
+WMMA_LOAD_8x8x32_S4_A_INTRIN = "wmma_load_8x8x32_s4_a_shared"
+TensorIntrin.register(
+    WMMA_LOAD_8x8x32_S4_A_INTRIN,
+    *get_wmma_load_intrin(8, 8, 32, "int4", "shared", False, False),
+)
+
+WMMA_LOAD_8x8x32_S4_A_DYN_INTRIN = "wmma_load_8x8x32_s4_a_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_8x8x32_S4_A_DYN_INTRIN,
+    *get_wmma_load_intrin(8, 8, 32, "int4", "shared.dyn", False, False),
+)
+
+WMMA_LOAD_8x8x32_S4_B_TRANS_INTRIN = "wmma_load_8x8x32_s4_b_trans_shared"
+TensorIntrin.register(
+    WMMA_LOAD_8x8x32_S4_B_TRANS_INTRIN,
+    *get_wmma_load_intrin(8, 8, 32, "int4", "shared", True, True),
+)
+
+WMMA_LOAD_8x8x32_S4_B_TRANS_DYN_INTRIN = 
"wmma_load_8x8x32_s4_b_trans_shared_dyn"
+TensorIntrin.register(
+    WMMA_LOAD_8x8x32_S4_B_TRANS_DYN_INTRIN,
+    *get_wmma_load_intrin(8, 8, 32, "int4", "shared.dyn", True, True),
+)
+
 WMMA_FILL_16x16x16_F32_INTRIN = "wmma_fill_16x16x16_f32"
 TensorIntrin.register(WMMA_FILL_16x16x16_F32_INTRIN, *get_wmma_fill_intrin(16, 
16, 16, "float32"))
 
@@ -922,6 +1051,9 @@ TensorIntrin.register(WMMA_FILL_16x16x16_F16_INTRIN, 
*get_wmma_fill_intrin(16, 1
 WMMA_FILL_16x16x16_S32_INTRIN = "wmma_fill_16x16x16_s32"
 TensorIntrin.register(WMMA_FILL_16x16x16_S32_INTRIN, *get_wmma_fill_intrin(16, 
16, 16, "int32"))
 
+WMMA_FILL_8x8x32_S32_INTRIN = "wmma_fill_8x8x32_s32"
+TensorIntrin.register(WMMA_FILL_8x8x32_S32_INTRIN, *get_wmma_fill_intrin(8, 8, 
32, "int32"))
+
 WMMA_STORE_16x16x16_F32_SHARED_INTRIN = "wmma_store_16x16x16_f32_shared"
 TensorIntrin.register(
     WMMA_STORE_16x16x16_F32_SHARED_INTRIN, *get_wmma_store_intrin(16, 16, 16, 
"float32", "shared")
@@ -955,6 +1087,17 @@ TensorIntrin.register(
     *get_wmma_store_intrin(16, 16, 16, "int32", "shared.dyn"),
 )
 
+WMMA_STORE_8x8x32_S32_SHARED_INTRIN = "wmma_store_8x8x32_s32_shared"
+TensorIntrin.register(
+    WMMA_STORE_8x8x32_S32_SHARED_INTRIN, *get_wmma_store_intrin(8, 8, 32, 
"int32", "shared")
+)
+
+WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN = "wmma_store_8x8x32_s32_shared_dyn"
+TensorIntrin.register(
+    WMMA_STORE_8x8x32_S32_SHARED_DYN_INTRIN,
+    *get_wmma_store_intrin(8, 8, 32, "int32", "shared.dyn"),
+)
+
 WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN = "wmma_store_16x16x16_f32_global"
 TensorIntrin.register(
     WMMA_STORE_16x16x16_F32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, 
"float32", "global")
@@ -970,6 +1113,11 @@ TensorIntrin.register(
     WMMA_STORE_16x16x16_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(16, 16, 16, 
"int32", "global")
 )
 
+WMMA_STORE_8x8x32_S32_GLOBAL_INTRIN = "wmma_store_8x8x32_s32_global"
+TensorIntrin.register(
+    WMMA_STORE_8x8x32_S32_GLOBAL_INTRIN, *get_wmma_store_intrin(8, 8, 32, 
"int32", "global")
+)
+
 
 def get_wmma_intrin_group(
     load_scope: Literal["shared", "shared.dyn"],

Reply via email to