This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5bbe1aba6d [Dlight] LowBatchGemv rule only apply to function with
spatial symbolic var (#16678)
5bbe1aba6d is described below
commit 5bbe1aba6d0ca0f7422299a7b34c9e1a4181288d
Author: Hongyi Jin <[email protected]>
AuthorDate: Sat Mar 9 08:12:14 2024 -0500
[Dlight] LowBatchGemv rule only apply to function with spatial symbolic
var (#16678)
* squash
* fix
---
python/tvm/dlight/gpu/low_batch_gemv.py | 12 ++++++++++--
tests/python/dlight/test_gpu_low_batch_gemv.py | 24 ++++++++++++++++++++++++
2 files changed, 34 insertions(+), 2 deletions(-)
diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py
b/python/tvm/dlight/gpu/low_batch_gemv.py
index dfed020853..1c27fdfb13 100644
--- a/python/tvm/dlight/gpu/low_batch_gemv.py
+++ b/python/tvm/dlight/gpu/low_batch_gemv.py
@@ -98,7 +98,14 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) ->
Optional[List[tir.Buffe
for iter_var in block_stmt.iter_vars
if isinstance(iter_var.dom.extent, tir.IntImm)
)
- if len(const_iter_vars) == len(block_stmt.iter_vars):
+ if len(block_stmt.iter_vars) - len(const_iter_vars) != 1:
+ return None
+ symbolic_iter_var = list(
+ iter_var
+ for iter_var in block_stmt.iter_vars
+ if not isinstance(iter_var.dom.extent, tir.IntImm)
+ )[0]
+ if symbolic_iter_var.iter_type != tir.stmt.IterVar.DataPar:
return None
ret = [
read.buffer
@@ -220,7 +227,8 @@ class LowBatchGEMV(GPUScheduleRule):
return None
sch = tir.Schedule(func)
block_infos = normalize_prim_func(sch)
-
+ if block_infos is None:
+ return None
reduction_block_infos = [
block_info for block_info in block_infos if
block_info.is_reduction()
]
diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py
b/tests/python/dlight/test_gpu_low_batch_gemv.py
index 5827b7b810..d3e635ddaa 100644
--- a/tests/python/dlight/test_gpu_low_batch_gemv.py
+++ b/tests/python/dlight/test_gpu_low_batch_gemv.py
@@ -251,5 +251,29 @@ def test_batch_gemv():
tvm.ir.assert_structural_equal(mod["main"], expected)
+def test_reduction_symbolic_var():
+ # fmt: off
+ @T.prim_func(private=True)
+ def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1),
T.int64(32), T.int64(1), T.int64(128)), "float32")):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ kv_seq_len = T.int64()
+ A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1),
kv_seq_len))
+ B = T.match_buffer(var_B, (T.int64(1), T.int64(32), kv_seq_len,
T.int64(128)))
+ # with T.block("root"):
+ for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1),
T.int64(128), kv_seq_len):
+ with T.block("matmul"):
+ v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1,
i2, i3, k])
+ T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
+ T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
+ with T.init():
+ matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
+ matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2,
v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
+ # fmt: on
+ mod = tvm.IRModule({"main": before})
+ with Target("metal"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+ tvm.ir.assert_structural_equal(mod["main"], before)
+
+
if __name__ == "__main__":
tvm.testing.main()