This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 486513ae16 [TIR] Fix offset_factor in cuda tensor core intrins (#15913)
486513ae16 is described below
commit 486513ae167ffdb56fae414a107413937ab4032c
Author: Wuwei Lin <[email protected]>
AuthorDate: Wed Oct 11 23:25:24 2023 -0700
[TIR] Fix offset_factor in cuda tensor core intrins (#15913)
---
python/tvm/tir/tensor_intrin/cuda.py | 46 +++++++++++++++---------------------
1 file changed, 19 insertions(+), 27 deletions(-)
diff --git a/python/tvm/tir/tensor_intrin/cuda.py
b/python/tvm/tir/tensor_intrin/cuda.py
index 44de418dad..6ee00ee634 100644
--- a/python/tvm/tir/tensor_intrin/cuda.py
+++ b/python/tvm/tir/tensor_intrin/cuda.py
@@ -16,7 +16,6 @@
# under the License.
# pylint: disable=invalid-name,missing-function-docstring
"""Intrinsics for tensorization on NVIDIA GPU."""
-import re
from typing import Dict, Tuple
from typing_extensions import Literal
@@ -44,16 +43,6 @@ def shared_32x16_to_ldmatrix_32x16_layout(i, j):
return thread_id, 8 * (j // 8) + (i // 16) * 4 + i % 4
-def get_tensor_core_load_offset_factor(dtype):
- """get offset factor for tensor core load intrin"""
- bits = re.search(r"(\d+)", dtype).group(0)
- bits = int(bits)
- if bits <= 4:
- # sub-byte oeprations have different offset factor
- return 128 // bits
- return 256 // bits
-
-
@register_func("tir.index_map.shared_16x16_to_ldmatrix_32x8_layout")
def index_map_shared_16x16_to_ldmatrix_32x8_layout(ind):
i, j = ind[0], ind[1]
@@ -127,7 +116,7 @@ def get_ldmatrix_intrin(k_dim, dtype, is_b, transposed,
shared_scope="shared"):
col_dim = k_dim
shmem_shape = (row_dim, col_dim)
- offset_factor = get_tensor_core_load_offset_factor(dtype)
+ offset_factor = col_dim
@T.prim_func
def ldmatrix_desc(warp_handle: T.handle, shared_handle: T.handle) -> None:
@@ -244,8 +233,9 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
return j, i
return i, j
- in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
- out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
+ A_offset_factor = k_dim
+ B_offset_factor = maybe_swap(k_dim, N_DIM)[-1]
+ out_offset_factor = N_DIM
@T.prim_func
def mma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
@@ -254,7 +244,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
(WARP_SIZE, local_size),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=A_offset_factor,
scope="warp",
)
B = T.match_buffer(
@@ -262,7 +252,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
(WARP_SIZE, local_size),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=B_offset_factor,
scope="warp",
)
C = T.match_buffer(
@@ -309,7 +299,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
(WARP_SIZE, local_size),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=A_offset_factor,
scope="warp",
)
B = T.match_buffer(
@@ -317,7 +307,7 @@ def get_mma_intrin(k_dim, out_dtype, b_transposed):
(WARP_SIZE, local_size),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=B_offset_factor,
scope="warp",
)
C = T.match_buffer(
@@ -568,11 +558,11 @@ def get_wmma_load_intrin(
"""Generator of wmma_load intrins"""
wmma_fragment_scope = f"wmma.matrix_{'b' if is_b else 'a'}"
layout = "col_major" if is_col_major else "row_major"
- offset_factor = get_tensor_core_load_offset_factor(dtype)
frag_m, frag_n = (k_dim, n_dim) if is_b else (m_dim, k_dim)
if is_col_major:
frag_m, frag_n = frag_n, frag_m
+ offset_factor = frag_n
@T.prim_func
def wmma_load_desc(a: T.handle, c: T.handle) -> None:
@@ -644,7 +634,7 @@ def get_wmma_fill_intrin(
) -> Tuple[PrimFunc, PrimFunc]:
"""Generator of wmma_fill intrins"""
zero = IntImm("int32", 0).astype(dtype)
- offset_factor = get_tensor_core_load_offset_factor(dtype)
+ offset_factor = n_dim
@T.prim_func
def wmma_fill_desc(c: T.handle) -> None:
@@ -699,7 +689,7 @@ def get_wmma_store_intrin(
m_dim: int, n_dim: int, k_dim: int, dtype: str, scope: str
) -> Tuple[PrimFunc, PrimFunc]:
"""Generator of wmma_store intrins"""
- offset_factor = get_tensor_core_load_offset_factor(dtype)
+ offset_factor = n_dim
@T.prim_func
def wmma_store_desc(a: T.handle, c: T.handle) -> None:
@@ -770,8 +760,6 @@ def get_wmma_sync_intrin(
m_dim: int, n_dim: int, k_dim: int, in_dtype: str, out_dtype: str,
b_transposed: bool
) -> Tuple[PrimFunc, PrimFunc]:
"""Generator of wmma_sync intrins"""
- in_offset_factor = get_tensor_core_load_offset_factor(in_dtype)
- out_offset_factor = get_tensor_core_load_offset_factor(out_dtype)
def maybe_cast(v):
if in_dtype != out_dtype:
@@ -785,6 +773,10 @@ def get_wmma_sync_intrin(
b_shape_0, b_shape_1 = maybe_swap(k_dim, n_dim)
+ A_offset_factor = k_dim
+ B_offset_factor = b_shape_1
+ out_offset_factor = n_dim
+
@T.prim_func
def wmma_sync_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(
@@ -792,7 +784,7 @@ def get_wmma_sync_intrin(
(m_dim, k_dim),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=A_offset_factor,
scope="wmma.matrix_a",
)
B = T.match_buffer(
@@ -800,7 +792,7 @@ def get_wmma_sync_intrin(
maybe_swap(k_dim, n_dim),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=B_offset_factor,
scope="wmma.matrix_b",
)
C = T.match_buffer(
@@ -837,7 +829,7 @@ def get_wmma_sync_intrin(
(m_dim, k_dim),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=A_offset_factor,
scope="wmma.matrix_a",
strides=[a1, a0],
)
@@ -846,7 +838,7 @@ def get_wmma_sync_intrin(
maybe_swap(k_dim, n_dim),
in_dtype,
align=64,
- offset_factor=in_offset_factor,
+ offset_factor=B_offset_factor,
scope="wmma.matrix_b",
strides=[b1, b0],
)