This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 484bd4421b [Unity][BYOC] Do not use cudaMemcpy for max_seqlen in var
len attention (#15989)
484bd4421b is described below
commit 484bd4421b71effcad64b87543f75528b01d1709
Author: masahi <[email protected]>
AuthorDate: Fri Oct 27 04:58:11 2023 +0900
[Unity][BYOC] Do not use cudaMemcpy for max_seqlen in var len attention
(#15989)
* Do not use cudaMemcpy for max_seqlen in cutlass var len attention
* black
---
python/tvm/contrib/cutlass/attention_operation.py | 11 +++--------
tests/python/relax/test_codegen_cutlass.py | 20 ++++++++++++++++++--
2 files changed, 21 insertions(+), 10 deletions(-)
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 7084a105c8..da998db6c0 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -35,9 +35,7 @@ def instantiate_attention_template(attrs):
var_len_template = """
p.seqstart_q_ptr = (int32_t*)${seqstart_q}->data;
p.seqstart_k_ptr = (int32_t*)${seqstart_k}->data;
- // TODO(masahi): Pass max_seqlen_q as an integer
- cudaMemcpy(&p.num_queries, (int32_t*)${max_seqlen_q}->data, sizeof(int32_t),
- cudaMemcpyDeviceToHost);
+ p.num_queries = ((int32_t*)${max_seqlen_q}->data)[0];
p.num_batches = ${seqstart_q}->shape[0] - 1;
"""
@@ -285,11 +283,8 @@ def instantiate_flash_attention_var_len_template(attrs):
"""Return host code for flash attention with variable sequence lengths."""
template = """
- int _max_seqlen_q, _max_seqlen_k;
- cudaMemcpy(&_max_seqlen_q, (int32_t*)${max_seqlen_q}->data,
sizeof(int32_t),
- cudaMemcpyDeviceToHost);
- cudaMemcpy(&_max_seqlen_k, (int32_t*)${max_seqlen_k}->data,
sizeof(int32_t),
- cudaMemcpyDeviceToHost);
+ int _max_seqlen_q = ((int32_t*)${max_seqlen_q}->data)[0];
+ int _max_seqlen_k = ((int32_t*)${max_seqlen_k}->data)[0];
int batch_size = ${seqstart_q}->shape[0] - 1;
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index bd647486d6..9bec214ab9 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -2116,6 +2116,14 @@ def _test_batched_var_len_attention(mod, seq_lens,
num_head, num_kv_head, head_s
def test_batched_var_len_attention():
@I.ir_module
class Module:
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ ]
+ }
+ )
+
@R.function
def main(
queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
@@ -2134,7 +2142,7 @@ def test_batched_var_len_attention():
cumsum = R.call_dps_packed(
"tvm.contrib.thrust.sum_scan", seq_lens,
out_sinfo=seq_lens.struct_info
)
- max_seqlen_q = R.max(seq_lens)
+ max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0")
seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
k = R.reshape(keys, R.shape([1, num_tokens, 128, 32]))
@@ -2161,6 +2169,14 @@ def test_batched_var_len_attention():
def test_batched_var_len_multi_query_attention():
@I.ir_module
class Module:
+ I.module_global_infos(
+ {
+ "vdevice": [
+ I.vdevice("llvm"),
+ ]
+ }
+ )
+
@R.function
def main(
queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
@@ -2179,7 +2195,7 @@ def test_batched_var_len_multi_query_attention():
cumsum = R.call_dps_packed(
"tvm.contrib.thrust.sum_scan", seq_lens,
out_sinfo=seq_lens.struct_info
)
- max_seqlen_q = R.max(seq_lens)
+ max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0")
seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
k = R.reshape(keys, R.shape([1, num_tokens, 16, 32]))