adfwer233 commented on code in PR #15389:
URL: https://github.com/apache/tvm/pull/15389#discussion_r1272931997
##########
python/tvm/dlight/gpu/matmul.py:
##########
@@ -248,39 +248,256 @@ def get_index_map(block: tir.Block) ->
Optional[Tuple[tir.IndexMap, ...]]:
C_index_map,
)
+def get_reduction_blocks(sch, blocks) -> bool:
+ # Get the main computation block
+ def is_reduction(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+
+ def is_spatial(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.DataPar}
+
+ # NOTE: We assume there is only one reduction block in the function
+ # all blocks are required to be spatial or reduction
+ if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
+ return None
-class Matmul(ScheduleRule):
- """The schedule rule for matmul-like computation"""
+ # There is only one reduction block
+ reduction_blocks = [block for block in blocks if is_reduction(block)]
+ if len(reduction_blocks) != 1:
+ return None
+
+ return reduction_blocks
+
+def check_sm_version(arch: str) -> int:
+ sm_version = arch.replace("sm_", "")
+ return int(sm_version) if sm_version.isdigit() else -1
+
+class MatmulTensorization(ScheduleRule):
+ """
+ The schedule rule for float16 tensor core matmul computation.
+ func with attr 'dlight.do_not_tensorize' will not be tensorized.
+ """
def apply( # pylint: disable=too-many-locals,missing-docstring
self,
func: tir.PrimFunc,
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
+ from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group #
pylint: disable=import-outside-toplevel
+
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
- # Get the main computation block
- def is_reduction(block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in
block_stmt.iter_vars}
- return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+ if func.attrs is not None and "dlight.do_not_tensorize" in
func.attrs.keys():
+ return None
- def is_spatial(block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in
block_stmt.iter_vars}
- return iter_types == {IterVar.DataPar}
+ reduction_blocks = get_reduction_blocks(sch, blocks)
+ if reduction_blocks is None:
+ return None
- # NOTE: We assume there is only one reduction block in the function
- # all blocks are required to be spatial or reduction
- if not all([is_reduction(block) or is_spatial(block) for block in
blocks]):
+ main_block = reduction_blocks[0]
+ block_stmt = sch.get(main_block)
+ index_maps = get_index_map(block_stmt)
+ if index_maps is None:
return None
+ matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+ # Start Schedule
+ # Step 0. Get schedule config.
+ # NOTE: we can analyze the config by the hardware spec in the future
+
+ # tensor core intrinsic size
+ micro_size_x = 16
+ micro_size_y = 16
+ micro_size_k = 16
+
+ i_factors, j_factors, k_factors = (
+ [None, 1, 2, 2],
+ [1, None, 2, 2],
+ [None, 2],
+ )
+
+ num_ty = i_factors[2] * j_factors[2]
+ x_pad_factor = i_factors[2] * i_factors[3]
+ y_pad_factor = j_factors[2] * j_factors[3]
+ k_pad_factor = k_factors[1]
+
+ # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S,
J, K]
+ block = sch.reindex(main_block, ("read", 0))
+ sch.transform_layout(block, ("write", 0), a_index_map)
+ block = sch.reindex(main_block, ("read", 1))
+ sch.transform_layout(block, ("write", 0), b_index_map)
+ block = sch.reindex(main_block, ("write", 0))
+ sch.transform_layout(block, ("read", 0), c_index_map)
+ sch.transform_block_layout(main_block, matmul_index_map)
+
+ # Step 2. Padding for dynamic shape kernels
+ sch.pad_einsum(
+ main_block,
+ [
+ 1,
+ micro_size_x * x_pad_factor,
+ micro_size_y * y_pad_factor,
+ micro_size_k * k_pad_factor,
+ ],
+ )
- # There is only one reduction block
- reduction_blocks = [block for block in blocks if is_reduction(block)]
- if len(reduction_blocks) != 1:
+ # Step 3. Schedule matmul to use tensor core
+ block = main_block
+
+ batch, i, j, k = sch.get_loops(block)
+
+ # inner loops for tensor core computation
+ i, i_inner = sch.split(i, factors=[None, micro_size_x])
+ j, j_inner = sch.split(j, factors=[None, micro_size_y])
+ k, k_inner = sch.split(k, factors=[None, micro_size_k])
+
+ sch.reorder(i, j, k, i_inner, j_inner, k_inner)
+
+ block_inner = block
+ block_outer = sch.blockize(i_inner)
+
+ i0, i1, i2, i3 = sch.split(i, factors=i_factors)
+ j0, j1, j2, j3 = sch.split(j, factors=j_factors)
+ k0, k1 = sch.split(k, k_factors)
+
+ sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)
+
+ block_idx = sch.fuse(i0, j0)
+ block_idy = sch.fuse(i1, j1)
+ thread_idy = sch.fuse(j2, i2)
+ sch.bind(batch, "blockIdx.z")
+ sch.bind(block_idx, "blockIdx.x")
+ sch.bind(block_idy, "blockIdx.y")
+ sch.bind(thread_idy, "threadIdx.y")
+
+ def fetch_to_shared(block, idx, ndim):
+ block_read = sch.cache_read(block, idx, "shared.dyn")
+ sch.compute_at(block_read, k0)
+ vector_size = 8
+ warp_size = 32
+ fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+
+ _, f_1, f_2, f_3 = sch.split(
+ fused, factors=[None, num_ty, warp_size, vector_size]
+ )
+ sch.bind(f_2, "threadIdx.x")
+ sch.bind(f_1, "threadIdx.y")
+ sch.vectorize(f_3)
+
+ sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8)
+ return block_read
+
+ a_g2s = fetch_to_shared(block_outer, 0, 2)
+ b_g2s = fetch_to_shared(block_outer, 1, 2)
+
+ auto_inline_producers(sch, a_g2s)
+ auto_inline_producers(sch, b_g2s)
+
+ # create read cache to load matrix from shared memory to wmma fragments
+ A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
+ B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
+ sch.compute_at(A_mat, k1)
+ sch.compute_at(B_mat, k1)
+
+ # create write cache to store matrix from wmma fragments to shared
memory and global memory
+ accumulator_shared_to_global = sch.cache_write(block_outer, 0,
"shared")
+ sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
+
+ store = sch.cache_write(block_outer, 0, "wmma.accumulator")
+ sch.reverse_compute_at(store, thread_idy)
+ sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
+
+ # split the store loop to match hardware intrinsic pattern
+ i, j = sch.get_loops(store)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+
+ block_init_c = sch.decompose_reduction(block_outer, k0)
+ block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
+
+ # Tensorization by hardware intrinsics
+ intrin_group = get_wmma_intrin_group(
+ load_scope="shared.dyn",
+ store_scope="shared",
+ in_dtype="float16",
+ out_dtype="float32",
+ trans_b=True
Review Comment:
Thanks for quick review!
We can also support NN matmul, I'm going to add this feature and test it.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]