This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e5420436a0 [Dlight] Skip GeMV when normalization fails (#16665)
e5420436a0 is described below

commit e5420436a0fa5ee60764b6c300dfd4ff93d7b069
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 3 00:50:04 2024 -0500

    [Dlight] Skip GeMV when normalization fails (#16665)
    
    Prior to this PR, GeMV does not skip the cases of normalization
    failure, which leads to error. This PR fixes this issue.
    
    A unit test is added accordingly.
---
 python/tvm/dlight/gpu/gemv.py        |  2 ++
 tests/python/dlight/test_gpu_gemv.py | 33 +++++++++++++++++++++++++++++++++
 2 files changed, 35 insertions(+)

diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index d453b84bc0..d1a195fbad 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -180,6 +180,8 @@ class GEMV(GPUScheduleRule):
         sch = tir.Schedule(func)
         block_infos = normalize_prim_func(sch)
         block_infos = try_inline_contiguous_spatial(sch, block_infos)
+        if block_infos is None:
+            return None
         if len(block_infos) == 1:
             epilogue = None
         elif len(block_infos) == 2:
diff --git a/tests/python/dlight/test_gpu_gemv.py 
b/tests/python/dlight/test_gpu_gemv.py
index b5e8b82ab7..8903babbc0 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -996,5 +996,38 @@ def test_blockized_gemv():
         tvm.ir.assert_structural_equal(mod["main"], expected)
 
 
+def test_func_to_skip():
+    @T.prim_func
+    def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len: 
T.int64):
+        data_buf = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32", 
align=8)
+        output_buf = T.match_buffer(
+            var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32", 
align=8
+        )
+        with T.block("exclusive_scan_thrust"):
+            T.reads()
+            T.writes()
+            T.call_packed(
+                "tvm.contrib.thrust.sum_scan",
+                T.tvm_stack_make_array(
+                    data_buf.data, T.tvm_stack_make_shape(seq_len * 
T.int64(8)), 0, 1, 0, T.int64(0)
+                ),
+                T.tvm_stack_make_array(
+                    output_buf.data,
+                    T.tvm_stack_make_shape(seq_len * T.int64(8)),
+                    0,
+                    1,
+                    0,
+                    T.int64(0),
+                ),
+                T.bool(False),
+            )
+
+    # This function should be skipped.
+    mod = tvm.IRModule({"main": before})
+    with Target("metal"):
+        mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+        tvm.ir.assert_structural_equal(mod["main"], before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to