This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 50d1c97dc9 [DLIGHT][GPU] Add OpenCL dequant matmul schedule (#17187)
50d1c97dc9 is described below
commit 50d1c97dc982c6ddfe089852d1fbbac3ea629851
Author: krishnaraj36 <[email protected]>
AuthorDate: Tue Jul 23 20:57:53 2024 +0530
[DLIGHT][GPU] Add OpenCL dequant matmul schedule (#17187)
* [DLIGHT][GPU] Add OpenCL dequant matmul schedule
1. Enhanced the GPU matmul schedule for OpenCL Android and windows backend.
2. It improves the 2X performance gain for Llama-2-7B prefill process
Model device Earlier prefill perf
Optimized prefill perf
Llama-2-7B-chat-hf Snapdragon® 8 Gen 3 27 tok/sec
50 tok/sec
* Update matmul.py
---
python/tvm/dlight/gpu/matmul.py | 144 +++++++++++++++++++++++--
tests/python/dlight/test_gpu_matmul.py | 192 +++++++++++++++++++++++++++------
2 files changed, 292 insertions(+), 44 deletions(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index a5759941ca..25cc649b44 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -27,7 +27,7 @@ from tvm.tir import IterVar, PrimExpr, Var
from tvm.tir.analysis import undefined_vars
from tvm.tir.schedule.schedule import BlockRV
-from ..base import analysis
+from ..base import analysis, BlockInfo, IterInfo
from .base import GPUScheduleRule
@@ -273,6 +273,32 @@ def get_index_map(block: tir.Block) ->
Optional[Tuple[tir.IndexMap, ...]]:
)
+def get_block_info(sch: tir.Schedule, block: tir.schedule.BlockRV) ->
BlockInfo:
+ def _iter_kind(loop: tir.IterVar) -> str:
+ return {tir.IterVar.DataPar: "S", tir.IterVar.CommReduce:
"R"}.get(loop.iter_type, "O")
+
+ def _is_reduction_block(block: tir.schedule.BlockRV):
+ for iter_var in sch.get(block).iter_vars:
+ if _iter_kind(iter_var) == "R":
+ return True
+ return False
+
+ return BlockInfo(
+ name=sch.get(block).name_hint,
+ iters=[
+ IterInfo(
+ kind=_iter_kind(iter_var),
+ var=iter_var.var,
+ dom=iter_var.dom.extent,
+ loop_rv=loop_rv,
+ )
+ for loop_rv, iter_var in zip(sch.get_loops(block),
sch.get(block).iter_vars)
+ ],
+ block_rv=block,
+ reduction_block=_is_reduction_block(block),
+ )
+
+
def get_reduction_blocks(sch, blocks) -> bool:
# Get the main computation block
def is_reduction(block: BlockRV) -> bool:
@@ -914,17 +940,19 @@ class Matmul(GPUScheduleRule):
storage_align=True,
inner_x=False,
)
- elif target.kind.name == "opencl" and "android" in str(target.host):
+ elif target.kind.name == "opencl" and (
+ ("android" in str(target.host)) or ("windows" in str(target.host))
+ ):
return Matmul.Config(
- block_size_x=8,
- block_size_y=16,
+ block_size_x=32,
+ block_size_y=8,
vthread_x=1,
vthread_y=1,
micro_size_x=8,
micro_size_y=2,
micro_size_k=16,
vector_size=8,
- unroll=64,
+ unroll=4,
use_shared=False,
storage_align=False,
inner_x=True,
@@ -941,6 +969,7 @@ class Matmul(GPUScheduleRule):
if not isinstance(func, tir.PrimFunc) or not
self.is_target_available(target):
return None
sch = tir.Schedule(func)
+ config = self.get_configs(target)
root_block = analysis.get_root_block(sch)
blocks = sch.get_child_blocks(root_block)
@@ -953,9 +982,22 @@ class Matmul(GPUScheduleRule):
index_maps = get_index_map(block_stmt)
if index_maps is None:
return None
- matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
+
+ main_block_info = get_block_info(sch, main_block)
+ iter_infos = main_block_info.iters
+
+ # Checks if it's a inner reduction by getting the last matrix's inner
Index
+ def is_inner_reduction(block_stmt, iter_infos):
+ end_it = block_stmt.reads[-1].region[-1].min
+ return {it.var: it.kind for it in iter_infos}.get(end_it, "O") ==
"R"
+
+ if target.kind.name == "opencl" and not is_inner_reduction(block_stmt,
iter_infos):
+ ret = self.sch_outer_reduction(sch, config, main_block, blocks)
+ if ret is not None:
+ return ret
# Step 0. Normalize generic matmul to C[S, I, J] += A[S, I, K] * B[S,
J, K]
+ matmul_index_map, a_index_map, b_index_map, c_index_map = index_maps
block = sch.reindex(main_block, ("read", 0))
sch.transform_layout(block, ("write", 0), a_index_map)
block = sch.reindex(main_block, ("read", 1))
@@ -994,10 +1036,7 @@ class Matmul(GPUScheduleRule):
except: # pylint: disable=bare-except
pass
- # Step 2. Get schedule config.
- config = self.get_configs(target)
-
- # Step 3. Schedule matmul
+ # Step 2. Schedule matmul
y_kernel_size = config.vthread_y * config.block_size_y *
config.micro_size_y
x_kernel_size = config.vthread_x * config.block_size_x *
config.micro_size_x
if config.inner_x:
@@ -1075,3 +1114,88 @@ class Matmul(GPUScheduleRule):
sch.decompose_reduction(main_block, ko)
return sch
+
+ def sch_outer_reduction(
+ self,
+ sch: tir.Schedule,
+ config: Config,
+ reduction_block: tir.schedule.BlockRV,
+ blocks: List[tir.schedule.BlockRV],
+ ) -> Optional[tir.Schedule]:
+ reduction_loops = sch.get_loops(reduction_block)
+ if not len(reduction_loops) == 4:
+ return None
+
+ mb, ms, n, k = reduction_loops
+ if not (
+ isinstance(sch.get(n).extent, tir.IntImm)
+ and isinstance(sch.get(mb).extent, tir.IntImm)
+ and isinstance(sch.get(ms).extent, tir.Var)
+ ):
+ return None
+
+ Threads_X, Threads_Y, VecSize, Unroll_M = (
+ config.block_size_x,
+ config.block_size_y,
+ config.vector_size,
+ config.unroll,
+ )
+
+ is_dequant_block = len(blocks) > 1
+ if is_dequant_block:
+ compute_block, dequant_block, matmul_block = blocks
+ sch.compute_inline(compute_block)
+ else:
+ (matmul_block,) = blocks
+
+ m = sch.fuse(mb, ms)
+
+ sch.pad_einsum(matmul_block, [1, Threads_Y * Unroll_M, Threads_X *
VecSize, 1])
+
+ rmat_block, wmat_block = (
+ sch.get_producers(matmul_block)[0],
+ sch.get_consumers(matmul_block)[0],
+ )
+ mo, mi, mu = sch.split(m, [None, Threads_Y, Unroll_M])
+ no, ni, nv = sch.split(n, [None, Threads_X, VecSize])
+ k0, k1, k2, k3 = sch.split(k, [None, (Threads_X * VecSize) // 32, 4,
8])
+ sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)
+
+ sch.compute_at(rmat_block, k0)
+ if is_dequant_block:
+ sch.compute_at(dequant_block, k3)
+ sch.reverse_compute_at(wmat_block, mi)
+ sch.set_scope(rmat_block, 0, "shared")
+ sch.set_scope(matmul_block, 0, "local")
+ if is_dequant_block:
+ sch.set_scope(dequant_block, 0, "local")
+
+ sch.bind(mo, "blockIdx.y")
+ sch.bind(no, "blockIdx.x")
+ sch.bind(mi, "threadIdx.y")
+ sch.bind(ni, "threadIdx.x")
+ sch.vectorize(sch.get_loops(matmul_block)[-1])
+ if is_dequant_block:
+ sch.vectorize(sch.get_loops(dequant_block)[-1])
+
+ # Co-operative Memory Fetch
+ ro, rv = sch.split(sch.get_loops(rmat_block)[-1], [None, VecSize])
+ sch.bind(ro, "threadIdx.x")
+ sch.vectorize(rv)
+
+ wv = sch.get_loops(wmat_block)[-1]
+ sch.vectorize(wv)
+
+ # Scale and Quant Cache
+ if is_dequant_block:
+ qb = sch.cache_read(dequant_block, 0, "local")
+ sb = sch.cache_read(dequant_block, 1, "local")
+ sch.compute_at(sb, k1)
+ sch.compute_at(qb, k2)
+ sch.set_scope(sb, 0, "local")
+ sch.set_scope(qb, 0, "local")
+ sch.vectorize(sch.get_loops(qb)[-1])
+ sch.vectorize(sch.get_loops(sb)[-1])
+
+ sch.decompose_reduction(matmul_block, k0)
+ return sch
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index ca32c286ab..4cef7f1c27 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -634,42 +634,166 @@ class TestMatmulAndroid(AndroidBeforeAfter):
inp0 = T.match_buffer(var_inp0, (T.int64(1), m, T.int64(4096)))
matmul = T.match_buffer(var_matmul, (T.int64(1), m, T.int64(4096)))
# with T.block("root"):
- matmul_reindex_pad_local = T.alloc_buffer((T.int64(1), (m +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
- for ax0_ax1_0_fused in T.thread_binding((m + T.int64(31)) //
T.int64(32), thread="blockIdx.y"):
- for ax2_0 in T.thread_binding(T.int64(64), thread="blockIdx.x"):
- for ax1_1 in T.thread_binding(T.int64(1), thread="vthread.y"):
- for ax2_1 in T.thread_binding(T.int64(1),
thread="vthread.x"):
- for ax1_2 in T.thread_binding(T.int64(16),
thread="threadIdx.y"):
- for ax2_2 in T.thread_binding(T.int64(8),
thread="threadIdx.x", annotations={"pragma_auto_unroll_max_step": 64,
"pragma_unroll_explicit": 1}):
- for ax1_3_init, ax2_3_0_init in
T.grid(T.int64(2), T.int64(1)):
- for ax2_3_1_init in
T.vectorized(T.int64(8)):
- with T.block("matmul_init"):
+ inp0_pad_shared = T.alloc_buffer((T.int64(1), (m + T.int64(31)) //
T.int64(32) * T.int64(32), T.int64(4096)), scope="shared")
+ matmul_pad_local = T.alloc_buffer((T.int64(1), (m + T.int64(31)) //
T.int64(32) * T.int64(32), T.int64(4096)), scope="local")
+ for i2_0 in T.thread_binding(T.int64(16), thread="blockIdx.x"):
+ for i0_i1_fused_0 in T.thread_binding((m + T.int64(31)) //
T.int64(32), thread="blockIdx.y"):
+ for i2_1 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for i0_i1_fused_2_init in range(T.int64(4)):
+ for i2_2_init in T.vectorized(T.int64(8)):
+ with T.block("matmul_init"):
+ v_i0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v_i1 = T.axis.spatial((m + T.int64(31)) //
T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 *
T.int64(4) + i0_i1_fused_2_init)
+ v_i2 = T.axis.spatial(T.int64(4096), i2_0
* T.int64(256) + i2_1 * T.int64(8) + i2_2_init)
+ T.reads()
+ T.writes(matmul_pad_local[v_i0, v_i1,
v_i2])
+ matmul_pad_local[v_i0, v_i1, v_i2] =
T.float32(0)
+ for k_0 in range(T.int64(16)):
+ for ax0 in range(T.int64(4)):
+ for ax1_0 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for ax1_1 in T.vectorized(T.int64(8)):
+ with T.block("inp0_pad"):
v0 = T.axis.spatial(T.int64(1),
T.int64(0))
- v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) +
ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3_init)
- v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0_init *
T.int64(8) + ax2_3_1_init)
- T.reads()
-
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-
matmul_reindex_pad_local[T.int64(0), v1, v2] = T.float32(0)
- for ax3_0, ax3_1, ax1_3, ax2_3_0 in
T.grid(T.int64(256), T.int64(16), T.int64(2), T.int64(1)):
- for ax2_3_1 in T.vectorized(T.int64(8)):
- with T.block("matmul_update"):
+ v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) +
i0_i1_fused_1 * T.int64(4) + ax0)
+ v2 = T.axis.spatial(T.int64(4096),
k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
+ T.reads(inp0[v0, v1, v2])
+ T.writes(inp0_pad_shared[v0, v1,
v2])
+ inp0_pad_shared[v0, v1, v2] =
T.if_then_else(v1 < m, inp0[v0, v1, v2], T.float32(0))
+ for k_1, k_2, k_3, i0_i1_fused_2 in
T.grid(T.int64(8), T.int64(4), T.int64(8), T.int64(4)):
+ for i2_2 in T.vectorized(T.int64(8)):
+ with T.block("matmul_update"):
+ v_i0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v_i1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) +
i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
+ v_i2 = T.axis.spatial(T.int64(4096),
i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
+ v_k = T.axis.reduce(T.int64(4096), k_0
* T.int64(256) + k_1 * T.int64(32) + k_2 * T.int64(8) + k_3)
+ T.reads(matmul_pad_local[v_i0, v_i1,
v_i2], inp0_pad_shared[v_i0, v_i1, v_k], inp1[v_k, v_i2])
+ T.writes(matmul_pad_local[v_i0, v_i1,
v_i2])
+ matmul_pad_local[v_i0, v_i1, v_i2] =
matmul_pad_local[v_i0, v_i1, v_i2] + inp0_pad_shared[v_i0, v_i1, v_k] *
inp1[v_k, v_i2]
+ for ax0 in range(T.int64(4)):
+ for ax1 in T.vectorized(T.int64(8)):
+ with T.block("matmul_pad"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(m, i0_i1_fused_0 *
T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
+ v2 = T.axis.spatial(T.int64(4096), i2_0 *
T.int64(256) + i2_1 * T.int64(8) + ax1)
+ T.where((i0_i1_fused_0 - (m + T.int64(31))
// T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and i0_i1_fused_0 *
T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < m)
+ T.reads(matmul_pad_local[v0, v1, v2])
+ T.writes(matmul[v0, v1, v2])
+ matmul[v0, v1, v2] = matmul_pad_local[v0,
v1, v2]
+
+
+class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
+ # fmt: off
+ @T.prim_func
+ def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"),
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260:
T.handle, p_output0: T.handle):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ seq_len = T.int64()
+ rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len,
T.int64(4096)), "float16")
+ matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len,
T.int64(12288)), "float16")
+ # with T.block("root"):
+ compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
+ dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096),
T.int64(12288)), "float16")
+ for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
+ with T.block("compute"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(lv840[v_i0 // T.int64(8), v_i1])
+ T.writes(compute[v_i0, v_i1])
+ compute[v_i0, v_i1] = T.Cast("float16",
T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32",
v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
+ for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
+ with T.block("dequantize"):
+ v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
+ T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1])
+ T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
+ dequantize_intermediate_intermediate[v_i0, v_i1] =
(compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1]
+ for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288),
T.int64(4096)):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
+ T.reads(rms_norm260[v_i0, v_i1, v_k],
dequantize_intermediate_intermediate[v_k, v_i2])
+ T.writes(matmul_intermediate[v_i0, v_i1, v_i2])
+ with T.init():
+ matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
+ matmul_intermediate[v_i0, v_i1, v_i2] =
matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] *
dequantize_intermediate_intermediate[v_k, v_i2]
+
+ @T.prim_func
+ def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"),
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260:
T.handle, p_output0: T.handle):
+ T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
+ seq_len = T.int64()
+ rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len,
T.int64(4096)), "float16")
+ matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len,
T.int64(12288)), "float16")
+ # with T.block("root"):
+ dequantize_intermediate_intermediate_local =
T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local")
+ rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16",
scope="shared")
+ matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16",
scope="local")
+ lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32",
scope="local")
+ lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)),
"float16", scope="local")
+ for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
+ for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) //
T.int64(32), thread="blockIdx.y"):
+ for i2_1 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for i0_i1_fused_1 in T.thread_binding(T.int64(8),
thread="threadIdx.y"):
+ for i0_i1_fused_2_init in range(T.int64(4)):
+ for i2_2_init in T.vectorized(T.int64(8)):
+ with T.block("matmul_init"):
+ v_i0 = T.axis.spatial(T.int64(1),
T.int64(0))
+ v_i1 = T.axis.spatial((seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) +
i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2_init)
+ v_i2 = T.axis.spatial(T.int64(12288), i2_0
* T.int64(256) + i2_1 * T.int64(8) + i2_2_init)
+ T.reads()
+
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
+ matmul_intermediate_pad_local[v_i0, v_i1,
v_i2] = T.float16(0)
+ for k_0 in range(T.int64(16)):
+ for ax0 in range(T.int64(4)):
+ for ax1_0 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
+ for ax1_1 in T.vectorized(T.int64(8)):
+ with T.block("rms_norm260_pad"):
v0 = T.axis.spatial(T.int64(1),
T.int64(0))
- v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) +
ax1_1 * T.int64(32) + ax1_2 * T.int64(2) + ax1_3)
- v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_1 * T.int64(64) + ax2_2 * T.int64(8) + ax2_3_0 *
T.int64(8) + ax2_3_1)
- v3 = T.axis.reduce(T.int64(4096),
ax3_0 * T.int64(16) + ax3_1)
-
T.reads(matmul_reindex_pad_local[T.int64(0), v1, v2], inp0[T.int64(0), v1, v3],
inp1[v3, v2])
-
T.writes(matmul_reindex_pad_local[T.int64(0), v1, v2])
-
matmul_reindex_pad_local[T.int64(0), v1, v2] =
matmul_reindex_pad_local[T.int64(0), v1, v2] + T.if_then_else(v1 < m,
inp0[T.int64(0), v1, v3], T.float32(0)) * inp1[v3, v2]
- for ax0, ax1, ax2_0_1 in T.grid(T.int64(1),
T.int64(2), T.int64(1)):
- for ax2_1_1 in T.vectorized(T.int64(8)):
- with
T.block("matmul_reindex_pad_local"):
- v0 = T.axis.spatial(T.int64(1),
ax0)
- v1 = T.axis.spatial((m +
T.int64(31)) // T.int64(32) * T.int64(32), ax0_ax1_0_fused * T.int64(32) +
ax1_2 * T.int64(2) + ax1)
- v2 = T.axis.spatial(T.int64(4096),
ax2_0 * T.int64(64) + ax2_2 * T.int64(8) + ax2_0_1 * T.int64(8) + ax2_1_1)
- T.where(ax0_ax1_0_fused *
T.int64(32) + ax1_2 * T.int64(2) + ax1 < m)
-
T.reads(matmul_reindex_pad_local[v0, v1, v2])
- T.writes(matmul[T.int64(0), v1,
v2])
- matmul[T.int64(0), v1, v2] =
matmul_reindex_pad_local[v0, v1, v2]
+ v1 = T.axis.spatial((seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) +
i0_i1_fused_1 * T.int64(4) + ax0)
+ v2 = T.axis.spatial(T.int64(4096),
k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
+ T.reads(rms_norm260[v0, v1, v2])
+
T.writes(rms_norm260_pad_shared[v0, v1, v2])
+ rms_norm260_pad_shared[v0, v1, v2]
= T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0))
+ for k_1 in range(T.int64(8)):
+ for ax0 in T.vectorized(T.int64(8)):
+ with T.block("lv841_local"):
+ v0 = T.axis.spatial(T.int64(128), k_0
* T.int64(8) + k_1)
+ v1 = T.axis.spatial(T.int64(12288),
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
+ T.reads(lv841[v0, v1])
+ T.writes(lv841_local[v0, v1])
+ lv841_local[v0, v1] = lv841[v0, v1]
+ for k_2 in range(T.int64(4)):
+ for ax0 in T.vectorized(T.int64(8)):
+ with T.block("lv840_local"):
+ v0 = T.axis.spatial(T.int64(512),
k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)
+ v1 =
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
+ T.reads(lv840[v0, v1])
+ T.writes(lv840_local[v0, v1])
+ lv840_local[v0, v1] = lv840[v0, v1]
+ for k_3 in range(T.int64(8)):
+ for ax0 in T.vectorized(T.int64(8)):
+ with T.block("dequantize"):
+ v_i0 =
T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 *
T.int64(8) + k_3)
+ v_i1 =
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
+ T.reads(lv840_local[v_i0 //
T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1])
+
T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1])
+
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16",
T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1],
T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) -
T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1]
+ for i0_i1_fused_2 in range(T.int64(4)):
+ for i2_2 in
T.vectorized(T.int64(8)):
+ with T.block("matmul_update"):
+ v_i0 =
T.axis.spatial(T.int64(1), T.int64(0))
+ v_i1 =
T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32),
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
+ v_i2 =
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
+ v_k =
T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 *
T.int64(8) + k_3)
+
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2],
rms_norm260_pad_shared[v_i0, v_i1, v_k],
dequantize_intermediate_intermediate_local[v_k, v_i2])
+
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
+
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] =
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0,
v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
+ for ax0 in range(T.int64(4)):
+ for ax1 in T.vectorized(T.int64(8)):
+ with T.block("matmul_intermediate_pad"):
+ v0 = T.axis.spatial(T.int64(1), T.int64(0))
+ v1 = T.axis.spatial(seq_len, i0_i1_fused_0
* T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
+ v2 = T.axis.spatial(T.int64(12288), i2_0 *
T.int64(256) + i2_1 * T.int64(8) + ax1)
+ T.where((i0_i1_fused_0 - (seq_len +
T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len)
+ T.reads(matmul_intermediate_pad_local[v0,
v1, v2])
+ T.writes(matmul_intermediate[v0, v1, v2])
+ matmul_intermediate[v0, v1, v2] =
matmul_intermediate_pad_local[v0, v1, v2]
# fmt: on