This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 8e8799d709 [Unity][Dlight] Enhance matmul tensorizer with Int8 support
(#16084)
8e8799d709 is described below
commit 8e8799d709046676c267828786fa6c5fa3601fb1
Author: Ivan Sidorenko <[email protected]>
AuthorDate: Wed Nov 8 00:23:37 2023 +0300
[Unity][Dlight] Enhance matmul tensorizer with Int8 support (#16084)
Main goal of this commit was not to provide highly optimized schedule,
but to enable tensor core tensozrization for i8i8i32 matmul. Now it
supports fp16 matmul only.
Anyway, speedup is a case compared to fp16 (fp16fp16->fp16 vs i8i8->i32):
N | M, N, K* | Shape type | fp16, us | int8, us | speedup |
--|----------------------|----------------|----------|----------|---------|
1 | 256, 256, 256 | static shapes | 15.03 | 12.55 | 1.198x |
2 | 512, 512, 4096 | static shapes | 173.25 | 129.06 | 1.342x |
3 | N (=260), 4096, 4096 | dynamic shapes | 350.44 | 298.29 | 1.175x |
--------------------------------------------------------------------------|
where K* - is a reduction axis.
Implementation of Matmul int8 tensorizer was derived from Matmul
tensorizer with minimal changes. It is possible to join these
implementations (fp16 and int8), but I am not sure this is a good idea.
Since it can slow down development of fp16 part.
---
python/tvm/dlight/gpu/matmul.py | 220 +++++++++++++++++-
tests/python/dlight/test_gpu_matmul_tensorize.py | 273 +++++++++++++++++++++++
2 files changed, 492 insertions(+), 1 deletion(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 3549b4bc44..703f9c151f 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -296,6 +296,16 @@ def get_reduction_blocks(sch, blocks) -> bool:
return reduction_blocks
+def get_in_out_dtypes(block: tir.Block) -> Tuple[str]:
+ """
+ Detect In/Out data types for the given block based on the analysis if
read/write buffers.
+ """
+ assert len(block.reads) > 0 and len(block.writes) > 0
+ in_dtype = block.reads[0].buffer.dtype
+ out_dtype = block.writes[0].buffer.dtype
+ return (in_dtype, out_dtype)
+
+
def check_sm_version(arch: str) -> int:
sm_version = arch.replace("sm_", "")
return int(sm_version) if sm_version.isdigit() else -1
@@ -520,6 +530,209 @@ class MatmulTensorization(ScheduleRule):
return sch if tensorize_success else None
+class MatmulInt8Tensorization(ScheduleRule):
+ """
+ The schedule rule for int8 tensor core matmul computation.
+ func with attr 'dlight.do_not_tensorize' will not be tensorized.
+ """
+
+ def apply( # pylint: disable=too-many-locals,missing-docstring
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> Optional[tir.Schedule]:
+ from tvm.tir.tensor_intrin.cuda import ( # pylint:
disable=import-outside-toplevel
+ get_wmma_intrin_group,
+ )
+
+ sch = tir.Schedule(func)
+ root_block = analysis.get_root_block(sch)
+ blocks = sch.get_child_blocks(root_block)
+
+ if func.attrs is not None and "dlight.do_not_tensorize" in
func.attrs.keys():
+ return None
+
+ reduction_blocks = get_reduction_blocks(sch, blocks)
+ if reduction_blocks is None:
+ return None
+
+ main_block = reduction_blocks[0]
+ block_stmt = sch.get(main_block)
+ index_maps = get_index_map(block_stmt)
+ if index_maps is None:
+ return None
+ matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+ # Start Schedule
+ # Step 0. Get schedule config.
+ # NOTE: we can analyze the config by the hardware spec in the future
+
+ # tensor core intrinsic size
+ micro_size_x = 16
+ micro_size_y = 16
+ micro_size_k = 16
+
+ warp_size = 32
+ vector_size = 4
+
+ i_factors, j_factors, k_factors = (
+ [None, 1, 4, 2],
+ [1, None, 4, 2],
+ [None, 1],
+ )
+
+ num_ty = i_factors[2] * j_factors[2]
+ x_pad_factor = i_factors[2] * i_factors[3]
+ y_pad_factor = j_factors[2] * j_factors[3]
+ k_pad_factor = k_factors[1]
+
+ # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S,
J, K]
+ block = sch.reindex(main_block, ("read", 0))
+ sch.transform_layout(block, ("write", 0), a_index_map)
+ block = sch.reindex(main_block, ("read", 1))
+ sch.transform_layout(block, ("write", 0), b_index_map)
+ block = sch.reindex(main_block, ("write", 0))
+ sch.transform_layout(block, ("read", 0), c_index_map)
+ sch.transform_block_layout(main_block, matmul_index_map)
+
+ # Step 2. Padding for dynamic shape kernels
+ sch.pad_einsum(
+ main_block,
+ [
+ 1,
+ micro_size_x * x_pad_factor,
+ micro_size_y * y_pad_factor,
+ micro_size_k * k_pad_factor,
+ ],
+ )
+
+ # Step 3. Schedule matmul to use tensor core
+ block = main_block
+
+ batch, i, j, k = sch.get_loops(block)
+
+ # inner loops for tensor core computation
+ i, i_inner = sch.split(i, factors=[None, micro_size_x])
+ j, j_inner = sch.split(j, factors=[None, micro_size_y])
+ k, k_inner = sch.split(k, factors=[None, micro_size_k])
+
+ sch.reorder(i, j, k, i_inner, j_inner, k_inner)
+
+ block_inner = block
+ block_outer = sch.blockize(i_inner)
+
+ i0, i1, i2, i3 = sch.split(i, factors=i_factors)
+ j0, j1, j2, j3 = sch.split(j, factors=j_factors)
+ k0, k1 = sch.split(k, k_factors)
+ sch.annotate(k0, "software_pipeline_order", [0, 3, 1, 4, 5, 2, 6])
+ sch.annotate(k0, "software_pipeline_stage", [0, 0, 0, 0, 0, 1, 1])
+ sch.annotate(k1, "software_pipeline_order", [0, 1, 2])
+ sch.annotate(k1, "software_pipeline_stage", [0, 0, 1])
+
+ sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)
+
+ block_idx = sch.fuse(i0, j0)
+ block_idy = sch.fuse(i1, j1)
+ thread_idy = sch.fuse(j2, i2)
+ sch.bind(batch, "blockIdx.z")
+ sch.bind(block_idx, "blockIdx.x")
+ sch.bind(block_idy, "blockIdx.y")
+ sch.bind(thread_idy, "threadIdx.y")
+
+ def fetch_to_shared(block, idx, ndim):
+ block_read = sch.cache_read(block, idx, "shared.dyn")
+ sch.compute_at(block_read, k0)
+ fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+
+ _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty,
warp_size, vector_size])
+
+ sch.bind(f_2, "threadIdx.x")
+ sch.bind(f_1, "threadIdx.y")
+ sch.vectorize(f_3)
+
+ sch.storage_align(block_read, 0, axis=-2, factor=32, offset=16)
+ sch.annotate(block_read, "tir.manifest_shared_memory_local_stage",
1)
+ sch.annotate(block_read, "double_buffer_scope", 0)
+ return block_read
+
+ a_g2s = fetch_to_shared(block_outer, 0, 2)
+ b_g2s = fetch_to_shared(block_outer, 1, 2)
+
+ auto_inline_producers(sch, a_g2s)
+ auto_inline_producers(sch, b_g2s)
+
+ # create read cache to load matrix from shared memory to wmma fragments
+ A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
+ B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
+ sch.compute_at(A_mat, k1)
+ sch.compute_at(B_mat, k1)
+
+ # create write cache to store matrix from wmma fragments to shared
memory and global memory
+ accumulator_shared_to_global = sch.cache_write(block_outer, 0,
"shared.dyn")
+ sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
+
+ store = sch.cache_write(block_outer, 0, "wmma.accumulator")
+ sch.reverse_compute_at(store, thread_idy)
+ sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
+
+ # split the store loop to match hardware intrinsic pattern
+ i, j = sch.get_loops(store)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+
+ block_init_c = sch.decompose_reduction(block_outer, k0)
+ block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
+
+ # Tensorization by hardware intrinsics
+ intrin_group = get_wmma_intrin_group(
+ load_scope="shared.dyn",
+ store_scope="shared.dyn",
+ in_dtype="int8",
+ out_dtype="int32",
+ trans_b=True,
+ )
+
+ try:
+ i, j = sch.get_loops(A_mat)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+ sch.unroll(i0)
+ sch.unroll(j0)
+ sch.tensorize(i1, intrin_group["load_a"])
+
+ i, j = sch.get_loops(B_mat)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+ sch.unroll(i0)
+ sch.unroll(j0)
+ sch.tensorize(i1, intrin_group["load_b"])
+ except: # pylint: disable=bare-except
+ return None
+
+ def tensorize_init_store_compute():
+ sch.tensorize(sch.get_loops(block_init_c_inner)[-2],
intrin_group["init"])
+ sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"])
+ sch.tensorize(sch.get_loops(block_inner)[-3],
intrin_group["compute"])
+
+ try:
+ tensorize_init_store_compute()
+ except: # pylint: disable=bare-except
+ return None
+
+ auto_inline_consumer_chain(sch, accumulator_shared_to_global)
+
+ fused = sch.fuse(*sch.get_loops(accumulator_shared_to_global)[-2:])
+ _, f1, f2 = sch.split(fused, factors=[None, warp_size, vector_size])
+ sch.bind(f1, "threadIdx.x")
+ sch.vectorize(f2)
+
+ return sch
+
+
class Matmul(ScheduleRule):
"""The schedule rule for matmul-like computation"""
@@ -618,7 +831,12 @@ class Matmul(ScheduleRule):
if extent.value <= minimal_tensorize_threshold:
apply_tensorization = False
if apply_tensorization:
- tensorize_sch = MatmulTensorization().apply(func, target, _)
+ # Analyze read/write buffers and choose correct tensorizer:
int8 or fp16.
+ in_dtype, out_dtype = get_in_out_dtypes(block_stmt)
+ if in_dtype == "int8" and out_dtype == "int32":
+ tensorize_sch = MatmulInt8Tensorization().apply(func,
target, _)
+ else:
+ tensorize_sch = MatmulTensorization().apply(func, target,
_)
if tensorize_sch is not None:
return tensorize_sch
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py
b/tests/python/dlight/test_gpu_matmul_tensorize.py
index c682c879e0..72ffb30719 100644
--- a/tests/python/dlight/test_gpu_matmul_tensorize.py
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -425,5 +425,278 @@ class TestMatmulTensorizeEpilogue(BaseBeforeAfter):
# fmt: on
+class TestMatmulInt8Tensorize(BaseBeforeAfter):
+ # fmt: off
+ @T.prim_func
+ def before(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256),
"int8"), compute: T.Buffer((256, 256), "int32")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i, j, r in T.grid(256, 256, 256):
+ with T.block("compute"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, r])
+ T.reads(X[v_i, v_k], W[v_j, v_k])
+ T.writes(compute[v_i, v_j])
+ with T.init():
+ compute[v_i, v_j] = 0
+ compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("int32", X[v_i,
v_k]) * T.Cast("int32", W[v_j, v_k])
+
+ @T.prim_func
+ def expected(X: T.Buffer((256, 256), "int8"), W: T.Buffer((256, 256),
"int8"), compute: T.Buffer((256, 256), "int32")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8",
scope="shared.dyn")
+ W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int8",
scope="shared.dyn")
+ X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256),
"int8", scope="wmma.matrix_a")
+ W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256),
"int8", scope="wmma.matrix_b")
+ compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "int32",
scope="shared.dyn")
+ compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256,
256), "int32", scope="wmma.accumulator")
+ for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+ for ax1_0_0_ax2_0_0_fused in T.thread_binding(2,
thread="blockIdx.x"):
+ for ax1_0_1_ax2_0_1_fused in T.thread_binding(2,
thread="blockIdx.y"):
+ for ax2_0_2_ax1_0_2_fused in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+ with T.block("compute_o_init"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
+ v2_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
+ T.reads()
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("compute_init_o"):
+ v1_i_init_o = T.axis.spatial(1, 0)
+ v2_i_init_o = T.axis.spatial(1, 0)
+ T.reads()
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ C =
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0",
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_fill_fragment(C.data, 16, 16, 16,
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset %
C.strides[0] // 16, T.float32(0))
+ for ax3_0_0 in T.serial(16,
annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6],
"software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}):
+ for ax0_ax1_fused_0 in range(1):
+ for ax0_ax1_fused_1 in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with
T.block("X_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(256,
ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+ v2 = T.axis.spatial(256,
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+ T.reads(X[v1, v2])
+
T.writes(X_reindex_shared_dyn[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0,
"tir.manifest_shared_memory_local_stage": 1})
+ X_reindex_shared_dyn[v0, v1,
v2] = X[v1, v2]
+ for ax0_ax1_fused_0 in range(1):
+ for ax0_ax1_fused_1 in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with
T.block("W_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(256,
ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+ v2 = T.axis.spatial(256,
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+ T.reads(W[v1, v2])
+
T.writes(W_reindex_shared_dyn[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0,
"tir.manifest_shared_memory_local_stage": 1})
+ W_reindex_shared_dyn[v0, v1,
v2] = W[v1, v2]
+ for ax3_0_1 in T.serial(1,
annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage":
[0, 0, 1]}):
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with
T.block("X_reindex_shared.dyn_wmma.matrix_a_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(16, ax3_0_0
+ ax1_0)
+ T.reads(X_reindex_shared_dyn[v0_o,
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+
T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"),
scope="shared.dyn", offset_factor=16)
+ C =
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"),
scope="wmma.matrix_a", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16,
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"),
A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "row_major")
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with
T.block("W_reindex_shared.dyn_wmma.matrix_b_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(16, ax3_0_0
+ ax1_0)
+ T.reads(W_reindex_shared_dyn[v0_o,
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+
T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"),
scope="shared.dyn", offset_factor=16)
+ C =
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"),
scope="wmma.matrix_b", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16,
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"),
A.data, A.elem_offset, A.strides[0] * 16, 1), A.strides[0], "col_major")
+ for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+ with T.block("compute_o_update"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
+ v2_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
+ v3_o = T.axis.reduce(16, ax3_0_0 +
ax3_0_1)
+
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o *
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16],
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o
* 16 + 16])
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("compute_o"):
+ v1_i_o = T.axis.spatial(1, 0)
+ v2_i_o = T.axis.spatial(1, 0)
+ v3_i_o = T.axis.reduce(1, 0)
+
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o *
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16],
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o
* 16 + 16])
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16,
v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"),
scope="wmma.matrix_a", offset_factor=16)
+ B =
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16,
v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"),
scope="wmma.matrix_b", offset_factor=16)
+ C =
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0",
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_mma_sync(C.data,
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset %
C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0]
// 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset //
B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16,
C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16)
+ for ax0_0, ax1_0 in T.grid(2, 2):
+ with
T.block("compute_reindex_shared.dyn_wmma.accumulator_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0)
+
T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ T.writes(compute_reindex_shared_dyn[v0_o, v1_o
* 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o
* 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("A_s0",
"A_s1"), scope="wmma.accumulator", offset_factor=16)
+ C =
T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o
* 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"),
scope="shared.dyn", offset_factor=16)
+ T.tvm_store_matrix_sync(A.data, 16, 16, 16,
A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset %
A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int32"), C.data,
C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major")
+ for ax0_ax1_fused_0 in range(8):
+ for ax0_ax1_fused_1 in T.thread_binding(32,
thread="threadIdx.x"):
+ for ax0_ax1_fused_2 in T.vectorized(4):
+ with T.block("compute_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(256,
ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 + (ax0_ax1_fused_0
* 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
+ v2 = T.axis.spatial(256,
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 +
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+ T.reads(compute_reindex_shared_dyn[v0,
v1, v2])
+ T.writes(compute[v1, v2])
+ T.block_attr({"buffer_dim_align": [[0,
1, 16, 4]]})
+ compute[v1, v2] =
compute_reindex_shared_dyn[v0, v1, v2]
+ # fmt: on
+
+
+class TestMatmulInt8Tensorize3d2dDyn(BaseBeforeAfter):
+ # fmt: off
+ @T.prim_func
+ def before(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"),
var_matmul: T.handle):
+ T.func_attr({"op_pattern": 4, "tir.noalias": T.bool(True)})
+ m = T.int32()
+ A = T.match_buffer(var_A, (1, m, 22016), "int8")
+ matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32")
+ # with T.block("root"):
+ for i0, i1, i2, k in T.grid(1, m, 4096, 22016):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(A[v_i0, v_i1, v_k], B[v_i2, v_k])
+ T.writes(matmul_1[v_i0, v_i1, v_i2])
+ with T.init():
+ matmul_1[v_i0, v_i1, v_i2] = 0
+ matmul_1[v_i0, v_i1, v_i2] = matmul_1[v_i0, v_i1, v_i2] +
T.Cast("int32", A[v_i0, v_i1, v_k]) * T.Cast("int32", B[v_i2, v_k])
+
+ @T.prim_func
+ def expected(var_A: T.handle, B: T.Buffer((4096, 22016), "int8"),
var_matmul: T.handle):
+ T.func_attr({"op_pattern": 4, "tir.is_scheduled": 1, "tir.noalias":
T.bool(True)})
+ m = T.int32()
+ A = T.match_buffer(var_A, (1, m, 22016), "int8")
+ matmul_1 = T.match_buffer(var_matmul, (1, m, 4096), "int32")
+ # with T.block("root"):
+ A_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128 * 128,
22016), "int8", scope="shared.dyn")
+ B_reindex_shared_dyn = T.alloc_buffer((1, 4096, 22016), "int8",
scope="shared.dyn")
+ A_reindex_pad_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, (m + 127)
// 128 * 128, 22016), "int8", scope="wmma.matrix_a")
+ B_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 4096, 22016),
"int8", scope="wmma.matrix_b")
+ matmul_1_reindex_pad_shared_dyn = T.alloc_buffer((1, (m + 127) // 128
* 128, 4096), "int32", scope="shared.dyn")
+ matmul_1_reindex_pad_shared_dyn_wmma_accumulator = T.alloc_buffer((1,
(m + 127) // 128 * 128, 4096), "int32", scope="wmma.accumulator")
+ for ax0 in T.thread_binding(1, thread="blockIdx.z"):
+ for ax1_0_0_ax2_0_0_fused in T.thread_binding((m + 127) // 128,
thread="blockIdx.x"):
+ for ax1_0_1_ax2_0_1_fused in T.thread_binding(32,
thread="blockIdx.y"):
+ for ax2_0_2_ax1_0_2_fused in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+ with T.block("matmul_o_init"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial((m + 127) // 128 * 8,
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3_init)
+ v2_o = T.axis.spatial(256,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3_init)
+ T.reads()
+
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("matmul_init_o"):
+ v1_i_init_o = T.axis.spatial(1, 0)
+ v2_i_init_o = T.axis.spatial(1, 0)
+ T.reads()
+
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16])
+ C =
T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o *
16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32",
strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_fill_fragment(C.data, 16, 16, 16,
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset %
C.strides[0] // 16, T.float32(0))
+ for ax3_0_0 in T.serial(1376,
annotations={"software_pipeline_order": [0, 3, 1, 4, 5, 2, 6],
"software_pipeline_stage": [0, 0, 0, 0, 0, 1, 1]}):
+ for ax0_ax1_fused_0 in range(1):
+ for ax0_ax1_fused_1 in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with
T.block("A_reindex_pad_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((m + 127)
// 128 * 128, ax1_0_0_ax2_0_0_fused * 128 + (ax0_ax1_fused_0 * 2048 +
ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+ v2 = T.axis.spatial(22016,
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+ T.reads(A[v0, v1, v2])
+
T.writes(A_reindex_pad_shared_dyn[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0,
"tir.manifest_shared_memory_local_stage": 1})
+ A_reindex_pad_shared_dyn[v0,
v1, v2] = T.if_then_else(v1 < m, A[v0, v1, v2], T.int8(0))
+ for ax0_ax1_fused_0 in range(1):
+ for ax0_ax1_fused_1 in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with
T.block("B_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(4096,
ax1_0_1_ax2_0_1_fused * 128 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 16)
+ v2 = T.axis.spatial(22016,
ax3_0_0 * 16 + (ax0_ax1_fused_0 * 2048 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) % 16)
+ T.reads(B[v1, v2])
+
T.writes(B_reindex_shared_dyn[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 32, 16]], "double_buffer_scope": 0,
"tir.manifest_shared_memory_local_stage": 1})
+ B_reindex_shared_dyn[v0, v1,
v2] = B[v1, v2]
+ for ax3_0_1 in T.serial(1,
annotations={"software_pipeline_order": [0, 1, 2], "software_pipeline_stage":
[0, 0, 1]}):
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with
T.block("A_reindex_pad_shared.dyn_wmma.matrix_a_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(8 * ((m +
127) // 128), ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(1376,
ax3_0_0 + ax1_0)
+
T.reads(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o
* 16 + 16])
+
T.writes(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16])
+ A_1 =
T.match_buffer(A_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"),
scope="shared.dyn", offset_factor=16)
+ C =
T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16
+ 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"),
scope="wmma.matrix_a", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16,
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"),
A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "row_major")
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with
T.block("B_reindex_shared.dyn_wmma.matrix_b_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(256,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(1376,
ax3_0_0 + ax1_0)
+ T.reads(B_reindex_shared_dyn[v0_o,
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+
T.writes(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16])
+ A_1 =
T.match_buffer(B_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"),
scope="shared.dyn", offset_factor=16)
+ C =
T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int8", strides=("C_s0", "C_s1"),
scope="wmma.matrix_b", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16,
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("int8"),
A_1.data, A_1.elem_offset, A_1.strides[0] * 16, 1), A_1.strides[0], "col_major")
+ for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+ with T.block("matmul_o_update"):
+ v0_o = T.axis.spatial(1, ax0)
+ v1_o = T.axis.spatial((m + 127) // 128
* 8, ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax1_0_3)
+ v2_o = T.axis.spatial(256,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax2_0_3)
+ v3_o = T.axis.reduce(1376, ax3_0_0 +
ax3_0_1)
+
T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16
+ 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o
* 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16],
B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o
* 16 + 16])
+
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("matmul_o"):
+ v1_i_o = T.axis.spatial(1, 0)
+ v2_i_o = T.axis.spatial(1, 0)
+ v3_i_o = T.axis.reduce(1, 0)
+
T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16
+ 16, v2_o * 16:v2_o * 16 + 16], A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o
* 16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16],
B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o
* 16 + 16])
+
T.writes(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A_1 =
T.match_buffer(A_reindex_pad_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 +
16, v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("A_s0", "A_s1"),
scope="wmma.matrix_a", offset_factor=16)
+ B_1 =
T.match_buffer(B_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16,
v3_o * 16:v3_o * 16 + 16], (16, 16), "int8", strides=("B_s0", "B_s1"),
scope="wmma.matrix_b", offset_factor=16)
+ C =
T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[0, v1_o *
16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32",
strides=("C_s0", "C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_mma_sync(C.data,
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset %
C.strides[0] // 16, A_1.data, A_1.elem_offset // A_1.strides[0] // 16 *
(A_1.strides[0] // 16) + A_1.elem_offset % A_1.strides[0] // 16, B_1.data,
B_1.elem_offset // B_1.strides[0] // 16 * (B_1.strides[0] // 16) +
B_1.elem_offset % B_1.strides[0] // 16, C.data, C.elem_offset // C.strides[0]
// 16 * (C.strides[0] // 16) + C.elem_offset % C.strides [...]
+ for ax0_0, ax1_0 in T.grid(2, 2):
+ with
T.block("matmul_1_reindex_pad_shared.dyn_wmma.accumulator_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(8 * ((m + 127) // 128),
ax1_0_0_ax2_0_0_fused * 8 + ax2_0_2_ax1_0_2_fused % 4 * 2 + ax0_0)
+ v2_o = T.axis.spatial(256,
ax1_0_1_ax2_0_1_fused * 8 + ax2_0_2_ax1_0_2_fused // 4 * 2 + ax1_0)
+
T.reads(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16])
+ T.writes(matmul_1_reindex_pad_shared_dyn[v0_o,
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A_1 =
T.match_buffer(matmul_1_reindex_pad_shared_dyn_wmma_accumulator[v0_o, v1_o *
16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "int32",
strides=("A_s0", "A_s1"), scope="wmma.accumulator", offset_factor=16)
+ C =
T.match_buffer(matmul_1_reindex_pad_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16], (16, 16), "int32", strides=("C_s0", "C_s1"),
scope="shared.dyn", offset_factor=16)
+ T.tvm_store_matrix_sync(A_1.data, 16, 16, 16,
A_1.elem_offset // A_1.strides[0] // 16 * (A_1.strides[0] // 16) +
A_1.elem_offset % A_1.strides[0] // 16,
T.tvm_access_ptr(T.type_annotation("int32"), C.data, C.elem_offset,
C.strides[0] * 16, 2), C.strides[0], "row_major")
+ for ax0_ax1_fused_0 in range(8):
+ for ax0_ax1_fused_1 in T.thread_binding(32,
thread="threadIdx.x"):
+ for ax0_ax1_fused_2 in T.vectorized(4):
+ with
T.block("matmul_1_reindex_pad_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial((m + 127) // 128 *
128, ax1_0_0_ax2_0_0_fused * 128 + ax2_0_2_ax1_0_2_fused % 4 * 32 +
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) // 32)
+ v2 = T.axis.spatial(4096,
ax1_0_1_ax2_0_1_fused * 128 + ax2_0_2_ax1_0_2_fused // 4 * 32 +
(ax0_ax1_fused_0 * 128 + ax0_ax1_fused_1 * 4 + ax0_ax1_fused_2) % 32)
+
T.reads(matmul_1_reindex_pad_shared_dyn[v0, v1, v2])
+ T.writes(matmul_1[0, v1, v2])
+ T.block_attr({"buffer_dim_align": [[0,
1, 16, 4]]})
+ if v1 < m:
+ matmul_1[0, v1, v2] =
matmul_1_reindex_pad_shared_dyn[v0, v1, v2]
+ # fmt: on
+
+
if __name__ == "__main__":
tvm.testing.main()