adfwer233 commented on code in PR #15389:
URL: https://github.com/apache/tvm/pull/15389#discussion_r1274333123
##########
python/tvm/dlight/gpu/matmul.py:
##########
@@ -249,38 +249,258 @@ def get_index_map(block: tir.Block) ->
Optional[Tuple[tir.IndexMap, ...]]:
)
-class Matmul(ScheduleRule):
- """The schedule rule for matmul-like computation"""
+def get_reduction_blocks(sch, blocks) -> bool:
+ # Get the main computation block
+ def is_reduction(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+
+ def is_spatial(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.DataPar}
+
+ # NOTE: We assume there is only one reduction block in the function
+ # all blocks are required to be spatial or reduction
+ if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
+ return None
+
+ # There is only one reduction block
+ reduction_blocks = [block for block in blocks if is_reduction(block)]
+ if len(reduction_blocks) != 1:
+ return None
+
+ return reduction_blocks
+
+
+def check_sm_version(arch: str) -> int:
+ sm_version = arch.replace("sm_", "")
+ return int(sm_version) if sm_version.isdigit() else -1
+
+
+class MatmulTensorization(ScheduleRule):
+ """
+ The schedule rule for float16 tensor core matmul computation.
+ func with attr 'dlight.do_not_tensorize' will not be tensorized.
+ """
def apply( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
+ from tvm.tir.tensor_intrin.cuda import ( # pylint:
disable=import-outside-toplevel
+ get_wmma_intrin_group,
+ )
+
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
- # Get the main computation block
- def is_reduction(block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in
block_stmt.iter_vars}
- return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+ if func.attrs is not None and "dlight.do_not_tensorize" in
func.attrs.keys():
+ return None
- def is_spatial(block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in
block_stmt.iter_vars}
- return iter_types == {IterVar.DataPar}
+ reduction_blocks = get_reduction_blocks(sch, blocks)
+ if reduction_blocks is None:
+ return None
- # NOTE: We assume there is only one reduction block in the function
- # all blocks are required to be spatial or reduction
- if not all([is_reduction(block) or is_spatial(block) for block in
blocks]):
+ main_block = reduction_blocks[0]
+ block_stmt = sch.get(main_block)
+ index_maps = get_index_map(block_stmt)
+ if index_maps is None:
return None
+ matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+ # Start Schedule
+ # Step 0. Get schedule config.
+ # NOTE: we can analyze the config by the hardware spec in the future
+
+ # tensor core intrinsic size
+ micro_size_x = 16
+ micro_size_y = 16
+ micro_size_k = 16
+
+ i_factors, j_factors, k_factors = (
+ [None, 1, 2, 2],
+ [1, None, 2, 2],
+ [None, 2],
+ )
- # There is only one reduction block
- reduction_blocks = [block for block in blocks if is_reduction(block)]
- if len(reduction_blocks) != 1:
+ num_ty = i_factors[2] * j_factors[2]
+ x_pad_factor = i_factors[2] * i_factors[3]
+ y_pad_factor = j_factors[2] * j_factors[3]
+ k_pad_factor = k_factors[1]
+
+ # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S,
J, K]
+ block = sch.reindex(main_block, ("read", 0))
+ sch.transform_layout(block, ("write", 0), a_index_map)
+ block = sch.reindex(main_block, ("read", 1))
+ sch.transform_layout(block, ("write", 0), b_index_map)
+ block = sch.reindex(main_block, ("write", 0))
+ sch.transform_layout(block, ("read", 0), c_index_map)
+ sch.transform_block_layout(main_block, matmul_index_map)
+
+ # Step 2. Padding for dynamic shape kernels
+ sch.pad_einsum(
+ main_block,
+ [
+ 1,
+ micro_size_x * x_pad_factor,
+ micro_size_y * y_pad_factor,
+ micro_size_k * k_pad_factor,
+ ],
+ )
+
+ # Step 3. Schedule matmul to use tensor core
+ block = main_block
+
+ batch, i, j, k = sch.get_loops(block)
+
+ # inner loops for tensor core computation
+ i, i_inner = sch.split(i, factors=[None, micro_size_x])
+ j, j_inner = sch.split(j, factors=[None, micro_size_y])
+ k, k_inner = sch.split(k, factors=[None, micro_size_k])
+
+ sch.reorder(i, j, k, i_inner, j_inner, k_inner)
+
+ block_inner = block
+ block_outer = sch.blockize(i_inner)
+
+ i0, i1, i2, i3 = sch.split(i, factors=i_factors)
+ j0, j1, j2, j3 = sch.split(j, factors=j_factors)
+ k0, k1 = sch.split(k, k_factors)
+
+ sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)
+
+ block_idx = sch.fuse(i0, j0)
+ block_idy = sch.fuse(i1, j1)
+ thread_idy = sch.fuse(j2, i2)
+ sch.bind(batch, "blockIdx.z")
+ sch.bind(block_idx, "blockIdx.x")
+ sch.bind(block_idy, "blockIdx.y")
+ sch.bind(thread_idy, "threadIdx.y")
+
+ def fetch_to_shared(block, idx, ndim):
+ block_read = sch.cache_read(block, idx, "shared.dyn")
+ sch.compute_at(block_read, k0)
+ vector_size = 4
+ warp_size = 32
+ fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+
+ _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty,
warp_size, vector_size])
+
+ sch.bind(f_2, "threadIdx.x")
+ sch.bind(f_1, "threadIdx.y")
+ sch.vectorize(f_3)
+
+ sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8)
+ return block_read
+
+ a_g2s = fetch_to_shared(block_outer, 0, 2)
+ b_g2s = fetch_to_shared(block_outer, 1, 2)
+
+ auto_inline_producers(sch, a_g2s)
+ auto_inline_producers(sch, b_g2s)
+
+ # create read cache to load matrix from shared memory to wmma fragments
+ A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
+ B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
+ sch.compute_at(A_mat, k1)
+ sch.compute_at(B_mat, k1)
+
+ # create write cache to store matrix from wmma fragments to shared
memory and global memory
+ accumulator_shared_to_global = sch.cache_write(block_outer, 0,
"shared")
+ sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
+
+ store = sch.cache_write(block_outer, 0, "wmma.accumulator")
+ sch.reverse_compute_at(store, thread_idy)
+ sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
+
+ # split the store loop to match hardware intrinsic pattern
+ i, j = sch.get_loops(store)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+
+ block_init_c = sch.decompose_reduction(block_outer, k0)
+ block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
+
+ # Tensorization by hardware intrinsics
+ intrin_group = get_wmma_intrin_group(
+ load_scope="shared.dyn",
+ store_scope="shared",
Review Comment:
Thanks for your review! I have fixed this issue by using `shared.dyn` in
store buffer.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]