This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit abde85edbcc7845ca3b3d2cd7045b52c306d61dc
Author: Masahiro Masuda <[email protected]>
AuthorDate: Wed Nov 29 11:46:03 2023 +0000

    fix window_size_left param for var len attention
---
 python/tvm/contrib/cutlass/attention_operation.py | 2 +-
 python/tvm/contrib/cutlass/gen_tensor_op.py       | 5 ++++-
 2 files changed, 5 insertions(+), 2 deletions(-)

diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index da998db6c0..eacc0ec37a 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -324,7 +324,7 @@ def instantiate_flash_attention_var_len_template(attrs):
                            o_row_stride,
                            ${scale},
                            ${is_causal},
-                            ${is_causal} ? _max_seqlen_k : -1,
+                            ${window_size_left},
                             ${window_size_right},
                            stream);
     """
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 15629ddd3d..3e2f6175ed 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -782,6 +782,7 @@ def instantiate_template(func_name, annotations, func_args):
             and int(annotations["arch"]) >= 80
         )
 
+        # See 
https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L111-L116
         if "window_size" in annotations:
             assert use_flash, "Sliding-window attention is supported only by 
Flash Attention."
             assert (
@@ -789,17 +790,19 @@ def instantiate_template(func_name, annotations, 
func_args):
             ), "Sliding-window attention is only supported for causal with 
bottom right mask."
             attrs["window_size_left"] = int(annotations["window_size"]) - 1
             attrs["window_size_right"] = 0
+            attrs["is_causal"] = False
         else:
             if int(annotations["custom_mask_type"]) == 2:
                 attrs["window_size_left"] = attrs["num_keys"]
                 attrs["window_size_right"] = 0
+                attrs["is_causal"] = True
             else:
                 attrs["window_size_left"] = -1
                 attrs["window_size_right"] = -1
+                attrs["is_causal"] = False
 
         if use_flash:
             headers.append("flash.h")
-            attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
             attrs["num_q_heads"] = annotations["num_q_heads"]
             attrs["num_kv_heads"] = annotations["num_kv_heads"]
 

Reply via email to