This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e5f85c0e32 [DLIGHT][ADRENO] Fix for opencl adreno matmul schedule
(#17259)
e5f85c0e32 is described below
commit e5f85c0e32046b6b1bdc5bd1a2485c645df4e730
Author: krishnaraj36 <[email protected]>
AuthorDate: Sat Aug 10 21:55:51 2024 +0530
[DLIGHT][ADRENO] Fix for opencl adreno matmul schedule (#17259)
Fixed the matmul schedule for the case of epilog blocks
---
python/tvm/dlight/gpu/matmul.py | 50 ++++++++++++++-----
tests/python/dlight/test_gpu_matmul.py | 89 ++++++++++++++++++----------------
2 files changed, 85 insertions(+), 54 deletions(-)
diff --git a/python/tvm/dlight/gpu/matmul.py b/python/tvm/dlight/gpu/matmul.py
index 25cc649b44..5fb8e2469d 100644
--- a/python/tvm/dlight/gpu/matmul.py
+++ b/python/tvm/dlight/gpu/matmul.py
@@ -941,7 +941,7 @@ class Matmul(GPUScheduleRule):
inner_x=False,
)
elif target.kind.name == "opencl" and (
- ("android" in str(target.host)) or ("windows" in str(target.host))
+ ("android" in str(target.host)) or ("adreno" in str(target.attrs))
):
return Matmul.Config(
block_size_x=32,
@@ -991,7 +991,10 @@ class Matmul(GPUScheduleRule):
end_it = block_stmt.reads[-1].region[-1].min
return {it.var: it.kind for it in iter_infos}.get(end_it, "O") ==
"R"
- if target.kind.name == "opencl" and not is_inner_reduction(block_stmt,
iter_infos):
+ if (
+ target.kind.name == "opencl"
+ and (("android" in str(target.host)) or ("adreno" in
str(target.attrs)))
+ ) and not is_inner_reduction(block_stmt, iter_infos):
ret = self.sch_outer_reduction(sch, config, main_block, blocks)
if ret is not None:
return ret
@@ -1122,6 +1125,16 @@ class Matmul(GPUScheduleRule):
reduction_block: tir.schedule.BlockRV,
blocks: List[tir.schedule.BlockRV],
) -> Optional[tir.Schedule]:
+
+ """Get vectorization factor"""
+
+ def get_max_factor(n, factors):
+ factors = sorted(factors, reverse=True)
+ for factor in factors:
+ if n % factor == 0:
+ return factor
+ return 1
+
reduction_loops = sch.get_loops(reduction_block)
if not len(reduction_loops) == 4:
return None
@@ -1140,13 +1153,17 @@ class Matmul(GPUScheduleRule):
config.vector_size,
config.unroll,
)
-
- is_dequant_block = len(blocks) > 1
- if is_dequant_block:
- compute_block, dequant_block, matmul_block = blocks
- sch.compute_inline(compute_block)
- else:
- (matmul_block,) = blocks
+ VecSize = min(get_max_factor(sch.get(n).extent // Threads_X, [1, 2, 4,
8]), VecSize)
+ dequant_block = None
+ matmul_block = reduction_block
+ epilogue_block = None
+ if blocks[-1] is not matmul_block:
+ epilogue_block = blocks[-1]
+ for blk in blocks[:-1]:
+ if "dequantize" in sch.get(blk).name_hint:
+ dequant_block = blk
+ elif blk is not matmul_block:
+ sch.compute_inline(blk)
m = sch.fuse(mb, ms)
@@ -1162,12 +1179,13 @@ class Matmul(GPUScheduleRule):
sch.reorder(no, mo, ni, mi, k0, k1, k2, k3, mu, nv)
sch.compute_at(rmat_block, k0)
- if is_dequant_block:
+ if dequant_block is not None:
sch.compute_at(dequant_block, k3)
sch.reverse_compute_at(wmat_block, mi)
sch.set_scope(rmat_block, 0, "shared")
sch.set_scope(matmul_block, 0, "local")
- if is_dequant_block:
+
+ if dequant_block is not None:
sch.set_scope(dequant_block, 0, "local")
sch.bind(mo, "blockIdx.y")
@@ -1175,7 +1193,7 @@ class Matmul(GPUScheduleRule):
sch.bind(mi, "threadIdx.y")
sch.bind(ni, "threadIdx.x")
sch.vectorize(sch.get_loops(matmul_block)[-1])
- if is_dequant_block:
+ if dequant_block is not None:
sch.vectorize(sch.get_loops(dequant_block)[-1])
# Co-operative Memory Fetch
@@ -1187,7 +1205,7 @@ class Matmul(GPUScheduleRule):
sch.vectorize(wv)
# Scale and Quant Cache
- if is_dequant_block:
+ if dequant_block is not None:
qb = sch.cache_read(dequant_block, 0, "local")
sb = sch.cache_read(dequant_block, 1, "local")
sch.compute_at(sb, k1)
@@ -1197,5 +1215,11 @@ class Matmul(GPUScheduleRule):
sch.vectorize(sch.get_loops(qb)[-1])
sch.vectorize(sch.get_loops(sb)[-1])
+ if epilogue_block is not None:
+ sch.reverse_compute_at(epilogue_block, mi,
preserve_unit_loops=True)
+ sch.set_scope(wmat_block, 0, "local")
+ sch.compute_inline(wmat_block)
+ sch.vectorize(sch.get_loops(epilogue_block)[-1])
+
sch.decompose_reduction(matmul_block, k0)
return sch
diff --git a/tests/python/dlight/test_gpu_matmul.py
b/tests/python/dlight/test_gpu_matmul.py
index 4cef7f1c27..dc5276e62a 100644
--- a/tests/python/dlight/test_gpu_matmul.py
+++ b/tests/python/dlight/test_gpu_matmul.py
@@ -685,47 +685,54 @@ class TestMatmulAndroid(AndroidBeforeAfter):
class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
# fmt: off
@T.prim_func
- def before(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"),
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260:
T.handle, p_output0: T.handle):
+ def before(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"),
lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130:
T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),),
"float16"), p_output0: T.handle):
T.func_attr({"tir.noalias": T.bool(True)})
seq_len = T.int64()
- rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len,
T.int64(4096)), "float16")
- matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len,
T.int64(12288)), "float16")
+ rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len,
T.int64(4096)), "float16")
+ T_add_intermediate_intermediate = T.match_buffer(p_output0,
(T.int64(1), seq_len, T.int64(12288)), "float16")
# with T.block("root"):
compute = T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16")
dequantize_intermediate_intermediate = T.alloc_buffer((T.int64(4096),
T.int64(12288)), "float16")
+ matmul_intermediate = T.alloc_buffer((T.int64(1), seq_len,
T.int64(12288)), "float16")
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
with T.block("compute"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(lv840[v_i0 // T.int64(8), v_i1])
+ T.reads(lv452[v_i0 // T.int64(8), v_i1])
T.writes(compute[v_i0, v_i1])
- compute[v_i0, v_i1] = T.Cast("float16",
T.bitwise_and(T.shift_right(lv840[v_i0 // T.int64(8), v_i1], T.Cast("uint32",
v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
+ compute[v_i0, v_i1] = T.Cast("float16",
T.bitwise_and(T.shift_right(lv452[v_i0 // T.int64(8), v_i1], T.Cast("uint32",
v_i0 % T.int64(8) * T.int64(4))), T.uint32(15)))
for i0, i1 in T.grid(T.int64(4096), T.int64(12288)):
with T.block("dequantize"):
v_i0, v_i1 = T.axis.remap("SS", [i0, i1])
- T.reads(compute[v_i0, v_i1], lv841[v_i0 // T.int64(32), v_i1])
+ T.reads(compute[v_i0, v_i1], lv453[v_i0 // T.int64(32), v_i1])
T.writes(dequantize_intermediate_intermediate[v_i0, v_i1])
- dequantize_intermediate_intermediate[v_i0, v_i1] =
(compute[v_i0, v_i1] - T.float16(7)) * lv841[v_i0 // T.int64(32), v_i1]
+ dequantize_intermediate_intermediate[v_i0, v_i1] =
(compute[v_i0, v_i1] - T.float16(7)) * lv453[v_i0 // T.int64(32), v_i1]
for i0, i1, i2, k in T.grid(T.int64(1), seq_len, T.int64(12288),
T.int64(4096)):
with T.block("matmul"):
v_i0, v_i1, v_i2, v_k = T.axis.remap("SSSR", [i0, i1, i2, k])
- T.reads(rms_norm260[v_i0, v_i1, v_k],
dequantize_intermediate_intermediate[v_k, v_i2])
+ T.reads(rms_norm130[v_i0, v_i1, v_k],
dequantize_intermediate_intermediate[v_k, v_i2])
T.writes(matmul_intermediate[v_i0, v_i1, v_i2])
with T.init():
matmul_intermediate[v_i0, v_i1, v_i2] = T.float16(0)
- matmul_intermediate[v_i0, v_i1, v_i2] =
matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm260[v_i0, v_i1, v_k] *
dequantize_intermediate_intermediate[v_k, v_i2]
+ matmul_intermediate[v_i0, v_i1, v_i2] =
matmul_intermediate[v_i0, v_i1, v_i2] + rms_norm130[v_i0, v_i1, v_k] *
dequantize_intermediate_intermediate[v_k, v_i2]
+ for ax0, ax1, ax2 in T.grid(T.int64(1), seq_len, T.int64(12288)):
+ with T.block("T_add"):
+ v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
+ T.reads(matmul_intermediate[v_ax0, v_ax1, v_ax2],
transformer_h_0_attn_c_attn_bias3[v_ax2])
+ T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
+ T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2] =
matmul_intermediate[v_ax0, v_ax1, v_ax2] +
transformer_h_0_attn_c_attn_bias3[v_ax2]
@T.prim_func
- def expected(lv840: T.Buffer((T.int64(512), T.int64(12288)), "uint32"),
lv841: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm260:
T.handle, p_output0: T.handle):
+ def expected(lv452: T.Buffer((T.int64(512), T.int64(12288)), "uint32"),
lv453: T.Buffer((T.int64(128), T.int64(12288)), "float16"), p_rms_norm130:
T.handle, transformer_h_0_attn_c_attn_bias3: T.Buffer((T.int64(12288),),
"float16"), p_output0: T.handle):
T.func_attr({"global_symbol": "main", "tir.is_scheduled": 1,
"tir.noalias": T.bool(True)})
seq_len = T.int64()
- rms_norm260 = T.match_buffer(p_rms_norm260, (T.int64(1), seq_len,
T.int64(4096)), "float16")
- matmul_intermediate = T.match_buffer(p_output0, (T.int64(1), seq_len,
T.int64(12288)), "float16")
+ rms_norm130 = T.match_buffer(p_rms_norm130, (T.int64(1), seq_len,
T.int64(4096)), "float16")
+ T_add_intermediate_intermediate = T.match_buffer(p_output0,
(T.int64(1), seq_len, T.int64(12288)), "float16")
# with T.block("root"):
dequantize_intermediate_intermediate_local =
T.alloc_buffer((T.int64(4096), T.int64(12288)), "float16", scope="local")
- rms_norm260_pad_shared = T.alloc_buffer((T.int64(1), (seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16",
scope="shared")
+ rms_norm130_pad_shared = T.alloc_buffer((T.int64(1), (seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(4096)), "float16",
scope="shared")
matmul_intermediate_pad_local = T.alloc_buffer((T.int64(1), (seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), T.int64(12288)), "float16",
scope="local")
- lv840_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32",
scope="local")
- lv841_local = T.alloc_buffer((T.int64(128), T.int64(12288)),
"float16", scope="local")
+ lv452_local = T.alloc_buffer((T.int64(512), T.int64(12288)), "uint32",
scope="local")
+ lv453_local = T.alloc_buffer((T.int64(128), T.int64(12288)),
"float16", scope="local")
for i2_0 in T.thread_binding(T.int64(48), thread="blockIdx.x"):
for i0_i1_fused_0 in T.thread_binding((seq_len + T.int64(31)) //
T.int64(32), thread="blockIdx.y"):
for i2_1 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
@@ -743,37 +750,37 @@ class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
for ax0 in range(T.int64(4)):
for ax1_0 in T.thread_binding(T.int64(32),
thread="threadIdx.x"):
for ax1_1 in T.vectorized(T.int64(8)):
- with T.block("rms_norm260_pad"):
+ with T.block("rms_norm130_pad"):
v0 = T.axis.spatial(T.int64(1),
T.int64(0))
v1 = T.axis.spatial((seq_len +
T.int64(31)) // T.int64(32) * T.int64(32), i0_i1_fused_0 * T.int64(32) +
i0_i1_fused_1 * T.int64(4) + ax0)
v2 = T.axis.spatial(T.int64(4096),
k_0 * T.int64(256) + ax1_0 * T.int64(8) + ax1_1)
- T.reads(rms_norm260[v0, v1, v2])
-
T.writes(rms_norm260_pad_shared[v0, v1, v2])
- rms_norm260_pad_shared[v0, v1, v2]
= T.if_then_else(v1 < seq_len, rms_norm260[v0, v1, v2], T.float16(0))
+ T.reads(rms_norm130[v0, v1, v2])
+
T.writes(rms_norm130_pad_shared[v0, v1, v2])
+ rms_norm130_pad_shared[v0, v1, v2]
= T.if_then_else(v1 < seq_len, rms_norm130[v0, v1, v2], T.float16(0))
for k_1 in range(T.int64(8)):
for ax0 in T.vectorized(T.int64(8)):
- with T.block("lv841_local"):
+ with T.block("lv453_local"):
v0 = T.axis.spatial(T.int64(128), k_0
* T.int64(8) + k_1)
v1 = T.axis.spatial(T.int64(12288),
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
- T.reads(lv841[v0, v1])
- T.writes(lv841_local[v0, v1])
- lv841_local[v0, v1] = lv841[v0, v1]
+ T.reads(lv453[v0, v1])
+ T.writes(lv453_local[v0, v1])
+ lv453_local[v0, v1] = lv453[v0, v1]
for k_2 in range(T.int64(4)):
for ax0 in T.vectorized(T.int64(8)):
- with T.block("lv840_local"):
+ with T.block("lv452_local"):
v0 = T.axis.spatial(T.int64(512),
k_0 * T.int64(32) + k_1 * T.int64(4) + k_2)
v1 =
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
- T.reads(lv840[v0, v1])
- T.writes(lv840_local[v0, v1])
- lv840_local[v0, v1] = lv840[v0, v1]
+ T.reads(lv452[v0, v1])
+ T.writes(lv452_local[v0, v1])
+ lv452_local[v0, v1] = lv452[v0, v1]
for k_3 in range(T.int64(8)):
for ax0 in T.vectorized(T.int64(8)):
with T.block("dequantize"):
v_i0 =
T.axis.spatial(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 *
T.int64(8) + k_3)
v_i1 =
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax0)
- T.reads(lv840_local[v_i0 //
T.int64(8), v_i1], lv841_local[v_i0 // T.int64(32), v_i1])
+ T.reads(lv452_local[v_i0 //
T.int64(8), v_i1], lv453_local[v_i0 // T.int64(32), v_i1])
T.writes(dequantize_intermediate_intermediate_local[v_i0, v_i1])
-
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16",
T.bitwise_and(T.shift_right(lv840_local[v_i0 // T.int64(8), v_i1],
T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) -
T.float16(7)) * lv841_local[v_i0 // T.int64(32), v_i1]
+
dequantize_intermediate_intermediate_local[v_i0, v_i1] = (T.Cast("float16",
T.bitwise_and(T.shift_right(lv452_local[v_i0 // T.int64(8), v_i1],
T.Cast("uint32", v_i0 % T.int64(8) * T.int64(4))), T.uint32(15))) -
T.float16(7)) * lv453_local[v_i0 // T.int64(32), v_i1]
for i0_i1_fused_2 in range(T.int64(4)):
for i2_2 in
T.vectorized(T.int64(8)):
with T.block("matmul_update"):
@@ -781,19 +788,19 @@ class TestFusedDequantMatmulAndroid(AndroidBeforeAfter):
v_i1 =
T.axis.spatial((seq_len + T.int64(31)) // T.int64(32) * T.int64(32),
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + i0_i1_fused_2)
v_i2 =
T.axis.spatial(T.int64(12288), i2_0 * T.int64(256) + i2_1 * T.int64(8) + i2_2)
v_k =
T.axis.reduce(T.int64(4096), k_0 * T.int64(256) + k_1 * T.int64(32) + k_2 *
T.int64(8) + k_3)
-
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2],
rms_norm260_pad_shared[v_i0, v_i1, v_k],
dequantize_intermediate_intermediate_local[v_k, v_i2])
+
T.reads(matmul_intermediate_pad_local[v_i0, v_i1, v_i2],
rms_norm130_pad_shared[v_i0, v_i1, v_k],
dequantize_intermediate_intermediate_local[v_k, v_i2])
T.writes(matmul_intermediate_pad_local[v_i0, v_i1, v_i2])
-
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] =
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm260_pad_shared[v_i0,
v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
- for ax0 in range(T.int64(4)):
- for ax1 in T.vectorized(T.int64(8)):
- with T.block("matmul_intermediate_pad"):
- v0 = T.axis.spatial(T.int64(1), T.int64(0))
- v1 = T.axis.spatial(seq_len, i0_i1_fused_0
* T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0)
- v2 = T.axis.spatial(T.int64(12288), i2_0 *
T.int64(256) + i2_1 * T.int64(8) + ax1)
- T.where((i0_i1_fused_0 - (seq_len +
T.int64(31)) // T.int64(32) < T.int64(0) or i0_i1_fused_0 == T.int64(0)) and
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax0 < seq_len)
- T.reads(matmul_intermediate_pad_local[v0,
v1, v2])
- T.writes(matmul_intermediate[v0, v1, v2])
- matmul_intermediate[v0, v1, v2] =
matmul_intermediate_pad_local[v0, v1, v2]
+
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] =
matmul_intermediate_pad_local[v_i0, v_i1, v_i2] + rms_norm130_pad_shared[v_i0,
v_i1, v_k] * dequantize_intermediate_intermediate_local[v_k, v_i2]
+ for ax0, ax1 in T.grid(T.int64(1), T.int64(4)):
+ for ax2 in T.vectorized(T.int64(8)):
+ with T.block("T_add"):
+ v_ax0 = T.axis.spatial(T.int64(1), ax0)
+ v_ax1 = T.axis.spatial(seq_len,
i0_i1_fused_0 * T.int64(32) + i0_i1_fused_1 * T.int64(4) + ax1)
+ v_ax2 = T.axis.spatial(T.int64(12288),
i2_0 * T.int64(256) + i2_1 * T.int64(8) + ax2)
+ T.where(i0_i1_fused_0 * T.int64(32) +
i0_i1_fused_1 * T.int64(4) + ax1 < seq_len)
+
T.reads(matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2],
transformer_h_0_attn_c_attn_bias3[v_ax2])
+
T.writes(T_add_intermediate_intermediate[v_ax0, v_ax1, v_ax2])
+ T_add_intermediate_intermediate[v_ax0,
v_ax1, v_ax2] = matmul_intermediate_pad_local[v_ax0, v_ax1, v_ax2] +
transformer_h_0_attn_c_attn_bias3[v_ax2]
# fmt: on