This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 486513ae16 [TIR] Fix offset_factor in cuda tensor core intrins (#15913) 486513ae16 is described below commit 486513ae167ffdb56fae414a107413937ab4032c Author: Wuwei Lin <wu...@apache.org> AuthorDate: Wed Oct 11 23:25:24 2023 -0700 [TIR] Fix offset_factor in cuda tensor core intrins (#15913) --- python/tvm/tir/tensor_intrin/cuda.py | 46 +++++++++++++++--------------------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/python/tvm/tir/tensor_intrin/cuda.py b/python/tvm/tir/tensor_intrin/cuda.py index 44de418dad..6ee00ee634 100644 --- a/python/tvm/tir/tensor_intrin/cuda.py +++ b/python/tvm/tir/tensor_intrin/cuda.py @@ -16,7 +16,6 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for tensorization on NVIDIA GPU.""" -import re from typing import Dict, Tuple from typing_extensions import Literal @@ -44,16 +43,6 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j): return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4 -def get_tensor_core_load_offset_factor(dtype): - """get offset factor for tensor core load intrin""" - bits = re.search(r"(\d+)", dtype).group(0) - bits = int(bits) - if bits <= 4: - # sub-byte oeprations have different offset factor - return 128 // bits - return 256 // bits - - @register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout") def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind): i, j = ind[0], ind[1] @@ -127,7 +116,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed, shared_scope="shared"): col_dim = k_dim shmem_shape = (row_dim, col_dim) - offset_factor = get_tensor_core_load_offset_factor(dtype) + offset_factor = col_dim @T.prim_func def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None: @@ -244,8 +233,9 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed): return j, i return i, j - in_offset_factor = get_tensor_core_load_offset_factor(in_dtype) - out_offset_factor = get_tensor_core_load_offset_factor(out_dtype) + A_offset_factor = k_dim + B_offset_factor = maybe_swap(k_dim, N_DIM)[-1] + out_offset_factor = N_DIM @T.prim_func def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: @@ -254,7 +244,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed): (WARP_SIZE, local_size), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=A_offset_factor, scope="warp", ) B = T.match_buffer( @@ -262,7 +252,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed): (WARP_SIZE, local_size), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=B_offset_factor, scope="warp", ) C = T.match_buffer( @@ -309,7 +299,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed): (WARP_SIZE, local_size), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=A_offset_factor, scope="warp", ) B = T.match_buffer( @@ -317,7 +307,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed): (WARP_SIZE, local_size), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=B_offset_factor, scope="warp", ) C = T.match_buffer( @@ -568,11 +558,11 @@ def get_wmma_load_intrin( """Generator of wmma_load intrins""" wmma_fragment_scope = f"wmma.matrix_{'b' if is_b else 'a'}" layout = "col_major" if is_col_major else "row_major" - offset_factor = get_tensor_core_load_offset_factor(dtype) frag_m, frag_n = (k_dim, n_dim) if is_b else (m_dim, k_dim) if is_col_major: frag_m, frag_n = frag_n, frag_m + offset_factor = frag_n @T.prim_func def wmma_load_desc(a: T.handle, c: T.handle) -> None: @@ -644,7 +634,7 @@ def get_wmma_fill_intrin( ) -> Tuple[PrimFunc, PrimFunc]: """Generator of wmma_fill intrins""" zero = IntImm("int32", 0).astype(dtype) - offset_factor = get_tensor_core_load_offset_factor(dtype) + offset_factor = n_dim @T.prim_func def wmma_fill_desc(c: T.handle) -> None: @@ -699,7 +689,7 @@ def get_wmma_store_intrin( m_dim: int, n_dim: int, k_dim: int, dtype: str, scope: str ) -> Tuple[PrimFunc, PrimFunc]: """Generator of wmma_store intrins""" - offset_factor = get_tensor_core_load_offset_factor(dtype) + offset_factor = n_dim @T.prim_func def wmma_store_desc(a: T.handle, c: T.handle) -> None: @@ -770,8 +760,6 @@ def get_wmma_sync_intrin( m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str, b_transposed: bool ) -> Tuple[PrimFunc, PrimFunc]: """Generator of wmma_sync intrins""" - in_offset_factor = get_tensor_core_load_offset_factor(in_dtype) - out_offset_factor = get_tensor_core_load_offset_factor(out_dtype) def maybe_cast(v): if in_dtype != out_dtype: @@ -785,6 +773,10 @@ def get_wmma_sync_intrin( b_shape_0, b_shape_1 = maybe_swap(k_dim, n_dim) + A_offset_factor = k_dim + B_offset_factor = b_shape_1 + out_offset_factor = n_dim + @T.prim_func def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None: A = T.match_buffer( @@ -792,7 +784,7 @@ def get_wmma_sync_intrin( (m_dim, k_dim), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=A_offset_factor, scope="wmma.matrix_a", ) B = T.match_buffer( @@ -800,7 +792,7 @@ def get_wmma_sync_intrin( maybe_swap(k_dim, n_dim), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=B_offset_factor, scope="wmma.matrix_b", ) C = T.match_buffer( @@ -837,7 +829,7 @@ def get_wmma_sync_intrin( (m_dim, k_dim), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=A_offset_factor, scope="wmma.matrix_a", strides=[a1, a0], ) @@ -846,7 +838,7 @@ def get_wmma_sync_intrin( maybe_swap(k_dim, n_dim), in_dtype, align=64, - offset_factor=in_offset_factor, + offset_factor=B_offset_factor, scope="wmma.matrix_b", strides=[b1, b0], )