This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e5420436a0 [Dlight] Skip GeMV when normalization fails (#16665)
e5420436a0 is described below
commit e5420436a0fa5ee60764b6c300dfd4ff93d7b069
Author: Ruihang Lai <[email protected]>
AuthorDate: Sun Mar 3 00:50:04 2024 -0500
[Dlight] Skip GeMV when normalization fails (#16665)
Prior to this PR, GeMV does not skip the cases of normalization
failure, which leads to error. This PR fixes this issue.
A unit test is added accordingly.
---
python/tvm/dlight/gpu/gemv.py | 2 ++
tests/python/dlight/test_gpu_gemv.py | 33 +++++++++++++++++++++++++++++++++
2 files changed, 35 insertions(+)
diff --git a/python/tvm/dlight/gpu/gemv.py b/python/tvm/dlight/gpu/gemv.py
index d453b84bc0..d1a195fbad 100644
--- a/python/tvm/dlight/gpu/gemv.py
+++ b/python/tvm/dlight/gpu/gemv.py
@@ -180,6 +180,8 @@ class GEMV(GPUScheduleRule):
sch = tir.Schedule(func)
block_infos = normalize_prim_func(sch)
block_infos = try_inline_contiguous_spatial(sch, block_infos)
+ if block_infos is None:
+ return None
if len(block_infos) == 1:
epilogue = None
elif len(block_infos) == 2:
diff --git a/tests/python/dlight/test_gpu_gemv.py
b/tests/python/dlight/test_gpu_gemv.py
index b5e8b82ab7..8903babbc0 100644
--- a/tests/python/dlight/test_gpu_gemv.py
+++ b/tests/python/dlight/test_gpu_gemv.py
@@ -996,5 +996,38 @@ def test_blockized_gemv():
tvm.ir.assert_structural_equal(mod["main"], expected)
+def test_func_to_skip():
+ @T.prim_func
+ def before(var_A: T.handle, var_exclusive_scan_thrust: T.handle, seq_len:
T.int64):
+ data_buf = T.match_buffer(var_A, (seq_len * T.int64(8),), "int32",
align=8)
+ output_buf = T.match_buffer(
+ var_exclusive_scan_thrust, (seq_len * T.int64(8),), "int32",
align=8
+ )
+ with T.block("exclusive_scan_thrust"):
+ T.reads()
+ T.writes()
+ T.call_packed(
+ "tvm.contrib.thrust.sum_scan",
+ T.tvm_stack_make_array(
+ data_buf.data, T.tvm_stack_make_shape(seq_len *
T.int64(8)), 0, 1, 0, T.int64(0)
+ ),
+ T.tvm_stack_make_array(
+ output_buf.data,
+ T.tvm_stack_make_shape(seq_len * T.int64(8)),
+ 0,
+ 1,
+ 0,
+ T.int64(0),
+ ),
+ T.bool(False),
+ )
+
+ # This function should be skipped.
+ mod = tvm.IRModule({"main": before})
+ with Target("metal"):
+ mod = dl.ApplyDefaultSchedule(dl.gpu.GEMV())(mod)
+ tvm.ir.assert_structural_equal(mod["main"], before)
+
+
if __name__ == "__main__":
tvm.testing.main()