This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 484bd4421b [Unity][BYOC] Do not use cudaMemcpy for max_seqlen in var 
len attention (#15989)
484bd4421b is described below

commit 484bd4421b71effcad64b87543f75528b01d1709
Author: masahi <[email protected]>
AuthorDate: Fri Oct 27 04:58:11 2023 +0900

    [Unity][BYOC] Do not use cudaMemcpy for max_seqlen in var len attention 
(#15989)
    
    * Do not use cudaMemcpy for max_seqlen in cutlass var len attention
    
    * black
---
 python/tvm/contrib/cutlass/attention_operation.py | 11 +++--------
 tests/python/relax/test_codegen_cutlass.py        | 20 ++++++++++++++++++--
 2 files changed, 21 insertions(+), 10 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 7084a105c8..da998db6c0 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -35,9 +35,7 @@ def instantiate_attention_template(attrs):
     var_len_template = """
   p.seqstart_q_ptr = (int32_t*)${seqstart_q}->data;
   p.seqstart_k_ptr = (int32_t*)${seqstart_k}->data;
-  // TODO(masahi): Pass max_seqlen_q as an integer
-  cudaMemcpy(&p.num_queries, (int32_t*)${max_seqlen_q}->data, sizeof(int32_t),
-             cudaMemcpyDeviceToHost);
+  p.num_queries = ((int32_t*)${max_seqlen_q}->data)[0];
   p.num_batches = ${seqstart_q}->shape[0] - 1;
 """
 
@@ -285,11 +283,8 @@ def instantiate_flash_attention_var_len_template(attrs):
     """Return host code for flash attention with variable sequence lengths."""
 
     template = """
-    int _max_seqlen_q, _max_seqlen_k;
-    cudaMemcpy(&_max_seqlen_q, (int32_t*)${max_seqlen_q}->data, 
sizeof(int32_t),
-               cudaMemcpyDeviceToHost);
-    cudaMemcpy(&_max_seqlen_k, (int32_t*)${max_seqlen_k}->data, 
sizeof(int32_t),
-               cudaMemcpyDeviceToHost);
+    int _max_seqlen_q = ((int32_t*)${max_seqlen_q}->data)[0];
+    int _max_seqlen_k = ((int32_t*)${max_seqlen_k}->data)[0];
 
     int batch_size = ${seqstart_q}->shape[0] - 1;
 
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index bd647486d6..9bec214ab9 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -2116,6 +2116,14 @@ def _test_batched_var_len_attention(mod, seq_lens, 
num_head, num_kv_head, head_s
 def test_batched_var_len_attention():
     @I.ir_module
     class Module:
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                ]
+            }
+        )
+
         @R.function
         def main(
             queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
@@ -2134,7 +2142,7 @@ def test_batched_var_len_attention():
                 cumsum = R.call_dps_packed(
                     "tvm.contrib.thrust.sum_scan", seq_lens, 
out_sinfo=seq_lens.struct_info
                 )
-                max_seqlen_q = R.max(seq_lens)
+                max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0")
                 seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
                 q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
                 k = R.reshape(keys, R.shape([1, num_tokens, 128, 32]))
@@ -2161,6 +2169,14 @@ def test_batched_var_len_attention():
 def test_batched_var_len_multi_query_attention():
     @I.ir_module
     class Module:
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                ]
+            }
+        )
+
         @R.function
         def main(
             queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
@@ -2179,7 +2195,7 @@ def test_batched_var_len_multi_query_attention():
                 cumsum = R.call_dps_packed(
                     "tvm.contrib.thrust.sum_scan", seq_lens, 
out_sinfo=seq_lens.struct_info
                 )
-                max_seqlen_q = R.max(seq_lens)
+                max_seqlen_q = R.to_vdevice(R.max(seq_lens), "llvm:0")
                 seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
                 q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
                 k = R.reshape(keys, R.shape([1, num_tokens, 16, 32]))

Reply via email to