This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new cd14d4d7e2 [Unity][Hot fix] Flash attention offload bug due to typo 
(#15512)
cd14d4d7e2 is described below

commit cd14d4d7e2eee87e616dd074f6b8a7770fe8854a
Author: masahi <[email protected]>
AuthorDate: Thu Aug 10 02:02:22 2023 +0900

    [Unity][Hot fix] Flash attention offload bug due to typo (#15512)
    
    * flash attention hot fix
    
    * also add cuda graph support
---
 python/tvm/contrib/cutlass/attention_operation.py | 14 +++++++++++---
 python/tvm/contrib/cutlass/gen_tensor_op.py       |  2 +-
 2 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 9b4fa78127..67a68df442 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -178,6 +178,10 @@ def instantiate_flash_attention_template(attrs):
     int v_batch_stride = v_row_stride * ${num_keys};
     int o_batch_stride = o_row_stride * ${num_queries};
 
+    auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+    ICHECK(func != nullptr);
+    cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
     flash_attn::flash_attention_forward(
                             static_cast<const 
cutlass::half_t*>(${query}->data),
                            static_cast<const cutlass::half_t*>(${key}->data),
@@ -203,7 +207,7 @@ def instantiate_flash_attention_template(attrs):
                            o_row_stride,
                            ${scale},
                            ${is_causal},
-                           nullptr);
+                           stream);
     """
 
     template_stacked = """
@@ -224,8 +228,12 @@ def instantiate_flash_attention_template(attrs):
     int v_batch_stride = v_row_stride * ${num_keys};
     int o_batch_stride = o_row_stride * ${num_queries};
 
+    auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+    ICHECK(func != nullptr);
+    cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator 
void*());
+
     flash_attn::flash_attention_forward(
-    static_cast<const cutlass::half_t*>(${qkv}->data),
+                            static_cast<const cutlass::half_t*>(${qkv}->data),
                            static_cast<const cutlass::half_t*>(${qkv}->data) + 
${head_dim} * ${num_heads},
                            static_cast<const cutlass::half_t*>(${qkv}->data) + 
${head_dim} * ${num_heads} * 2,
                            static_cast<cutlass::half_t*>(out0->data),
@@ -249,7 +257,7 @@ def instantiate_flash_attention_template(attrs):
                            o_row_stride,
                            ${scale},
                            ${is_causal},
-                           nullptr);
+                           stream);
     """
 
     if "qkv" in attrs:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 7133193c1e..317030b6ff 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -760,7 +760,7 @@ def instantiate_template(func_name, annotations, func_args):
 
         if use_flash:
             headers.append("flash.h")
-            attrs["is_causal"] = int(annotations["custom_mask_type"]) == 0
+            attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
             code = instantiate_flash_attention_template(attrs)
         else:
             headers.append("kernel_forward.h")

Reply via email to