This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5bbe1aba6d  [Dlight] LowBatchGemv rule only apply to function with 
spatial symbolic var (#16678)
5bbe1aba6d is described below

commit 5bbe1aba6d0ca0f7422299a7b34c9e1a4181288d
Author: Hongyi Jin <[email protected]>
AuthorDate: Sat Mar 9 08:12:14 2024 -0500

     [Dlight] LowBatchGemv rule only apply to function with spatial symbolic 
var (#16678)
    
    * squash
    
    * fix
---
 python/tvm/dlight/gpu/low_batch_gemv.py        | 12 ++++++++++--
 tests/python/dlight/test_gpu_low_batch_gemv.py | 24 ++++++++++++++++++++++++
 2 files changed, 34 insertions(+), 2 deletions(-)

diff --git a/python/tvm/dlight/gpu/low_batch_gemv.py 
b/python/tvm/dlight/gpu/low_batch_gemv.py
index dfed020853..1c27fdfb13 100644
--- a/python/tvm/dlight/gpu/low_batch_gemv.py
+++ b/python/tvm/dlight/gpu/low_batch_gemv.py
@@ -98,7 +98,14 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> 
Optional[List[tir.Buffe
         for iter_var in block_stmt.iter_vars
         if isinstance(iter_var.dom.extent, tir.IntImm)
     )
-    if len(const_iter_vars) == len(block_stmt.iter_vars):
+    if len(block_stmt.iter_vars) - len(const_iter_vars) != 1:
+        return None
+    symbolic_iter_var = list(
+        iter_var
+        for iter_var in block_stmt.iter_vars
+        if not isinstance(iter_var.dom.extent, tir.IntImm)
+    )[0]
+    if symbolic_iter_var.iter_type != tir.stmt.IterVar.DataPar:
         return None
     ret = [
         read.buffer
@@ -220,7 +227,8 @@ class LowBatchGEMV(GPUScheduleRule):
             return None
         sch = tir.Schedule(func)
         block_infos = normalize_prim_func(sch)
-
+        if block_infos is None:
+            return None
         reduction_block_infos = [
             block_info for block_info in block_infos if 
block_info.is_reduction()
         ]
diff --git a/tests/python/dlight/test_gpu_low_batch_gemv.py 
b/tests/python/dlight/test_gpu_low_batch_gemv.py
index 5827b7b810..d3e635ddaa 100644
--- a/tests/python/dlight/test_gpu_low_batch_gemv.py
+++ b/tests/python/dlight/test_gpu_low_batch_gemv.py
@@ -251,5 +251,29 @@ def test_batch_gemv():
     tvm.ir.assert_structural_equal(mod["main"], expected)
 
 
+def test_reduction_symbolic_var():
+    # fmt: off
+    @T.prim_func(private=True)
+    def before(var_A: T.handle, var_B: T.handle, matmul: T.Buffer((T.int64(1), 
T.int64(32), T.int64(1), T.int64(128)), "float32")):
+        T.func_attr({"tir.noalias": T.bool(True)})
+        kv_seq_len = T.int64()
+        A = T.match_buffer(var_A, (T.int64(1), T.int64(32), T.int64(1), 
kv_seq_len))
+        B = T.match_buffer(var_B, (T.int64(1), T.int64(32), kv_seq_len, 
T.int64(128)))
+        # with T.block("root"):
+        for i0, i1, i2, i3, k in T.grid(T.int64(1), T.int64(32), T.int64(1), 
T.int64(128), kv_seq_len):
+            with T.block("matmul"):
+                v_i0, v_i1, v_i2, v_i3, v_k = T.axis.remap("SSSSR", [i0, i1, 
i2, i3, k])
+                T.reads(A[v_i0, v_i1, v_i2, v_k], B[v_i0, v_i1, v_k, v_i3])
+                T.writes(matmul[v_i0, v_i1, v_i2, v_i3])
+                with T.init():
+                    matmul[v_i0, v_i1, v_i2, v_i3] = T.float32(0)
+                matmul[v_i0, v_i1, v_i2, v_i3] = matmul[v_i0, v_i1, v_i2, 
v_i3] + A[v_i0, v_i1, v_i2, v_k] * B[v_i0, v_i1, v_k, v_i3]
+    # fmt: on
+    mod = tvm.IRModule({"main": before})
+    with Target("metal"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.LowBatchGEMV(4))(mod)
+    tvm.ir.assert_structural_equal(mod["main"], before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to