This is an automated email from the ASF dual-hosted git repository.
kparzysz pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cd14d4d7e2 [Unity][Hot fix] Flash attention offload bug due to typo
(#15512)
cd14d4d7e2 is described below
commit cd14d4d7e2eee87e616dd074f6b8a7770fe8854a
Author: masahi <[email protected]>
AuthorDate: Thu Aug 10 02:02:22 2023 +0900
[Unity][Hot fix] Flash attention offload bug due to typo (#15512)
* flash attention hot fix
* also add cuda graph support
---
python/tvm/contrib/cutlass/attention_operation.py | 14 +++++++++++---
python/tvm/contrib/cutlass/gen_tensor_op.py | 2 +-
2 files changed, 12 insertions(+), 4 deletions(-)
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 9b4fa78127..67a68df442 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -178,6 +178,10 @@ def instantiate_flash_attention_template(attrs):
int v_batch_stride = v_row_stride * ${num_keys};
int o_batch_stride = o_row_stride * ${num_queries};
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator
void*());
+
flash_attn::flash_attention_forward(
static_cast<const
cutlass::half_t*>(${query}->data),
static_cast<const cutlass::half_t*>(${key}->data),
@@ -203,7 +207,7 @@ def instantiate_flash_attention_template(attrs):
o_row_stride,
${scale},
${is_causal},
- nullptr);
+ stream);
"""
template_stacked = """
@@ -224,8 +228,12 @@ def instantiate_flash_attention_template(attrs):
int v_batch_stride = v_row_stride * ${num_keys};
int o_batch_stride = o_row_stride * ${num_queries};
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator
void*());
+
flash_attn::flash_attention_forward(
- static_cast<const cutlass::half_t*>(${qkv}->data),
+ static_cast<const cutlass::half_t*>(${qkv}->data),
static_cast<const cutlass::half_t*>(${qkv}->data) +
${head_dim} * ${num_heads},
static_cast<const cutlass::half_t*>(${qkv}->data) +
${head_dim} * ${num_heads} * 2,
static_cast<cutlass::half_t*>(out0->data),
@@ -249,7 +257,7 @@ def instantiate_flash_attention_template(attrs):
o_row_stride,
${scale},
${is_causal},
- nullptr);
+ stream);
"""
if "qkv" in attrs:
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index 7133193c1e..317030b6ff 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -760,7 +760,7 @@ def instantiate_template(func_name, annotations, func_args):
if use_flash:
headers.append("flash.h")
- attrs["is_causal"] = int(annotations["custom_mask_type"]) == 0
+ attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
code = instantiate_flash_attention_template(attrs)
else:
headers.append("kernel_forward.h")