This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new e8995272d1 [Unity][Dlight] Tensorization Rule in GPU Matmul (#15389)
e8995272d1 is described below
commit e8995272d112b3bc9e58a55fce4320c5539b9958
Author: Bao Anchang <[email protected]>
AuthorDate: Thu Jul 27 14:34:59 2023 +0800
[Unity][Dlight] Tensorization Rule in GPU Matmul (#15389)
* feat: dlight matmul tensorize
* lint
* lint and add comment
* fix: use vector size 4
* fix git-black lint
* pylint
* lint
* fix: use share.dyn in write cache
---
python/tvm/dlight/gpu/matmul.py | 271 +++++++++++++++++++++--
tests/python/dlight/test_gpu_matmul.py | 2 +-
tests/python/dlight/test_gpu_matmul_tensorize.py | 261 ++++++++++++++++++++++
3 files changed, 516 insertions(+), 18 deletions(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index def13a60ac..5ea3aa4dee 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -249,8 +249,41 @@ def get_index_map(block: tir.Block) ->
Optional[Tuple[tir.IndexMap, ...]]:
)
-class Matmul(ScheduleRule):
- """The schedule rule for matmul-like computation"""
+def get_reduction_blocks(sch, blocks) -> bool:
+ # Get the main computation block
+ def is_reduction(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+
+ def is_spatial(block: BlockRV) -> bool:
+ block_stmt = sch.get(block)
+ iter_types = {iter_var.iter_type for iter_var in block_stmt.iter_vars}
+ return iter_types == {IterVar.DataPar}
+
+ # NOTE: We assume there is only one reduction block in the function
+ # all blocks are required to be spatial or reduction
+ if not all([is_reduction(block) or is_spatial(block) for block in blocks]):
+ return None
+
+ # There is only one reduction block
+ reduction_blocks = [block for block in blocks if is_reduction(block)]
+ if len(reduction_blocks) != 1:
+ return None
+
+ return reduction_blocks
+
+
+def check_sm_version(arch: str) -> int:
+ sm_version = arch.replace("sm_", "")
+ return int(sm_version) if sm_version.isdigit() else -1
+
+
+class MatmulTensorization(ScheduleRule):
+ """
+ The schedule rule for float16 tensor core matmul computation.
+ func with attr 'dlight.do_not_tensorize' will not be tensorized.
+ """
def apply( # pylint: disable=too-many-locals,missing-docstring
self,
@@ -258,29 +291,216 @@ class Matmul(ScheduleRule):
target: Target,
_: bool,
) -> Optional[tir.Schedule]:
+ from tvm.tir.tensor_intrin.cuda import ( # pylint:
disable=import-outside-toplevel
+ get_wmma_intrin_group,
+ )
+
sch = tir.Schedule(func)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
- # Get the main computation block
- def is_reduction(block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in
block_stmt.iter_vars}
- return iter_types == {IterVar.CommReduce, IterVar.DataPar}
+ if func.attrs is not None and "dlight.do_not_tensorize" in
func.attrs.keys():
+ return None
- def is_spatial(block: BlockRV) -> bool:
- block_stmt = sch.get(block)
- iter_types = {iter_var.iter_type for iter_var in
block_stmt.iter_vars}
- return iter_types == {IterVar.DataPar}
+ reduction_blocks = get_reduction_blocks(sch, blocks)
+ if reduction_blocks is None:
+ return None
- # NOTE: We assume there is only one reduction block in the function
- # all blocks are required to be spatial or reduction
- if not all([is_reduction(block) or is_spatial(block) for block in
blocks]):
+ main_block = reduction_blocks[0]
+ block_stmt = sch.get(main_block)
+ index_maps = get_index_map(block_stmt)
+ if index_maps is None:
return None
+ matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+ # Start Schedule
+ # Step 0. Get schedule config.
+ # NOTE: we can analyze the config by the hardware spec in the future
+
+ # tensor core intrinsic size
+ micro_size_x = 16
+ micro_size_y = 16
+ micro_size_k = 16
+
+ i_factors, j_factors, k_factors = (
+ [None, 1, 2, 2],
+ [1, None, 2, 2],
+ [None, 2],
+ )
- # There is only one reduction block
- reduction_blocks = [block for block in blocks if is_reduction(block)]
- if len(reduction_blocks) != 1:
+ num_ty = i_factors[2] * j_factors[2]
+ x_pad_factor = i_factors[2] * i_factors[3]
+ y_pad_factor = j_factors[2] * j_factors[3]
+ k_pad_factor = k_factors[1]
+
+ # Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S,
J, K]
+ block = sch.reindex(main_block, ("read", 0))
+ sch.transform_layout(block, ("write", 0), a_index_map)
+ block = sch.reindex(main_block, ("read", 1))
+ sch.transform_layout(block, ("write", 0), b_index_map)
+ block = sch.reindex(main_block, ("write", 0))
+ sch.transform_layout(block, ("read", 0), c_index_map)
+ sch.transform_block_layout(main_block, matmul_index_map)
+
+ # Step 2. Padding for dynamic shape kernels
+ sch.pad_einsum(
+ main_block,
+ [
+ 1,
+ micro_size_x * x_pad_factor,
+ micro_size_y * y_pad_factor,
+ micro_size_k * k_pad_factor,
+ ],
+ )
+
+ # Step 3. Schedule matmul to use tensor core
+ block = main_block
+
+ batch, i, j, k = sch.get_loops(block)
+
+ # inner loops for tensor core computation
+ i, i_inner = sch.split(i, factors=[None, micro_size_x])
+ j, j_inner = sch.split(j, factors=[None, micro_size_y])
+ k, k_inner = sch.split(k, factors=[None, micro_size_k])
+
+ sch.reorder(i, j, k, i_inner, j_inner, k_inner)
+
+ block_inner = block
+ block_outer = sch.blockize(i_inner)
+
+ i0, i1, i2, i3 = sch.split(i, factors=i_factors)
+ j0, j1, j2, j3 = sch.split(j, factors=j_factors)
+ k0, k1 = sch.split(k, k_factors)
+
+ sch.reorder(i0, j0, i1, j1, j2, i2, k0, k1, i3, j3)
+
+ block_idx = sch.fuse(i0, j0)
+ block_idy = sch.fuse(i1, j1)
+ thread_idy = sch.fuse(j2, i2)
+ sch.bind(batch, "blockIdx.z")
+ sch.bind(block_idx, "blockIdx.x")
+ sch.bind(block_idy, "blockIdx.y")
+ sch.bind(thread_idy, "threadIdx.y")
+
+ def fetch_to_shared(block, idx, ndim):
+ block_read = sch.cache_read(block, idx, "shared.dyn")
+ sch.compute_at(block_read, k0)
+ vector_size = 4
+ warp_size = 32
+ fused = sch.fuse(*sch.get_loops(block_read)[-ndim:])
+
+ _, f_1, f_2, f_3 = sch.split(fused, factors=[None, num_ty,
warp_size, vector_size])
+
+ sch.bind(f_2, "threadIdx.x")
+ sch.bind(f_1, "threadIdx.y")
+ sch.vectorize(f_3)
+
+ sch.storage_align(block_read, 0, axis=-2, factor=16, offset=8)
+ return block_read
+
+ a_g2s = fetch_to_shared(block_outer, 0, 2)
+ b_g2s = fetch_to_shared(block_outer, 1, 2)
+
+ auto_inline_producers(sch, a_g2s)
+ auto_inline_producers(sch, b_g2s)
+
+ # create read cache to load matrix from shared memory to wmma fragments
+ A_mat = sch.cache_read(block_outer, 0, "wmma.matrix_a")
+ B_mat = sch.cache_read(block_outer, 1, "wmma.matrix_b")
+ sch.compute_at(A_mat, k1)
+ sch.compute_at(B_mat, k1)
+
+ # create write cache to store matrix from wmma fragments to shared
memory and global memory
+ accumulator_shared_to_global = sch.cache_write(block_outer, 0,
"shared.dyn")
+ sch.storage_align(accumulator_shared_to_global, 0, -2, 16, 4)
+
+ store = sch.cache_write(block_outer, 0, "wmma.accumulator")
+ sch.reverse_compute_at(store, thread_idy)
+ sch.reverse_compute_at(accumulator_shared_to_global, thread_idy)
+
+ # split the store loop to match hardware intrinsic pattern
+ i, j = sch.get_loops(store)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+
+ block_init_c = sch.decompose_reduction(block_outer, k0)
+ block_init_c_inner = sch.get_child_blocks(block_init_c)[0]
+
+ # Tensorization by hardware intrinsics
+ intrin_group = get_wmma_intrin_group(
+ load_scope="shared.dyn",
+ store_scope="shared.dyn",
+ in_dtype="float16",
+ out_dtype="float32",
+ trans_b=True,
+ )
+
+ try:
+ i, j = sch.get_loops(A_mat)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+ sch.unroll(i0)
+ sch.unroll(j0)
+ sch.tensorize(i1, intrin_group["load_a"])
+
+ i, j = sch.get_loops(B_mat)[-2:]
+ i0, i1 = sch.split(i, factors=[None, 16])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+ sch.unroll(i0)
+ sch.unroll(j0)
+ sch.tensorize(i1, intrin_group["load_b"])
+ except: # pylint: disable=bare-except
+ return None
+
+ # Try to tensorize the init, store and compute block with f16 or f32
intrinsics
+ tensorize_success: bool = False
+
+ def tensorize_init_store_compute():
+ sch.tensorize(sch.get_loops(block_init_c_inner)[-2],
intrin_group["init"])
+ sch.tensorize(sch.get_loops(store)[-2], intrin_group["store"])
+ sch.tensorize(sch.get_loops(block_inner)[-3],
intrin_group["compute"])
+
+ try:
+ tensorize_init_store_compute()
+ tensorize_success = True
+ except: # pylint: disable=bare-except
+ intrin_group = get_wmma_intrin_group(
+ load_scope="shared.dyn",
+ store_scope="shared.dyn",
+ in_dtype="float16",
+ out_dtype="float16",
+ trans_b=True,
+ )
+
+ if not tensorize_success:
+ try:
+ tensorize_init_store_compute()
+ tensorize_success = True
+ except: # pylint: disable=bare-except
+ return None
+
+ auto_inline_consumers(sch, accumulator_shared_to_global)
+ return sch if tensorize_success else None
+
+
+class Matmul(ScheduleRule):
+ """The schedule rule for matmul-like computation"""
+
+ def apply( # pylint: disable=too-many-locals,missing-docstring
+ self,
+ func: tir.PrimFunc,
+ target: Target,
+ _: bool,
+ ) -> Optional[tir.Schedule]:
+ sch = tir.Schedule(func)
+ root_block = analysis.get_root_block(sch)
+ blocks = sch.get_child_blocks(root_block)
+
+ reduction_blocks = get_reduction_blocks(sch, blocks)
+ if reduction_blocks is None:
return None
main_block = reduction_blocks[0]
@@ -302,6 +522,11 @@ class Matmul(ScheduleRule):
micro_size_k = 16
vector_size = 2
+ # Tensorization config:
+ # If any value of I, J, K is fixed and less than this threshold,
+ # tensorization rule will not be applied.
+ minimal_tensorize_threshold = 128
+
# Step 1. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S,
J, K]
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
@@ -311,6 +536,18 @@ class Matmul(ScheduleRule):
sch.transform_layout(block, ("read", 0), c_index_map)
sch.transform_block_layout(main_block, matmul_index_map)
+ block_stmt = sch.get(main_block)
+ if target.kind.name == "cuda" and check_sm_version(target.arch) > 70:
+ apply_tensorization: bool = True
+ # the batch dimension is not taken into consideration.
+ for item_var in block_stmt.iter_vars[1:]:
+ extent = item_var.dom.extent
+ if isinstance(extent, tir.expr.IntImm):
+ if extent.value <= minimal_tensorize_threshold:
+ apply_tensorization = False
+ if apply_tensorization:
+ return MatmulTensorization().apply(func, target, _)
+
# Step 2. Padding for dynamic shape kernels
sch.pad_einsum(
main_block,
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index b9ee95b76b..bc6419160b 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -28,7 +28,7 @@ class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
@pytest.fixture
def transform(self):
def transform(mod):
- with Target("nvidia/geforce-rtx-3090-ti"):
+ with Target("nvidia/geforce-gtx-1080-ti"):
return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
return transform
diff --git a/tests/python/dlight/test_gpu_matmul_tensorize.py
b/tests/python/dlight/test_gpu_matmul_tensorize.py
new file mode 100644
index 0000000000..349b4bf256
--- /dev/null
+++ b/tests/python/dlight/test_gpu_matmul_tensorize.py
@@ -0,0 +1,261 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=missing-docstring
+import pytest
+
+import tvm.testing
+from tvm import dlight as dl
+from tvm.script import ir as I
+from tvm.script import tir as T
+from tvm.target import Target
+
+
+class BaseBeforeAfter(tvm.testing.CompareBeforeAfter):
+ @pytest.fixture
+ def transform(self):
+ def transform(mod):
+ with Target("nvidia/geforce-rtx-2080-ti"):
+ return dl.ApplyDefaultSchedule(dl.gpu.Matmul())(mod)
+
+ return transform
+
+
+class TestMatmulTensorize(BaseBeforeAfter):
+ # fmt: off
+
+ @T.prim_func
+ def before(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256),
"float16"), compute: T.Buffer((256, 256), "float16")):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ for i, j, k in T.grid(256, 256, 256):
+ with T.block("compute"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(X[v_i, v_k], W[v_j, v_k])
+ T.writes(compute[v_i, v_j])
+ with T.init():
+ compute[v_i, v_j] = T.float16(0)
+ compute[v_i, v_j] = compute[v_i, v_j] + X[v_i, v_k] * W[v_j,
v_k]
+
+ @T.prim_func
+ def expected(X: T.Buffer((256, 256), "float16"), W: T.Buffer((256, 256),
"float16"), compute: T.Buffer((256, 256), "float16")):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ X_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16",
scope="shared.dyn")
+ W_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16",
scope="shared.dyn")
+ X_reindex_shared_dyn_wmma_matrix_a = T.alloc_buffer((1, 256, 256),
"float16", scope="wmma.matrix_a")
+ W_reindex_shared_dyn_wmma_matrix_b = T.alloc_buffer((1, 256, 256),
"float16", scope="wmma.matrix_b")
+ compute_reindex_shared_dyn = T.alloc_buffer((1, 256, 256), "float16",
scope="shared.dyn")
+ compute_reindex_shared_dyn_wmma_accumulator = T.alloc_buffer((1, 256,
256), "float16", scope="wmma.accumulator")
+ for ax0 in T.thread_binding(T.int64(1), thread="blockIdx.z"):
+ for ax1_0_0_ax2_0_0_fused in T.thread_binding(4,
thread="blockIdx.x"):
+ for ax1_0_1_ax2_0_1_fused in T.thread_binding(4,
thread="blockIdx.y"):
+ for ax2_0_2_ax1_0_2_fused in T.thread_binding(4,
thread="threadIdx.y"):
+ for ax1_0_3_init, ax2_0_3_init in T.grid(2, 2):
+ with T.block("compute_o_init"):
+ v0_o = T.axis.spatial(T.int64(1), ax0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3_init)
+ v2_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax2_0_3_init)
+ T.reads()
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("compute_init_o"):
+ v1_i_init_o = T.axis.spatial(1, 0)
+ v2_i_init_o = T.axis.spatial(1, 0)
+ T.reads()
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ C =
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0",
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_fill_fragment(C.data, 16, 16, 16,
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset %
C.strides[0] // 16, T.float32(0))
+ for ax3_0_0 in range(8):
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_1 in T.thread_binding(4,
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with
T.block("X_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(256,
ax1_0_0_ax2_0_0_fused * 64 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 32)
+ v2 = T.axis.spatial(256,
ax3_0_0 * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2
* 4 + ax0_ax1_fused_3) % 32)
+ T.reads(X[v1, v2])
+
T.writes(X_reindex_shared_dyn[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+ X_reindex_shared_dyn[v0, v1,
v2] = X[v1, v2]
+ for ax0_ax1_fused_0 in range(4):
+ for ax0_ax1_fused_1 in T.thread_binding(4,
thread="threadIdx.y"):
+ for ax0_ax1_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
+ for ax0_ax1_fused_3 in T.vectorized(4):
+ with
T.block("W_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(256,
ax1_0_1_ax2_0_1_fused * 64 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 +
ax0_ax1_fused_2 * 4 + ax0_ax1_fused_3) // 32)
+ v2 = T.axis.spatial(256,
ax3_0_0 * 32 + (ax0_ax1_fused_0 * 512 + ax0_ax1_fused_1 * 128 + ax0_ax1_fused_2
* 4 + ax0_ax1_fused_3) % 32)
+ T.reads(W[v1, v2])
+
T.writes(W_reindex_shared_dyn[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 16, 8]]})
+ W_reindex_shared_dyn[v0, v1,
v2] = W[v1, v2]
+ for ax3_0_1 in range(2):
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with
T.block("X_reindex_shared.dyn_wmma.matrix_a_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax0_0)
+ v2_o = T.axis.spatial(16, ax3_0_0
* 2 + ax3_0_1 + ax1_0)
+ T.reads(X_reindex_shared_dyn[v0_o,
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+
T.writes(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(X_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"),
scope="shared.dyn", offset_factor=16)
+ C =
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[v0_o, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"),
scope="wmma.matrix_a", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16,
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16,
T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset,
A.strides[0] * 16, 1), A.strides[0], "row_major")
+ for ax0_0 in T.unroll(2):
+ for ax1_0 in T.unroll(1):
+ with
T.block("W_reindex_shared.dyn_wmma.matrix_b_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax0_0)
+ v2_o = T.axis.spatial(16, ax3_0_0
* 2 + ax3_0_1 + ax1_0)
+ T.reads(W_reindex_shared_dyn[v0_o,
v1_o * 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+
T.writes(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 + 16,
v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(W_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o *
16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"),
scope="shared.dyn", offset_factor=16)
+ C =
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[v0_o, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"),
scope="wmma.matrix_b", offset_factor=16)
+ T.tvm_load_matrix_sync(C.data, 16,
16, 16, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16,
T.tvm_access_ptr(T.type_annotation("float16"), A.data, A.elem_offset,
A.strides[0] * 16, 1), A.strides[0], "col_major")
+ for ax1_0_3, ax2_0_3 in T.grid(2, 2):
+ with T.block("compute_o_update"):
+ v0_o = T.axis.spatial(T.int64(1), ax0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax1_0_3)
+ v2_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax2_0_3)
+ v3_o = T.axis.reduce(16, ax3_0_0 * 2 +
ax3_0_1)
+
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o *
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16],
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o
* 16 + 16])
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ with T.block("compute_o"):
+ v1_i_o = T.axis.spatial(1, 0)
+ v2_i_o = T.axis.spatial(1, 0)
+ v3_i_o = T.axis.reduce(1, 0)
+
T.reads(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16], X_reindex_shared_dyn_wmma_matrix_a[0, v1_o *
16:v1_o * 16 + 16, v3_o * 16:v3_o * 16 + 16],
W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16, v3_o * 16:v3_o
* 16 + 16])
+
T.writes(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(X_reindex_shared_dyn_wmma_matrix_a[0, v1_o * 16:v1_o * 16 + 16,
v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("A_s0", "A_s1"),
scope="wmma.matrix_a", offset_factor=16)
+ B =
T.match_buffer(W_reindex_shared_dyn_wmma_matrix_b[0, v2_o * 16:v2_o * 16 + 16,
v3_o * 16:v3_o * 16 + 16], (16, 16), "float16", strides=("B_s0", "B_s1"),
scope="wmma.matrix_b", offset_factor=16)
+ C =
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[0, v1_o * 16:v1_o *
16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0",
"C_s1"), scope="wmma.accumulator", offset_factor=16)
+ T.tvm_mma_sync(C.data,
C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) + C.elem_offset %
C.strides[0] // 16, A.data, A.elem_offset // A.strides[0] // 16 * (A.strides[0]
// 16) + A.elem_offset % A.strides[0] // 16, B.data, B.elem_offset //
B.strides[0] // 16 * (B.strides[0] // 16) + B.elem_offset % B.strides[0] // 16,
C.data, C.elem_offset // C.strides[0] // 16 * (C.strides[0] // 16) +
C.elem_offset % C.strides[0] // 16)
+ for ax0_0, ax1_0 in T.grid(2, 2):
+ with
T.block("compute_reindex_shared.dyn_wmma.accumulator_o"):
+ v0_o = T.axis.spatial(1, 0)
+ v1_o = T.axis.spatial(16,
ax1_0_0_ax2_0_0_fused * 4 + ax2_0_2_ax1_0_2_fused % 2 * 2 + ax0_0)
+ v2_o = T.axis.spatial(16,
ax1_0_1_ax2_0_1_fused * 4 + ax2_0_2_ax1_0_2_fused // 2 * 2 + ax1_0)
+
T.reads(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o * 16 +
16, v2_o * 16:v2_o * 16 + 16])
+ T.writes(compute_reindex_shared_dyn[v0_o, v1_o
* 16:v1_o * 16 + 16, v2_o * 16:v2_o * 16 + 16])
+ A =
T.match_buffer(compute_reindex_shared_dyn_wmma_accumulator[v0_o, v1_o * 16:v1_o
* 16 + 16, v2_o * 16:v2_o * 16 + 16], (16, 16), "float16", strides=("A_s0",
"A_s1"), scope="wmma.accumulator", offset_factor=16)
+ C =
T.match_buffer(compute_reindex_shared_dyn[v0_o, v1_o * 16:v1_o * 16 + 16, v2_o
* 16:v2_o * 16 + 16], (16, 16), "float16", strides=("C_s0", "C_s1"),
scope="shared.dyn", offset_factor=16)
+ T.tvm_store_matrix_sync(A.data, 16, 16, 16,
A.elem_offset // A.strides[0] // 16 * (A.strides[0] // 16) + A.elem_offset %
A.strides[0] // 16, T.tvm_access_ptr(T.type_annotation("float16"), C.data,
C.elem_offset, C.strides[0] * 16, 2), C.strides[0], "row_major")
+ for ax0_1, ax1 in T.grid(32, 32):
+ with T.block("compute_reindex_shared.dyn"):
+ v0 = T.axis.spatial(1, 0)
+ v1 = T.axis.spatial(256, ax1_0_0_ax2_0_0_fused
* 64 + ax2_0_2_ax1_0_2_fused % 2 * 32 + ax0_1)
+ v2 = T.axis.spatial(256, ax1_0_1_ax2_0_1_fused
* 64 + ax2_0_2_ax1_0_2_fused // 2 * 32 + ax1)
+ T.reads(compute_reindex_shared_dyn[v0, v1, v2])
+ T.writes(compute[v1, v2])
+ T.block_attr({"buffer_dim_align": [[0, 1, 16,
4]]})
+ compute[v1, v2] =
compute_reindex_shared_dyn[v0, v1, v2]
+
+ # fmt: on
+
+
+class TestMatmulTensorizeTooSmall(BaseBeforeAfter):
+ # fmt: off
+
+ @T.prim_func
+ def before(var_X: T.handle, W: T.Buffer((15, 256), "float16"),
var_compute: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.noalias": T.bool(True)})
+ m = T.int32()
+ X = T.match_buffer(var_X, (m, 256), "float16")
+ compute = T.match_buffer(var_compute, (m, 15))
+ # with T.block("root"):
+ for i, j, k in T.grid(m, 15, 256):
+ with T.block("compute"):
+ v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
+ T.reads(X[v_i, v_k], W[v_j, v_k])
+ T.writes(compute[v_i, v_j])
+ with T.init():
+ compute[v_i, v_j] = T.float32(0)
+ compute[v_i, v_j] = compute[v_i, v_j] + T.Cast("float32",
X[v_i, v_k]) * T.Cast("float32", W[v_j, v_k])
+
+ @T.prim_func
+ def expected(var_X: T.handle, W: T.Buffer((15, 256), "float16"),
var_compute: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ m = T.int32()
+ X = T.match_buffer(var_X, (m, 256), "float16")
+ compute = T.match_buffer(var_compute, (m, 15))
+ # with T.block("root"):
+ compute_reindex_pad_local = T.alloc_buffer((1, (T.Cast("int32",
T.Cast("int64", m)) + 31) // 32 * 32, 64), scope="local")
+ X_reindex_pad_shared = T.alloc_buffer((1, (T.Cast("int32",
T.Cast("int64", m)) + 31) // 32 * 32, 256), "float16", scope="shared")
+ W_reindex_pad_shared = T.alloc_buffer((1, 64, 256), "float16",
scope="shared")
+ for ax0_ax2_0_fused in T.thread_binding(T.int64(1),
thread="blockIdx.y"):
+ for ax1_0 in T.thread_binding((T.Cast("int32", T.Cast("int64", m))
+ 31) // 32, thread="blockIdx.x"):
+ for ax2_1 in T.thread_binding(1, thread="vthread.y"):
+ for ax1_1 in T.thread_binding(1, thread="vthread.x"):
+ for ax2_2 in T.thread_binding(16,
thread="threadIdx.y"):
+ for ax1_2 in T.thread_binding(8,
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 256,
"pragma_unroll_explicit": 1}):
+ for ax2_3_init, ax1_3_init in T.grid(4, 4):
+ with T.block("compute_init"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 = T.axis.spatial((T.Cast("int32",
T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 * 32 + ax1_1 * 32 + ax1_2 * 4 +
ax1_3_init)
+ v2 = T.axis.spatial(64, ax2_1 * 64 +
ax2_2 * 4 + ax2_3_init)
+ T.reads()
+ T.writes(compute_reindex_pad_local[0,
v1, v2])
+ compute_reindex_pad_local[0, v1, v2] =
T.float32(0)
+ for ax3_0 in range(16):
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(16, thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(8, thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(2):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(2):
+ with
T.block("X_reindex_pad_shared"):
+ v0 = T.axis.spatial(1,
0)
+ v1 =
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 *
32 + (ax0_ax1_ax2_fused_0 * 32 + ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2
* 2 + ax0_ax1_ax2_fused_3) // 16)
+ v2 =
T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 32 +
ax0_ax1_ax2_fused_1 * 4 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+ T.reads(X[v1, v2])
+
T.writes(X_reindex_pad_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
X_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < m, X[v1, v2],
T.float16(0))
+ for ax0_ax1_ax2_fused_0 in
T.thread_binding(16, thread="threadIdx.y"):
+ for ax0_ax1_ax2_fused_1 in
T.thread_binding(8, thread="threadIdx.x"):
+ for ax0_ax1_ax2_fused_2 in
range(4):
+ for ax0_ax1_ax2_fused_3 in
T.vectorized(2):
+ with
T.block("W_reindex_pad_shared"):
+ v0 = T.axis.spatial(1,
0)
+ v1 =
T.axis.spatial(64, (ax0_ax1_ax2_fused_0 * 64 + ax0_ax1_ax2_fused_1 * 8 +
ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) // 16)
+ v2 =
T.axis.spatial(256, ax3_0 * 16 + (ax0_ax1_ax2_fused_0 * 64 +
ax0_ax1_ax2_fused_1 * 8 + ax0_ax1_ax2_fused_2 * 2 + ax0_ax1_ax2_fused_3) % 16)
+ T.reads(W[v1, v2])
+
T.writes(W_reindex_pad_shared[v0, v1, v2])
+
T.block_attr({"buffer_dim_align": [[0, 1, 8, 2]]})
+
W_reindex_pad_shared[v0, v1, v2] = T.if_then_else(v1 < 15, W[v1, v2],
T.float16(0))
+ for ax3_1, ax2_3, ax1_3 in T.grid(16, 4,
4):
+ with T.block("compute_update"):
+ v0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v1 =
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 *
32 + ax1_1 * 32 + ax1_2 * 4 + ax1_3)
+ v2 = T.axis.spatial(64, ax2_1 * 64
+ ax2_2 * 4 + ax2_3)
+ v3 = T.axis.reduce(256, ax3_0 * 16
+ ax3_1)
+
T.reads(compute_reindex_pad_local[0, v1, v2], X_reindex_pad_shared[0, v1, v3],
W_reindex_pad_shared[0, v2, v3])
+
T.writes(compute_reindex_pad_local[0, v1, v2])
+ compute_reindex_pad_local[0, v1,
v2] = compute_reindex_pad_local[0, v1, v2] + T.Cast("float32",
X_reindex_pad_shared[0, v1, v3]) * T.Cast("float32", W_reindex_pad_shared[0,
v2, v3])
+ for ax0, ax1, ax2_0 in T.grid(1, 4, 2):
+ for ax2_1_1 in T.vectorized(2):
+ with
T.block("compute_reindex_pad_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 =
T.axis.spatial((T.Cast("int32", T.Cast("int64", m)) + 31) // 32 * 32, ax1_0 *
32 + ax1_2 * 4 + ax1)
+ v2 = T.axis.spatial(64, ax2_2 * 4
+ ax2_0 * 2 + ax2_1_1)
+
T.reads(compute_reindex_pad_local[v0, v1, v2])
+ T.writes(compute[v1, v2])
+ if v1 < m and v2 < 15:
+ compute[v1, v2] =
compute_reindex_pad_local[v0, v1, v2]
+ # fmt: on
+
+
+if __name__ == "__main__":
+ tvm.testing.main()