This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch unity in repository https://gitbox.apache.org/repos/asf/tvm.git
commit abde85edbcc7845ca3b3d2cd7045b52c306d61dc Author: Masahiro Masuda <[email protected]> AuthorDate: Wed Nov 29 11:46:03 2023 +0000 fix window_size_left param for var len attention --- python/tvm/contrib/cutlass/attention_operation.py | 2 +- python/tvm/contrib/cutlass/gen_tensor_op.py | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index da998db6c0..eacc0ec37a 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -324,7 +324,7 @@ def instantiate_flash_attention_var_len_template(attrs): o_row_stride, ${scale}, ${is_causal}, - ${is_causal} ? _max_seqlen_k : -1, + ${window_size_left}, ${window_size_right}, stream); """ diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 15629ddd3d..3e2f6175ed 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -782,6 +782,7 @@ def instantiate_template(func_name, annotations, func_args): and int(annotations["arch"]) >= 80 ) + # See https://github.com/Dao-AILab/flash-attention/blob/main/csrc/flash_attn/flash_api.cpp#L111-L116 if "window_size" in annotations: assert use_flash, "Sliding-window attention is supported only by Flash Attention." assert ( @@ -789,17 +790,19 @@ def instantiate_template(func_name, annotations, func_args): ), "Sliding-window attention is only supported for causal with bottom right mask." attrs["window_size_left"] = int(annotations["window_size"]) - 1 attrs["window_size_right"] = 0 + attrs["is_causal"] = False else: if int(annotations["custom_mask_type"]) == 2: attrs["window_size_left"] = attrs["num_keys"] attrs["window_size_right"] = 0 + attrs["is_causal"] = True else: attrs["window_size_left"] = -1 attrs["window_size_right"] = -1 + attrs["is_causal"] = False if use_flash: headers.append("flash.h") - attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2 attrs["num_q_heads"] = annotations["num_q_heads"] attrs["num_kv_heads"] = annotations["num_kv_heads"]
