This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bd359685c7 [Unity][DLight] Update gemv rule (#15490)
bd359685c7 is described below
commit bd359685c7d2b6cb666dfa4eca1b6902b22d2a4f
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Aug 8 05:02:24 2023 -0700
[Unity][DLight] Update gemv rule (#15490)
---
python/tvm/dlight/gpu/gemv.py | 380 ++++++++++++++++++-------
python/tvm/dlight/gpu/utils.py | 2 +
tests/python/dlight/test_gpu_gemv.py | 536 ++++++++++++++++++++++++-----------
3 files changed, 653 insertions(+), 265 deletions(-)
diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index 13dee1cd54..b063883800 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -16,6 +16,7 @@
# under the License.
"""A rule for GEMV and DecodeGEMV."""
import re
+from functools import reduce
from typing import List, Optional, Union
from tvm import DataType, arith, ir, tir
@@ -124,6 +125,8 @@ def normalize(
if c_loops:
return None
loop, c_loop = sch.split(loop, factors=[None,
split_expr.lower_factor])
+ # we expect the inner most dim to be grouped atm
+ assert not (is_reduction ^ is_inner_reduction)
c_loops.append(c_loop)
if is_reduction:
r_loops.append(loop)
@@ -169,6 +172,10 @@ class GEMV(ScheduleRule):
return None
block_info = block_infos[0]
+ if len(block_info.iters) not in [2, 3]:
+ # either [B, S, R] = [B, S, R] * [B, R]
+ # or [S, R] = [S, R] * [R]
+ return None
block = block_info.block_rv
vector_input_buffers = is_gemv(sch, block_info)
if vector_input_buffers is None:
@@ -179,14 +186,13 @@ class GEMV(ScheduleRule):
# Step 2. Do the scheduling
if is_inner_reduction:
- # print(func)
self.sch_inner_reduction(sch, target, block, vector_input_buffers,
epilogue)
return sch
else:
# TODO: Need to handle GEMV with KN layout
return None
- def sch_inner_reduction( # pylint: disable=too-many-arguments
+ def sch_inner_reduction( # pylint: disable=too-many-arguments,
invalid-name, unused-argument
self,
sch: tir.Schedule,
target: Target,
@@ -195,106 +201,282 @@ class GEMV(ScheduleRule):
epilogue_info: Optional[BlockInfo],
):
"""Schedule the inner reduction block."""
- # pylint: disable=invalid-name
- _, s, r, _ = sch.get_loops(block)
- # TODO: make it tunable
- vec_bytes = 16 if target.kind.name == "cuda" else 8
- unroll_number = 256 if target.kind.name == "cuda" else 64
+
+ def get_max_factor(n, factors):
+ factors = sorted(factors, reverse=True)
+ for factor in factors:
+ if n % factor == 0:
+ return factor
+ return 1
+
+ def apply(
+ sch: tir.Schedule,
+ gemv,
+ TAG_S,
+ TAG_R,
+ TS,
+ TR,
+ TILE_S,
+ TILE_R,
+ VEC_LOAD,
+ VEC_C,
+ LOAD_V_SHARED,
+ LOAD_V_VEC,
+ UNROLL,
+ ):
+ # rfactor: reduce to tx * vec_c
+ _, s, r, c = sch.get_loops(block=gemv)
+ s = sch.fuse(_, s)
+ r = sch.fuse(r, c)
+ bx, ts, tile_s = sch.split(s, factors=[None, TS, TILE_S],
preserve_unit_iters=True)
+ r, tr, tile_r_vec_n, vec_c = sch.split(
+ r, factors=[None, TR, TILE_R // VEC_C, VEC_C],
preserve_unit_iters=True
+ )
+ sch.reorder(r, tile_r_vec_n, tr, vec_c)
+ tr_vec_c = sch.fuse(tr, vec_c)
+ rf = sch.rfactor(tr_vec_c, 0)
+
+ # rfactor: reduce to tx
+ bx, ts, tile_s, tr_vec_c = sch.get_loops(block=gemv)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ rf2 = sch.rfactor(tr, 0)
+
+ # bind, vectorize compute
+ bx, ts, tile_s, r, tile_r_vec_n, tr_vec_c = sch.get_loops(block=rf)
+ tr, vec_c = sch.split(tr_vec_c, factors=[TR, None],
preserve_unit_iters=True)
+ sch.reorder(bx, ts, tr, r, tile_s, tile_r_vec_n, vec_c)
+ sch.bind(bx, "blockIdx.x")
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+ sch.vectorize(vec_c)
+
+ shared_mem_usage = 0
+ for buf in vector_input_buffers:
+ buf_size = reduce(
+ lambda x, y: x * y, buf.shape,
tir.IntImm(buf.shape[0].dtype, 1)
+ ) * get_bytes(buf.dtype)
+ shared_mem_usage += buf_size
+ LOAD_V_SHARED = (
+ LOAD_V_SHARED
+ and isinstance(shared_mem_usage, tir.IntImm)
+ and shared_mem_usage.value <=
target.max_shared_memory_per_block
+ )
+
+ # vectorize load A
+ # (TODO) this is now actually problematic since the number of
loops is dependent on the
+ # number of dimensions of A_q
+ Aq_local = sch.cache_read(rf, read_buffer_index=1,
storage_scope="local")
+ sch.compute_at(Aq_local, r, preserve_unit_loops=True)
+ s_local, r_local = sch.get_loops(block=Aq_local)[-2:]
+ s_local, vec_load = sch.split(
+ s_local, factors=[None, VEC_LOAD], preserve_unit_iters=True
+ )
+ sch.reorder(s_local, r_local, vec_load) # either s_local or
r_local should be 1
+ sch.vectorize(vec_load)
+
+ # load vector into shared memory, shape should be the whole vector
+ if LOAD_V_SHARED:
+ assert len(vector_input_buffers) == 1
+ V_shared = sch.cache_read(rf, read_buffer_index=0,
storage_scope="shared")
+ sch.compute_at(V_shared, tr, preserve_unit_loops=True)
+ l = sch.get_loops(block=V_shared)[-1]
+ loop: tir.For = sch.get(l)
+ if isinstance(loop.extent, tir.IntImm):
+ # avoid introducing predicates when vector length is too
large
+ vec_length = max(
+ min(
+ get_max_factor(
+ (int)(loop.extent),
+ [TS * TR * 1, TS * TR * 2, TS * TR * 4, TS *
TR * 8],
+ )
+ // TS
+ // TR,
+ LOAD_V_VEC,
+ ),
+ 1,
+ )
+ else:
+ vec_length = LOAD_V_VEC
+ if TAG_R == "threadIdx.x":
+ _, ty, tx, vec = sch.split(
+ l, factors=[None, TS, TR, vec_length],
preserve_unit_iters=True
+ )
+ else:
+ _, ty, tx, vec = sch.split(
+ l, factors=[None, TR, TS, vec_length],
preserve_unit_iters=True
+ )
+ sch.bind(ty, "threadIdx.y")
+ sch.bind(tx, "threadIdx.x")
+ sch.vectorize(vec)
+
+ # reduce tile_s * tr * vec to tile_s * tr
+ sch.reverse_compute_at(rf2, loop=bx, preserve_unit_loops=True)
+ tr, vec_c, *ts_tile_s = sch.get_loops(block=rf2)[1:]
+ ts_tile_s = sch.fuse(*ts_tile_s)
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ tile_s, vec_s = sch.split(
+ tile_s,
+ factors=[None, get_max_factor(TILE_S, [1, 2, 4, 8])],
+ preserve_unit_iters=True,
+ )
+ sch.reorder(ts, tr, tile_s, vec_s, vec_c)
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+ sch.vectorize(vec_s)
+
+ # reduce tile_s * tr to tile_s
+ sch.reverse_compute_at(gemv, loop=bx, preserve_unit_loops=True)
+ tr, *ts_tile_s = sch.get_loops(block=gemv)[1:]
+ ts_tile_s = sch.fuse(*ts_tile_s)
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ sch.reorder(tile_s, ts, tr)
+ sch.bind(ts, TAG_S)
+ sch.bind(tr, TAG_R)
+
+ sch.decompose_reduction(rf, loop=sch.get_loops(block=rf)[3])
+ sch.decompose_reduction(rf2, loop=sch.get_loops(block=rf2)[-1])
+
+ sch.set_scope(rf, buffer_index=0, storage_scope="local")
+ sch.set_scope(rf2, buffer_index=0, storage_scope="local")
+
+ unroll_factor = UNROLL
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf)[3],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf)[3],
ann_key="pragma_unroll_explicit", ann_val=1
+ )
+
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
+ ann_key="pragma_auto_unroll_max_step",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(rf2)[3],
ann_key="pragma_unroll_explicit", ann_val=1
+ )
+
+ if LOAD_V_SHARED:
+ sch.annotate(
+ block_or_loop=sch.get_loops(V_shared)[-4],
+ ann_key="pragma_unroll_explicit",
+ ann_val=unroll_factor,
+ )
+ sch.annotate(
+ block_or_loop=sch.get_loops(V_shared)[-4],
ann_key="pragma_vectorize", ann_val=1
+ )
+
+ # Schedule epilogue
+ if epilogue_info is not None:
+ epilogue = epilogue_info.block_rv
+ if is_broadcast_epilogue(sch, block, epilogue):
+ sch.reverse_compute_at(epilogue, bx)
+ sch.set_scope(block, 0, "shared")
+ _, _, *s = sch.get_loops(epilogue) # pylint:
disable=invalid-name
+ _, tx = sch.split(sch.fuse(*s), factors=[None, TX])
+ sch.bind(tx, "threadIdx.x")
+ else:
+ sch.reverse_compute_at(epilogue, bx,
preserve_unit_loops=True)
+ ts_tile_s = sch.fuse(*sch.get_loops(epilogue)[1:])
+ ts_tile_s = sch.get_loops(epilogue)[-1]
+ ts, tile_s = sch.split(ts_tile_s, factors=[TS, None],
preserve_unit_iters=True)
+ sch.bind(ts, TAG_S)
+ sch.set_scope(block, 0, "local")
+ # pylint: enable=invalid-name
+ return sch
def get_extent(loop_rv: tir.schedule.LoopRV):
loop: tir.For = sch.get(loop_rv)
- return loop.extent.value if isinstance(loop.extent, tir.IntImm)
else 1
+ return loop.extent.value if isinstance(loop.extent, tir.IntImm)
else loop.extent
# Specify the `len_tx` and `len_ty` according to the loop extent
- len_s, len_r = get_extent(s), get_extent(r)
- if len_r >= 4096 and len_r % 128 == 0:
- len_tx = 128
- elif 1024 < len_r <= 2048 and len_r % 64 == 0:
- len_tx = 64
- else:
- len_tx = 32
-
- if len_s >= 4096:
- len_ty = 8
- else:
- len_ty = min(len_s, 4)
-
- # Use `split_k` to prevent too large shared memory usage
- split_k: int = 4
-
- _, tx = sch.split(r, [None, len_tx], preserve_unit_iters=True)
- # Schedule the RF block
- rf = sch.rfactor(tx, 0)
- batch, bx, r, tx, _ = sch.get_loops(rf)
- sch.reorder(bx, tx, r)
- ro, ri = sch.split(r, [split_k, None], preserve_unit_iters=True)
- bx, ty = sch.split(bx, [None, len_ty], preserve_unit_iters=True)
-
- sch.bind(batch, "blockIdx.y")
- sch.bind(bx, "blockIdx.x")
- sch.bind(ty, "threadIdx.y")
- sch.bind(tx, "threadIdx.x")
- sch.annotate(ro, "pragma_auto_unroll_max_step", unroll_number)
- sch.annotate(ro, "pragma_unroll_explicit", 1)
-
+ batch, s, r, c = sch.get_loops(block=block)
+ len_batch, len_s, len_r, len_c = (
+ get_extent(batch),
+ get_extent(s),
+ get_extent(r),
+ get_extent(c),
+ )
+ len_S = len_batch * len_s
+ len_R = len_r * len_c
+
+ TAG_S, TAG_R = "threadIdx.y", "threadIdx.x"
if target.kind.name == "cuda":
- # Cache read the vector
- def cache_shared(index: int):
- block: tir.Block = sch.get(rf)
- type_bytes: int = get_bytes(block.reads[index].buffer.dtype)
- cache = sch.cache_read(rf, index, "shared")
- sch.compute_at(cache, ro, preserve_unit_loops=True)
- fused = sch.fuse(*sch.get_loops(cache)[5:])
- loop: tir.For = sch.get(fused)
- vec_length = vec_bytes // type_bytes
- if isinstance(loop.extent, tir.IntImm):
- # avoid introducing predicates when vector length is too
large
- vec_length = min(loop.extent // len_ty // len_tx,
vec_length)
- _, _ty, _tx, _vec = sch.split(fused, [None, len_ty, len_tx,
vec_length])
- sch.bind(_ty, "threadIdx.y")
- sch.bind(_tx, "threadIdx.x")
- sch.vectorize(_vec)
-
- def cache_local(index: int):
- block: tir.Block = sch.get(rf)
- type_bytes: int = get_bytes(block.reads[index].buffer.dtype)
- vec_length = vec_bytes // type_bytes
- cache = sch.cache_read(rf, index, "local")
- sch.compute_at(cache, ri, preserve_unit_loops=True)
- fused = sch.fuse(*sch.get_loops(cache)[6:])
- loop: tir.For = sch.get(fused)
- if isinstance(loop.extent, tir.IntImm) and loop.extent.value %
vec_length == 0:
- _, _vec = sch.split(fused, [None, vec_length])
- sch.vectorize(_vec)
- elif isinstance(loop.extent, tir.IntImm) and loop.extent.value
< vec_length:
- sch.vectorize(fused)
-
- for buffer in vector_input_buffers:
- index = vector_input_buffers.index(buffer)
- cache_shared(index)
- cache_local(index)
-
- # TODO: cache scale buffer in Decode-GEMV to shared memory
-
- sch.set_scope(rf, 0, "local")
- sch.decompose_reduction(rf, ro)
- # Schedule the write back block
- sch.reverse_compute_at(block, ty, preserve_unit_loops=True)
- _, _, _, tx, *s = sch.get_loops(block)
- s = sch.fuse(*s)
- sch.reorder(s, tx)
- sch.bind(tx, "threadIdx.x")
- # Schedule epilogue
- if epilogue_info is not None:
- epilogue = epilogue_info.block_rv
- if is_broadcast_epilogue(sch, block, epilogue):
- sch.reverse_compute_at(epilogue, bx)
- sch.set_scope(block, 0, "shared")
- _, _, *s = sch.get_loops(epilogue) # pylint:
disable=invalid-name
- _, tx = sch.split(sch.fuse(*s), factors=[None, len_tx])
- sch.bind(tx, "threadIdx.x")
- else:
- # NOTE: Need to ensure tx_len == 32, so that can use `local`
stage here
- sch.reverse_compute_at(epilogue, ty)
- sch.set_scope(block, 0, "local")
- # pylint: enable=invalid-name
+ VEC_C = 4
+ LOAD_V_SHARED = True
+ LOAD_V_VEC = 8
+ UNROLL = 256
+ if isinstance(len_S, int):
+ if len_S > len_R:
+ TS, TR = 4, 64
+ else:
+ TS, TR = 16, 32
+ elif target.kind.name == "metal":
+ VEC_C = 2
+ LOAD_V_SHARED = True
+ LOAD_V_VEC = 4
+ UNROLL = 256
+ TS, TR = 64, 8
+ elif target.kind.name == "rocm":
+ VEC_C = 4
+ LOAD_V_SHARED = True
+ LOAD_V_VEC = 8
+ UNROLL = 256
+ if isinstance(len_S, int):
+ if len_S > len_R:
+ TS, TR = 1, 128
+ else:
+ TS, TR = 8, 64
+ elif target.kind.name == "opencl" and "android" in str(target.host):
+ TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
+ VEC_C = 8
+ LOAD_V_SHARED = False
+ LOAD_V_VEC = -1
+ UNROLL = 8
+ TS, TR = 2, 32
+ elif target.kind.name == "vulkan":
+ VEC_C = 4
+ LOAD_V_SHARED = True
+ LOAD_V_VEC = 4
+ UNROLL = 256
+ if isinstance(len_S, int):
+ if len_S > len_R:
+ TS, TR = 4, 32
+ else:
+ TS, TR = 16, 32
+ else:
+ VEC_C = 1
+ LOAD_V_SHARED = False
+ LOAD_V_VEC = -1
+ UNROLL = 64
+ TS, TR = 1, 64
+
+ if not isinstance(len_S, int):
+ TS, TR = 1, 64
+ TILE_S, TILE_R = (
+ 1,
+ len_c
+ if len_c > 1
+ else max(get_max_factor(len_r, [TR * 1, TR * 2, TR * 4, TR * 8])
// TR, 1),
+ )
+ VEC_C = min(get_max_factor(TILE_R, [1, 2, 4, 8]), VEC_C)
+ VEC_LOAD = 1
+
+ return apply(
+ sch,
+ gemv=block,
+ TAG_S=TAG_S,
+ TAG_R=TAG_R,
+ TS=TS,
+ TR=TR,
+ TILE_S=TILE_S,
+ TILE_R=TILE_R,
+ VEC_LOAD=VEC_LOAD,
+ VEC_C=VEC_C,
+ LOAD_V_SHARED=LOAD_V_SHARED,
+ LOAD_V_VEC=LOAD_V_VEC,
+ UNROLL=UNROLL,
+ )
diff --git a/python/tvm/dlight/gpu/utils.py b/python/tvm/dlight/gpu/utils.py
index 4fcc762942..9f9a9c5ae4 100644
--- a/python/tvm/dlight/gpu/utils.py
+++ b/python/tvm/dlight/gpu/utils.py
@@ -51,6 +51,8 @@ def suggest_threads_per_block(
) -> List[int]:
if target.kind.name == "cuda":
threads = 256
+ elif target.kind.name == "rocm":
+ threads = 256
else:
threads = 64
results: List[Optional[int]] = []
diff --git a/tests/python/dlight/test_gpu_gemv.py
b/tests/python/dlight/test_gpu_gemv.py
index 6cb5cceb43..fd6850ac60 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -90,61 +90,92 @@ class TestGEMV(BaseBeforeAfter):
var_compute_intermediate = T.match_buffer(p_output0, (1, 32, 1, n))
# with T.block("root"):
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 32, 1, n),
"float16", scope="local")
- var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 32, 1,
n), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local = T.alloc_buffer((128, 1, 32, 1,
n), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 32, 1,
n), "float16", scope="local")
+ lv1638_local = T.alloc_buffer((1, 32, n, 128), "float16",
scope="local")
lv1637_shared = T.alloc_buffer((1, 32, 1, 128), "float16",
scope="shared")
- lv1637_shared_local = T.alloc_buffer((1, 32, 1, 128), "float16",
scope="local")
- for ax0_fused in T.thread_binding(32, thread="blockIdx.y"):
- for ax1_fused_0 in T.thread_binding(n, thread="blockIdx.x"):
- for ax1_fused_1 in T.thread_binding(1, thread="threadIdx.y"):
- for ax2_fused_1 in T.thread_binding(32,
thread="threadIdx.x"):
- with T.block("NT_matmul_rf_init"):
- vax2_fused_1, v0 = T.axis.remap("SS",
[ax2_fused_1, ax0_fused])
- v1 = T.axis.spatial(n, ax1_fused_0 + ax1_fused_1)
- T.reads()
-
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
- var_NT_matmul_intermediate_rf_local[vax2_fused_1,
0, v0, 0, v1] = T.float16(0)
- for ax2_fused_0_0 in T.serial(4,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0_ax1_ax2_ax3_fused_0 in range(1):
- for ax0_ax1_ax2_ax3_fused_1 in
T.thread_binding(1, thread="threadIdx.y"):
- for ax0_ax1_ax2_ax3_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
- for ax0_ax1_ax2_ax3_fused_3 in
T.vectorized(1):
- with T.block("lv1637_shared"):
- v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(32,
ax0_fused)
- v2 = T.axis.spatial(1, 0)
- v3 = T.axis.spatial(128,
ax2_fused_0_0 * 32 + ax0_ax1_ax2_ax3_fused_0 * 32 + ax0_ax1_ax2_ax3_fused_1 *
32 + ax0_ax1_ax2_ax3_fused_2 + ax0_ax1_ax2_ax3_fused_3)
- T.reads(lv1637[v0, v1, v2, v3])
- T.writes(lv1637_shared[v0, v1,
v2, v3])
- lv1637_shared[v0, v1, v2, v3]
= lv1637[v0, v1, v2, v3]
- for ax2_fused_0_1 in range(1):
- for ax0_ax1_ax2_ax3_fused in T.vectorized(1):
- with T.block("lv1637_shared_local"):
- v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(32, ax0_fused)
- v2 = T.axis.spatial(1, 0)
- v3 = T.axis.spatial(128, ax2_fused_0_0
* 32 + ax2_fused_1)
- T.reads(lv1637_shared[v0, v1, v2, v3])
- T.writes(lv1637_shared_local[v0, v1,
v2, v3])
- lv1637_shared_local[v0, v1, v2, v3] =
lv1637_shared[v0, v1, v2, v3]
- for u in range(1):
- with T.block("NT_matmul_rf_update"):
- vax2_fused_1, v0 = T.axis.remap("SS",
[ax2_fused_1, ax0_fused])
- v1 = T.axis.spatial(n, ax1_fused_0 +
ax1_fused_1)
- vax2_fused_0 = T.axis.reduce(4,
ax2_fused_0_0 + ax2_fused_0_1)
-
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1],
lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1], lv1638[0, v0,
v1, vax2_fused_0 * 32 + vax2_fused_1])
-
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
-
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] =
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1] +
lv1637_shared_local[0, v0, 0, vax2_fused_0 * 32 + vax2_fused_1] * lv1638[0, v0,
v1, vax2_fused_0 * 32 + vax2_fused_1]
- for ax1_ax2_fused in range(1):
- for ax0 in T.thread_binding(32, thread="threadIdx.x"):
- with T.block("NT_matmul"):
- vax2_fused_1, v0, v1 = T.axis.remap("RSS",
[ax0, ax0_fused, ax1_fused_0])
-
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1])
- T.writes(var_NT_matmul_intermediate_local[0,
v0, 0, v1])
- with T.init():
- var_NT_matmul_intermediate_local[0, v0, 0,
v1] = T.float16(0)
- var_NT_matmul_intermediate_local[0, v0, 0, v1]
= var_NT_matmul_intermediate_local[0, v0, 0, v1] +
var_NT_matmul_intermediate_rf_local[vax2_fused_1, 0, v0, 0, v1]
+ for ax0_fused_ax1_fused_fused_0 in T.thread_binding(n * 32,
thread="blockIdx.x"):
+ for ax0_fused_ax1_fused_fused_1 in T.thread_binding(1,
thread="threadIdx.y"):
+ for ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 in
T.thread_binding(64, thread="threadIdx.x"):
+ for ax0, ax1, ax2 in T.grid(1, 1, 1):
+ for ax3_0 in T.serial(1,
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+ for ax3_1 in T.thread_binding(1,
thread="threadIdx.y"):
+ for ax3_2 in T.thread_binding(64,
thread="threadIdx.x"):
+ for ax3_3 in T.vectorized(2):
+ with T.block("lv1637_shared"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(32,
ax0_fused_ax1_fused_fused_0 // n + ax1)
+ v2 = T.axis.spatial(1, ax2)
+ v3 = T.axis.spatial(128, ax3_0 *
128 + ax3_1 * 128 + ax3_2 * 2 + ax3_3)
+ T.reads(lv1637[v0, v1, v2, v3])
+ T.writes(lv1637_shared[v0, v1, v2,
v3])
+ lv1637_shared[v0, v1, v2, v3] =
lv1637[v0, v1, v2, v3]
+ for ax0_fused_ax1_fused_fused_2_init in range(1):
+ for
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init in T.vectorized(2):
+ with T.block("NT_matmul_rf_init"):
+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused
= T.axis.spatial(128, ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 +
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1_init)
+ v0 = T.axis.spatial(32,
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 +
ax0_fused_ax1_fused_fused_2_init) // n)
+ v1 = T.axis.spatial(n,
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 +
ax0_fused_ax1_fused_fused_2_init) % n)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1])
+
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1] = T.float16(0)
+ for ax2_fused_u_fused_0 in T.serial(1,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0, ax1, ax2_0, ax3 in T.grid(1, 1, 1, 2):
+ for ax2_1 in T.vectorized(1):
+ with T.block("lv1638_local"):
+ v0 = T.axis.spatial(1, ax0)
+ v1 = T.axis.spatial(32,
ax0_fused_ax1_fused_fused_0 // n + ax1)
+ v2 = T.axis.spatial(n,
ax0_fused_ax1_fused_fused_0 % n + ax2_0 + ax2_1)
+ v3 = T.axis.spatial(128,
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 + ax3)
+ T.reads(lv1638[v0, v1, v2, v3])
+ T.writes(lv1638_local[v0, v1, v2, v3])
+ lv1638_local[v0, v1, v2, v3] = lv1638[v0,
v1, v2, v3]
+ for ax0_fused_ax1_fused_fused_2, ax2_fused_u_fused_2
in T.grid(1, 1):
+ for
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 in T.vectorized(2):
+ with T.block("NT_matmul_rf_update"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused = T.axis.spatial(128,
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 * 2 +
ax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1)
+ v0 = T.axis.spatial(32,
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 +
ax0_fused_ax1_fused_fused_2) // n)
+ v1 = T.axis.spatial(n,
(ax0_fused_ax1_fused_fused_0 + ax0_fused_ax1_fused_fused_1 +
ax0_fused_ax1_fused_fused_2) % n)
+ vax2_fused_u_fused_2, vax2_fused_u_fused_0
= T.axis.remap("RR", [ax2_fused_u_fused_2, ax2_fused_u_fused_0])
+
T.reads(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1], lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 +
vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused],
lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused])
+
T.writes(var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1])
+
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1] =
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused,
0, v0, 0, v1] + lv1637_shared[0, v0, 0, vax2_fused_u_fused_0 * 128 +
vax2_fused_u_fused_2 * 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused] *
lv1638_local[0, v0, v1, vax2_fused_u_fused_0 * 128 + vax2_fused_u_fused_2 * 2 +
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused]
+ for ax2_ax3_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
+ for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+ for ax2_ax3_fused_1_0 in T.serial(1,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_ax3_fused_1_1 in T.vectorized(1):
+ with T.block("NT_matmul_rf_init"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 = T.axis.spatial(64, ax0)
+ v0 = T.axis.spatial(32,
ax0_fused_ax1_fused_fused_0 // n)
+ v1 = T.axis.spatial(n,
ax0_fused_ax1_fused_fused_0 % n)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1])
+
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1] = T.float16(0)
+ for ax1 in range(2):
+ with T.block("NT_matmul_rf_update"):
+
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1 = T.axis.remap("SR", [ax0,
ax1])
+ v0 = T.axis.spatial(32,
ax0_fused_ax1_fused_fused_0 // n)
+ v1 = T.axis.spatial(n,
ax0_fused_ax1_fused_fused_0 % n)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1],
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
* 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1])
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1])
+
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1] =
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1] +
var_NT_matmul_intermediate_rf_local[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0
* 2 + vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_1, 0, v0, 0, v1]
+ for ax1_ax2_fused_1 in range(1):
+ for ax1_ax2_fused_0 in T.thread_binding(1,
thread="threadIdx.y"):
+ for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+ with T.block("NT_matmul"):
+ vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0 =
T.axis.reduce(64, ax0)
+ v0 = T.axis.spatial(32,
ax0_fused_ax1_fused_fused_0 // n)
+ v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0
% n)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1])
+ T.writes(var_NT_matmul_intermediate_local[0, v0,
0, v1])
+ with T.init():
+ var_NT_matmul_intermediate_local[0, v0, 0, v1]
= T.float16(0)
+ var_NT_matmul_intermediate_local[0, v0, 0, v1] =
var_NT_matmul_intermediate_local[0, v0, 0, v1] +
var_NT_matmul_intermediate_rf_local_1[vax2_fused_u_fused_1_ax2_fused_u_fused_3_fused_0,
0, v0, 0, v1]
+ for ax0_ax1_fused_0 in T.thread_binding(1, thread="threadIdx.y"):
+ for ax0_ax1_fused_1 in range(1):
with T.block("compute"):
- v0, v1 = T.axis.remap("SS", [ax0_fused, ax1_fused_0])
+ v0 = T.axis.spatial(32, ax0_fused_ax1_fused_fused_0 //
n)
+ v1 = T.axis.spatial(n, ax0_fused_ax1_fused_fused_0 % n)
T.reads(var_NT_matmul_intermediate_local[0, v0, 0,
v1], lv1614[0, 0, 0, v1])
T.writes(var_compute_intermediate[0, v0, 0, v1])
var_compute_intermediate[0, v0, 0, v1] =
T.Cast("float32", T.min(T.max(var_NT_matmul_intermediate_local[0, v0, 0, v1] *
T.float16(0.088397790055248615), T.float16(-65504)), lv1614[0, 0, 0, v1]))
@@ -152,10 +183,10 @@ class TestGEMV(BaseBeforeAfter):
# fmt: on
-class TestDecodeGEMV1(BaseBeforeAfter):
+def test_decode_gemv1():
# fmt: off
- @T.prim_func
+ @T.prim_func(private=True)
def before(lv571: T.Buffer((22016, 512), "uint32"), lv572:
T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"),
var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
@@ -175,72 +206,95 @@ class TestDecodeGEMV1(BaseBeforeAfter):
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] =
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv1654[v_i0, v_i1, v_k] *
p_output0_intermediate[v_i2, v_k]
- @T.prim_func
+ @T.prim_func(private=True)
def expected(lv571: T.Buffer((22016, 512), "uint32"), lv572:
T.Buffer((22016, 128), "float16"), lv1654: T.Buffer((1, 1, 4096), "float16"),
var_NT_matmul_intermediate: T.Buffer((1, 1, 22016), "float16")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
- var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1,
22016), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1,
22016), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1,
22016), "float16", scope="local")
+ lv571_local = T.alloc_buffer((22016, 512), "uint32", scope="local")
lv1654_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
- lv1654_shared_local = T.alloc_buffer((1, 1, 4096), "float16",
scope="local")
- for u_fused in T.thread_binding(1, thread="blockIdx.y"):
- for ax0_fused_0 in T.thread_binding(2752, thread="blockIdx.x"):
- for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
- for ax1_0_fused_1 in T.thread_binding(32,
thread="threadIdx.x"):
- with T.block("NT_matmul_rf_init"):
- vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
- v0 = T.axis.spatial(22016, ax0_fused_0 * 8 +
ax0_fused_1)
- T.reads()
-
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
- for ax1_0_fused_0_0 in T.serial(4,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0_ax1_ax2_fused_0 in range(1):
- for ax0_ax1_ax2_fused_1 in T.thread_binding(8,
thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_3 in
T.vectorized(4):
- with T.block("lv1654_shared"):
- v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(1, 0)
- v2 = T.axis.spatial(4096,
ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128
+ ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3)
- T.reads(lv1654[v0, v1, v2])
- T.writes(lv1654_shared[v0, v1,
v2])
- lv1654_shared[v0, v1, v2] =
lv1654[v0, v1, v2]
- for ax1_0_fused_0_1 in range(4):
- for ax0_ax1_ax2_fused_0 in range(1):
- for ax0_ax1_ax2_fused_1 in T.vectorized(8):
- with T.block("lv1654_shared_local"):
- v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(1, 0)
- v2 = T.axis.spatial(4096,
ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 +
ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
- T.reads(lv1654_shared[v0, v1, v2])
- T.writes(lv1654_shared_local[v0,
v1, v2])
- lv1654_shared_local[v0, v1, v2] =
lv1654_shared[v0, v1, v2]
- for ax1_1 in range(8):
- with T.block("NT_matmul_rf_update"):
- vax1_0_fused_1 = T.axis.spatial(32,
ax1_0_fused_1)
- v0 = T.axis.spatial(22016, ax0_fused_0
* 8 + ax0_fused_1)
- vax1_0_fused_0 = T.axis.reduce(16,
ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1)
- vax1_1 = T.axis.reduce(8, ax1_1)
-
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0],
lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1],
lv571[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv572[v0,
(vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
-
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] =
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] +
lv1654_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] *
((T.Cast("float16", T.bitwise_and(T.shift_right(lv571[v0, (vax1_0_fused_0 * 256
+ vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 +
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * lv [...]
- for ax1_fused in range(1):
- for ax0 in T.thread_binding(32, thread="threadIdx.x"):
- with T.block("NT_matmul"):
- vax1_0_fused_1 = T.axis.reduce(32, ax0)
- v0 = T.axis.spatial(22016, ax0_fused_0 * 8 +
ax0_fused_1)
-
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
- T.writes(var_NT_matmul_intermediate[0, 0, v0])
- with T.init():
- var_NT_matmul_intermediate[0, 0, v0] =
T.float16(0)
- var_NT_matmul_intermediate[0, 0, v0] =
var_NT_matmul_intermediate[0, 0, v0] +
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]
+ for u_fused_ax0_fused_fused_0 in T.thread_binding(5504,
thread="blockIdx.x"):
+ for u_fused_ax0_fused_fused_1 in T.thread_binding(4,
thread="threadIdx.y"):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in
T.thread_binding(64, thread="threadIdx.x"):
+ for ax0, ax1 in T.grid(1, 1):
+ for ax2_0 in T.serial(2,
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+ for ax2_1 in T.thread_binding(4,
thread="threadIdx.y"):
+ for ax2_2 in T.thread_binding(64,
thread="threadIdx.x"):
+ for ax2_3 in T.vectorized(8):
+ with T.block("lv1654_shared"):
+ v0, v1 = T.axis.remap("SS", [ax0,
ax1])
+ v2 = T.axis.spatial(4096, ax2_0 *
2048 + ax2_1 * 512 + ax2_2 * 8 + ax2_3)
+ T.reads(lv1654[v0, v1, v2])
+ T.writes(lv1654_shared[v0, v1, v2])
+ lv1654_shared[v0, v1, v2] =
lv1654[v0, v1, v2]
+ for u_fused_ax0_fused_fused_2_init in range(1):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in
T.vectorized(4):
+ with T.block("NT_matmul_rf_init"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused =
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 +
u_fused_ax0_fused_fused_2_init)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] = T.float16(0)
+ for ax1_0_fused_ax1_1_fused_0 in T.serial(8,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0_0, ax1 in T.grid(1, 1):
+ for ax0_1 in T.vectorized(1):
+ with T.block("lv571_local"):
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
+ v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 64 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+ T.reads(lv571[v0, v1])
+ T.writes(lv571_local[v0, v1])
+ lv571_local[v0, v1] = lv571[v0, v1]
+ for u_fused_ax0_fused_fused_2,
ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
+ with T.block("NT_matmul_rf_update"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused =
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 +
u_fused_ax0_fused_fused_2)
+ vax1_0_fused_ax1_1_fused_0,
vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0,
ax1_0_fused_ax1_1_fused_2])
+
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0], lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 +
vax1_0_fused_ax1_1_fused_2 * 4 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4],
lv571_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + va [...]
+
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] =
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] + lv1654_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 +
vax1_0_fused_ax1_1_fused_2 * 4 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] *
((T.Cast("float1 [...]
+ for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+ for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+ for ax2_fused_1_0 in T.serial(1,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_fused_1_1 in T.vectorized(1):
+ with T.block("NT_matmul_rf_init"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 =
T.axis.spatial(64, ax0)
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0] = T.float16(0)
+ for ax1 in range(4):
+ with T.block("NT_matmul_rf_update"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 =
T.axis.remap("SR", [ax0, ax1])
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0],
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0] =
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0] +
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
+ for ax1_fused_1 in range(1):
+ for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+ for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+ with T.block("NT_matmul"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 =
T.axis.reduce(64, ax0)
+ v0 = T.axis.spatial(22016,
u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0])
+ T.writes(var_NT_matmul_intermediate[0, 0, v0])
+ with T.init():
+ var_NT_matmul_intermediate[0, 0, v0] =
T.float16(0)
+ var_NT_matmul_intermediate[0, 0, v0] =
var_NT_matmul_intermediate[0, 0, v0] +
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0]
# fmt: on
+ mod = tvm.IRModule({"main": before})
+ with Target("nvidia/geforce-rtx-3090-ti"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
-class TestDecodeGEMV2(BaseBeforeAfter):
+
+def test_decode_gemv2():
# fmt: off
- @T.prim_func
+ @T.prim_func(private=True)
def before(lv771: T.Buffer((32000, 512), "uint32"), lv772:
T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"),
p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
@@ -267,73 +321,223 @@ class TestDecodeGEMV2(BaseBeforeAfter):
T.writes(p_output0_intermediate[v_i0, v_i1, v_i2])
p_output0_intermediate[v_i0, v_i1, v_i2] = T.Cast("float32",
var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
- @T.prim_func
+ @T.prim_func(private=True)
def expected(lv771: T.Buffer((32000, 512), "uint32"), lv772:
T.Buffer((32000, 128), "float16"), lv3216: T.Buffer((1, 1, 4096), "float16"),
p_output0_intermediate: T.Buffer((1, 1, 32000), "float32")):
T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
# with T.block("root"):
var_NT_matmul_intermediate_local = T.alloc_buffer((1, 1, 32000),
"float16", scope="local")
- var_NT_matmul_intermediate_rf_local = T.alloc_buffer((32, 1, 1,
32000), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local = T.alloc_buffer((256, 1, 1,
32000), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((64, 1, 1,
32000), "float16", scope="local")
+ lv771_local = T.alloc_buffer((32000, 512), "uint32", scope="local")
lv3216_shared = T.alloc_buffer((1, 1, 4096), "float16", scope="shared")
- lv3216_shared_local = T.alloc_buffer((1, 1, 4096), "float16",
scope="local")
- for u_fused in T.thread_binding(1, thread="blockIdx.y"):
- for ax0_fused_0 in T.thread_binding(4000, thread="blockIdx.x"):
- for ax0_fused_1 in T.thread_binding(8, thread="threadIdx.y"):
- for ax1_0_fused_1 in T.thread_binding(32,
thread="threadIdx.x"):
- with T.block("NT_matmul_rf_init"):
- vax1_0_fused_1 = T.axis.spatial(32, ax1_0_fused_1)
- v0 = T.axis.spatial(32000, ax0_fused_0 * 8 +
ax0_fused_1)
- T.reads()
-
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] = T.float16(0)
- for ax1_0_fused_0_0 in T.serial(4,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
- for ax0_ax1_ax2_fused_0 in range(1):
- for ax0_ax1_ax2_fused_1 in T.thread_binding(8,
thread="threadIdx.y"):
- for ax0_ax1_ax2_fused_2 in
T.thread_binding(32, thread="threadIdx.x"):
- for ax0_ax1_ax2_fused_3 in
T.vectorized(4):
- with T.block("lv3216_shared"):
- v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(1, 0)
- v2 = T.axis.spatial(4096,
ax1_0_fused_0_0 * 1024 + ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 * 128
+ ax0_ax1_ax2_fused_2 * 4 + ax0_ax1_ax2_fused_3)
- T.reads(lv3216[v0, v1, v2])
- T.writes(lv3216_shared[v0, v1,
v2])
- lv3216_shared[v0, v1, v2] =
lv3216[v0, v1, v2]
- for ax1_0_fused_0_1 in range(4):
- for ax0_ax1_ax2_fused_0 in range(1):
- for ax0_ax1_ax2_fused_1 in T.vectorized(8):
- with T.block("lv3216_shared_local"):
- v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(1, 0)
- v2 = T.axis.spatial(4096,
ax1_0_fused_0_0 * 1024 + ax1_0_fused_0_1 * 256 + ax1_0_fused_1 * 8 +
ax0_ax1_ax2_fused_0 * 8 + ax0_ax1_ax2_fused_1)
- T.reads(lv3216_shared[v0, v1, v2])
- T.writes(lv3216_shared_local[v0,
v1, v2])
- lv3216_shared_local[v0, v1, v2] =
lv3216_shared[v0, v1, v2]
- for ax1_1 in range(8):
- with T.block("NT_matmul_rf_update"):
- vax1_0_fused_1 = T.axis.spatial(32,
ax1_0_fused_1)
- v0 = T.axis.spatial(32000, ax0_fused_0
* 8 + ax0_fused_1)
- vax1_0_fused_0 = T.axis.reduce(16,
ax1_0_fused_0_0 * 4 + ax1_0_fused_0_1)
- vax1_1 = T.axis.reduce(8, ax1_1)
-
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0],
lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1],
lv771[v0, (vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 8], lv772[v0,
(vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1) // 32])
-
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
-
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] =
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0] +
lv3216_shared_local[0, 0, vax1_0_fused_0 * 256 + vax1_0_fused_1 * 8 + vax1_1] *
((T.Cast("float16", T.bitwise_and(T.shift_right(lv771[v0, (vax1_0_fused_0 * 256
+ vax1_0_fused_1 * 8 + vax1_1) // 8], T.Cast("uint32", (vax1_0_fused_0 * 256 +
vax1_0_fused_1 * 8 + vax1_1) % 8) * T.uint32(4)), T.uint32(15))) -
T.float16(7)) * lv [...]
- for ax1_fused in range(1):
- for ax0 in T.thread_binding(32, thread="threadIdx.x"):
- with T.block("NT_matmul"):
- vax1_0_fused_1 = T.axis.reduce(32, ax0)
- v0 = T.axis.spatial(32000, ax0_fused_0 * 8 +
ax0_fused_1)
-
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0])
- T.writes(var_NT_matmul_intermediate_local[0,
0, v0])
- with T.init():
- var_NT_matmul_intermediate_local[0, 0, v0]
= T.float16(0)
- var_NT_matmul_intermediate_local[0, 0, v0] =
var_NT_matmul_intermediate_local[0, 0, v0] +
var_NT_matmul_intermediate_rf_local[vax1_0_fused_1, 0, 0, v0]
+ for u_fused_ax0_fused_fused_0 in T.thread_binding(8000,
thread="blockIdx.x"):
+ for u_fused_ax0_fused_fused_1 in T.thread_binding(4,
thread="threadIdx.y"):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in
T.thread_binding(64, thread="threadIdx.x"):
+ for ax0, ax1 in T.grid(1, 1):
+ for ax2_0 in T.serial(2,
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+ for ax2_1 in T.thread_binding(4,
thread="threadIdx.y"):
+ for ax2_2 in T.thread_binding(64,
thread="threadIdx.x"):
+ for ax2_3 in T.vectorized(8):
+ with T.block("lv3216_shared"):
+ v0, v1 = T.axis.remap("SS", [ax0,
ax1])
+ v2 = T.axis.spatial(4096, ax2_0 *
2048 + ax2_1 * 512 + ax2_2 * 8 + ax2_3)
+ T.reads(lv3216[v0, v1, v2])
+ T.writes(lv3216_shared[v0, v1, v2])
+ lv3216_shared[v0, v1, v2] =
lv3216[v0, v1, v2]
+ for u_fused_ax0_fused_fused_2_init in range(1):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in
T.vectorized(4):
+ with T.block("NT_matmul_rf_init"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused =
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
+ v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 +
u_fused_ax0_fused_fused_2_init)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] = T.float16(0)
+ for ax1_0_fused_ax1_1_fused_0 in T.serial(8,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0_0, ax1 in T.grid(1, 1):
+ for ax0_1 in T.vectorized(1):
+ with T.block("lv771_local"):
+ v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 + ax0_0 + ax0_1)
+ v1 = T.axis.spatial(512,
ax1_0_fused_ax1_1_fused_0 * 64 +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+ T.reads(lv771[v0, v1])
+ T.writes(lv771_local[v0, v1])
+ lv771_local[v0, v1] = lv771[v0, v1]
+ for u_fused_ax0_fused_fused_2,
ax1_0_fused_ax1_1_fused_2 in T.grid(1, 2):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in T.vectorized(4):
+ with T.block("NT_matmul_rf_update"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused =
T.axis.spatial(256, ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
+ v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + u_fused_ax0_fused_fused_1 +
u_fused_ax0_fused_fused_2)
+ vax1_0_fused_ax1_1_fused_0,
vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0,
ax1_0_fused_ax1_1_fused_2])
+
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0], lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 +
vax1_0_fused_ax1_1_fused_2 * 4 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4],
lv771_local[v0, vax1_0_fused_ax1_1_fused_0 * 64 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 + va [...]
+
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] =
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
0, 0, v0] + lv3216_shared[0, 0, vax1_0_fused_ax1_1_fused_0 * 512 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // 4 * 8 +
vax1_0_fused_ax1_1_fused_2 * 4 +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % 4] *
((T.Cast("float1 [...]
+ for ax2_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+ for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+ for ax2_fused_1_0 in T.serial(1,
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_fused_1_1 in T.vectorized(1):
+ with T.block("NT_matmul_rf_init"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 =
T.axis.spatial(64, ax0)
+ v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0] = T.float16(0)
+ for ax1 in range(4):
+ with T.block("NT_matmul_rf_update"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 =
T.axis.remap("SR", [ax0, ax1])
+ v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + ax2_fused_0 + ax2_fused_1_0 + ax2_fused_1_1)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0],
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0])
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0])
+
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0] =
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0] +
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* 4 + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1, 0, 0, v0]
+ for ax1_fused_1 in range(1):
+ for ax1_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+ for ax0 in T.thread_binding(64, thread="threadIdx.x"):
+ with T.block("NT_matmul"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 =
T.axis.reduce(64, ax0)
+ v0 = T.axis.spatial(32000,
u_fused_ax0_fused_fused_0 * 4 + ax1_fused_0 + ax1_fused_1)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0])
+ T.writes(var_NT_matmul_intermediate_local[0, 0,
v0])
+ with T.init():
+ var_NT_matmul_intermediate_local[0, 0, v0] =
T.float16(0)
+ var_NT_matmul_intermediate_local[0, 0, v0] =
var_NT_matmul_intermediate_local[0, 0, v0] +
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
0, 0, v0]
+ for ax0_fused_0 in T.thread_binding(4, thread="threadIdx.y"):
+ for ax0_fused_1 in range(1):
with T.block("compute"):
- v0 = T.axis.spatial(32000, ax0_fused_0 * 8 +
ax0_fused_1)
+ v0 = T.axis.spatial(32000, u_fused_ax0_fused_fused_0 *
4 + ax0_fused_0 + ax0_fused_1)
T.reads(var_NT_matmul_intermediate_local[0, 0, v0])
T.writes(p_output0_intermediate[0, 0, v0])
p_output0_intermediate[0, 0, v0] = T.Cast("float32",
var_NT_matmul_intermediate_local[0, 0, v0])
# fmt: on
+ mod = tvm.IRModule({"main": before})
+ with Target("nvidia/geforce-rtx-3090-ti"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
+
+def test_decode_gemv3():
+ # fmt: off
+
+ @T.prim_func(private=True)
+ def before(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"),
lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574:
T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570:
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"),
p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)),
"float16")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ p_output0_intermediate_1 = T.alloc_buffer((T.int64(4096),
T.int64(11008)), "float16")
+ var_NT_matmul_intermediate = T.alloc_buffer((T.int64(1), T.int64(1),
T.int64(4096)), "float16")
+ for i, j in T.grid(T.int64(4096), T.int64(11008)):
+ with T.block("decode"):
+ v_i, v_j = T.axis.remap("SS", [i, j])
+ T.reads(lv575[v_i, v_j // T.int64(8)], lv576[v_i, v_j //
T.int64(32)])
+ T.writes(p_output0_intermediate_1[v_i, v_j])
+ p_output0_intermediate_1[v_i, v_j] = (T.Cast("float16",
T.bitwise_and(T.shift_right(lv575[v_i, v_j // T.int64(8)], T.Cast("uint32", v_j
% T.int64(8)) * T.uint32(4)), T.uint32(15))) - T.float16(7)) * lv576[v_i, v_j
// T.int64(32)]
+ for i0, i1, i2, k in T.grid(T.int64(1), T.int64(1), T.int64(4096),
T.int64(11008)):
+ with T.block("NT_matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(lv574[v_i0, v_i1, v_k], p_output0_intermediate_1[v_i2,
v_k])
+ T.writes(var_NT_matmul_intermediate[v_i0, v_i1, v_i2])
+ with T.init():
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+ var_NT_matmul_intermediate[v_i0, v_i1, v_i2] =
var_NT_matmul_intermediate[v_i0, v_i1, v_i2] + lv574[v_i0, v_i1, v_k] *
p_output0_intermediate_1[v_i2, v_k]
+ for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(1), T.int64(4096)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(lv570[v_ax0, v_ax1, v_ax2],
var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2])
+ T.writes(p_output0_intermediate[v_ax0, v_ax1, v_ax2])
+ p_output0_intermediate[v_ax0, v_ax1, v_ax2] = lv570[v_ax0,
v_ax1, v_ax2] + var_NT_matmul_intermediate[v_ax0, v_ax1, v_ax2]
+
+ @T.prim_func(private=True)
+ def expected(lv575: T.Buffer((T.int64(4096), T.int64(1376)), "uint32"),
lv576: T.Buffer((T.int64(4096), T.int64(344)), "float16"), lv574:
T.Buffer((T.int64(1), T.int64(1), T.int64(11008)), "float16"), lv570:
T.Buffer((T.int64(1), T.int64(1), T.int64(4096)), "float16"),
p_output0_intermediate: T.Buffer((T.int64(1), T.int64(1), T.int64(4096)),
"float16")):
+ T.func_attr({"tir.is_scheduled": 1, "tir.noalias": T.bool(True)})
+ # with T.block("root"):
+ var_NT_matmul_intermediate_local = T.alloc_buffer((T.int64(1),
T.int64(1), T.int64(4096)), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local = T.alloc_buffer((T.int64(128),
T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local")
+ var_NT_matmul_intermediate_rf_local_1 = T.alloc_buffer((T.int64(32),
T.int64(1), T.int64(1), T.int64(4096)), "float16", scope="local")
+ lv575_local = T.alloc_buffer((T.int64(4096), T.int64(1376)), "uint32",
scope="local")
+ lv574_shared = T.alloc_buffer((T.int64(1), T.int64(1),
T.int64(11008)), "float16", scope="shared")
+ for u_fused_ax0_fused_fused_0 in T.thread_binding(T.int64(256),
thread="blockIdx.x"):
+ for u_fused_ax0_fused_fused_1 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 in
T.thread_binding(T.int64(32), thread="threadIdx.x"):
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(1)):
+ for ax2_0 in T.serial(T.int64(22),
annotations={"pragma_unroll_explicit": 256, "pragma_vectorize": 1}):
+ for ax2_1 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax2_2 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for ax2_3 in T.vectorized(T.int64(1)):
+ with T.block("lv574_shared"):
+ v0, v1 = T.axis.remap("SS", [ax0,
ax1])
+ v2 =
T.axis.spatial(T.int64(11008), ax2_0 * T.int64(512) + ax2_1 * T.int64(32) +
ax2_2 + ax2_3)
+ T.where((ax2_0 * T.int64(16) +
ax2_1) * T.int64(32) + ax2_2 + ax2_3 < T.int64(11008))
+ T.reads(lv574[v0, v1, v2])
+ T.writes(lv574_shared[v0, v1, v2])
+ lv574_shared[v0, v1, v2] =
lv574[v0, v1, v2]
+ for u_fused_ax0_fused_fused_2_init in range(T.int64(1)):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init in
T.vectorized(T.int64(4)):
+ with T.block("NT_matmul_rf_init"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused =
T.axis.spatial(T.int64(128),
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1_init)
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 +
u_fused_ax0_fused_fused_2_init)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0])
+
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0] = T.float16(0)
+ for ax1_0_fused_ax1_1_fused_0 in T.serial(T.int64(43),
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax0_0, ax1 in T.grid(T.int64(1), T.int64(1)):
+ for ax0_1 in T.vectorized(T.int64(1)):
+ with T.block("lv575_local"):
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 + ax0_0 +
ax0_1)
+ v1 = T.axis.spatial(T.int64(1376),
ax1_0_fused_ax1_1_fused_0 * T.int64(32) +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 + ax1)
+ T.reads(lv575[v0, v1])
+ T.writes(lv575_local[v0, v1])
+ lv575_local[v0, v1] = lv575[v0, v1]
+ for u_fused_ax0_fused_fused_2,
ax1_0_fused_ax1_1_fused_2 in T.grid(T.int64(1), T.int64(2)):
+ for
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 in
T.vectorized(T.int64(4)):
+ with T.block("NT_matmul_rf_update"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused =
T.axis.spatial(T.int64(128),
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 * T.int64(4) +
ax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1)
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + u_fused_ax0_fused_fused_1 +
u_fused_ax0_fused_fused_2)
+ vax1_0_fused_ax1_1_fused_0,
vax1_0_fused_ax1_1_fused_2 = T.axis.remap("RR", [ax1_0_fused_ax1_1_fused_0,
ax1_0_fused_ax1_1_fused_2])
+
T.reads(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0], lv574_shared[T.int64(0), T.int64(0),
vax1_0_fused_ax1_1_fused_0 * T.int64(256) +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) *
T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int64(4) +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused % T.int64(4)],
lv575_local[v0, vax1_0_fused_ax1_1_fus [...]
+
T.writes(var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0])
+
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0] =
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused,
T.int64(0), T.int64(0), v0] + lv574_shared[T.int64(0), T.int64(0),
vax1_0_fused_ax1_1_fused_0 * T.int64(256) +
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused // T.int64(4) *
T.int64(8) + vax1_0_fused_ax1_1_fused_2 * T.int6 [...]
+ for ax2_fused_0 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax0 in T.thread_binding(T.int64(32), thread="threadIdx.x"):
+ for ax2_fused_1_0 in T.serial(T.int64(1),
annotations={"pragma_auto_unroll_max_step": 256, "pragma_unroll_explicit": 1}):
+ for ax2_fused_1_1 in T.vectorized(T.int64(1)):
+ with T.block("NT_matmul_rf_init"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 =
T.axis.spatial(T.int64(32), ax0)
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 +
ax2_fused_1_1)
+ T.reads()
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0])
+
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0] = T.float16(0)
+ for ax1 in range(T.int64(4)):
+ with T.block("NT_matmul_rf_update"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1 =
T.axis.remap("SR", [ax0, ax1])
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + ax2_fused_0 + ax2_fused_1_0 +
ax2_fused_1_1)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0],
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1,
T.int64(0), T.int64(0), v0])
+
T.writes(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0])
+
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0] =
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0] +
var_NT_matmul_intermediate_rf_local[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0
* T.int64(4) + vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_1,
T.int64(0), T.int64(0), v0]
+ for ax1_fused_1 in range(T.int64(1)):
+ for ax1_fused_0 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax0 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ with T.block("NT_matmul"):
+
vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0 =
T.axis.reduce(T.int64(32), ax0)
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + ax1_fused_0 + ax1_fused_1)
+
T.reads(var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0])
+
T.writes(var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0])
+ with T.init():
+ var_NT_matmul_intermediate_local[T.int64(0),
T.int64(0), v0] = T.float16(0)
+ var_NT_matmul_intermediate_local[T.int64(0),
T.int64(0), v0] = var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]
+
var_NT_matmul_intermediate_rf_local_1[vax1_0_fused_ax1_1_fused_1_ax1_0_fused_ax1_1_fused_3_fused_0,
T.int64(0), T.int64(0), v0]
+ for ax0_fused_0 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
+ for ax0_fused_1 in range(T.int64(1)):
+ with T.block("T_add"):
+ v0 = T.axis.spatial(T.int64(4096),
u_fused_ax0_fused_fused_0 * T.int64(16) + ax0_fused_0 + ax0_fused_1)
+ T.reads(lv570[T.int64(0), T.int64(0), v0],
var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0])
+ T.writes(p_output0_intermediate[T.int64(0),
T.int64(0), v0])
+ p_output0_intermediate[T.int64(0), T.int64(0), v0] =
lv570[T.int64(0), T.int64(0), v0] +
var_NT_matmul_intermediate_local[T.int64(0), T.int64(0), v0]
+
+ # fmt: on
+
+ mod = tvm.IRModule({"main": before})
+ with Target("nvidia/geforce-rtx-3090-ti"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+ mod.show(black_format=False)
+ tvm.ir.assert_structural_equal(mod["main"], expected)
+
if __name__ == "__main__":
tvm.testing.main()