This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 486513ae16 [TIR] Fix offset_factor in cuda tensor core intrins (#15913)
486513ae16 is described below

commit 486513ae167ffdb56fae414a107413937ab4032c
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Wed Oct 11 23:25:24 2023 -0700

    [TIR] Fix offset_factor in cuda tensor core intrins (#15913)
---
 python/tvm/tir/tensor_intrin/cuda.py | 46 +++++++++++++++---------------------
 1 file changed, 19 insertions(+), 27 deletions(-)

diff --git a/python/tvm/tir/tensor_intrin/cuda.py 
b/python/tvm/tir/tensor_intrin/cuda.py
index 44de418dad..6ee00ee634 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -16,7 +16,6 @@
 # under the License.
 # pylint: disable=invalid-name,missing-function-docstring
 """Intrinsics for tensorization on NVIDIA GPU."""
-import re
 from typing import Dict, Tuple
 
 from typing_extensions import Literal
@@ -44,16 +43,6 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
     return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
 
 
-def get_tensor_core_load_offset_factor(dtype):
-    """get offset factor for tensor core load intrin"""
-    bits = re.search(r"(\d+)", dtype).group(0)
-    bits = int(bits)
-    if bits <= 4:
-        # sub-byte oeprations have different offset factor
-        return 128 // bits
-    return 256 // bits
-
-
 @register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
 def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
     i, j = ind[0], ind[1]
@@ -127,7 +116,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, 
shared_scope="shared"):
         col_dim = k_dim
 
     shmem_shape = (row_dim, col_dim)
-    offset_factor = get_tensor_core_load_offset_factor(dtype)
+    offset_factor = col_dim
 
     @T.prim_func
     def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
@@ -244,8 +233,9 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             return j, i
         return i, j
 
-    in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
-    out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
+    A_offset_factor = k_dim
+    B_offset_factor = maybe_swap(k_dim, N_DIM)[-1]
+    out_offset_factor = N_DIM
 
     @T.prim_func
     def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
@@ -254,7 +244,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             (WARP_SIZE, local_size),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=A_offset_factor,
             scope="warp",
         )
         B = T.match_buffer(
@@ -262,7 +252,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             (WARP_SIZE, local_size),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=B_offset_factor,
             scope="warp",
         )
         C = T.match_buffer(
@@ -309,7 +299,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             (WARP_SIZE, local_size),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=A_offset_factor,
             scope="warp",
         )
         B = T.match_buffer(
@@ -317,7 +307,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
             (WARP_SIZE, local_size),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=B_offset_factor,
             scope="warp",
         )
         C = T.match_buffer(
@@ -568,11 +558,11 @@ def get_wmma_load_intrin(
     """Generator of wmma_load intrins"""
     wmma_fragment_scope = f"wmma.matrix_{'b' if is_b else 'a'}"
     layout = "col_major" if is_col_major else "row_major"
-    offset_factor = get_tensor_core_load_offset_factor(dtype)
 
     frag_m, frag_n = (k_dim, n_dim) if is_b else (m_dim, k_dim)
     if is_col_major:
         frag_m, frag_n = frag_n, frag_m
+    offset_factor = frag_n
 
     @T.prim_func
     def wmma_load_desc(a: T.handle, c: T.handle) -> None:
@@ -644,7 +634,7 @@ def get_wmma_fill_intrin(
 ) -> Tuple[PrimFunc, PrimFunc]:
     """Generator of wmma_fill intrins"""
     zero = IntImm("int32", 0).astype(dtype)
-    offset_factor = get_tensor_core_load_offset_factor(dtype)
+    offset_factor = n_dim
 
     @T.prim_func
     def wmma_fill_desc(c: T.handle) -> None:
@@ -699,7 +689,7 @@ def get_wmma_store_intrin(
     m_dim: int, n_dim: int, k_dim: int, dtype: str, scope: str
 ) -> Tuple[PrimFunc, PrimFunc]:
     """Generator of wmma_store intrins"""
-    offset_factor = get_tensor_core_load_offset_factor(dtype)
+    offset_factor = n_dim
 
     @T.prim_func
     def wmma_store_desc(a: T.handle, c: T.handle) -> None:
@@ -770,8 +760,6 @@ def get_wmma_sync_intrin(
     m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str, 
b_transposed: bool
 ) -> Tuple[PrimFunc, PrimFunc]:
     """Generator of wmma_sync intrins"""
-    in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
-    out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
 
     def maybe_cast(v):
         if in_dtype != out_dtype:
@@ -785,6 +773,10 @@ def get_wmma_sync_intrin(
 
     b_shape_0, b_shape_1 = maybe_swap(k_dim, n_dim)
 
+    A_offset_factor = k_dim
+    B_offset_factor = b_shape_1
+    out_offset_factor = n_dim
+
     @T.prim_func
     def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
         A = T.match_buffer(
@@ -792,7 +784,7 @@ def get_wmma_sync_intrin(
             (m_dim, k_dim),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=A_offset_factor,
             scope="wmma.matrix_a",
         )
         B = T.match_buffer(
@@ -800,7 +792,7 @@ def get_wmma_sync_intrin(
             maybe_swap(k_dim, n_dim),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=B_offset_factor,
             scope="wmma.matrix_b",
         )
         C = T.match_buffer(
@@ -837,7 +829,7 @@ def get_wmma_sync_intrin(
             (m_dim, k_dim),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=A_offset_factor,
             scope="wmma.matrix_a",
             strides=[a1, a0],
         )
@@ -846,7 +838,7 @@ def get_wmma_sync_intrin(
             maybe_swap(k_dim, n_dim),
             in_dtype,
             align=64,
-            offset_factor=in_offset_factor,
+            offset_factor=B_offset_factor,
             scope="wmma.matrix_b",
             strides=[b1, b0],
         )

Reply via email to