This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ef39f37f49 [Unity][BYOC] Support variable-length attention by flash
attention (#15959)
ef39f37f49 is described below
commit ef39f37f49c8dbb885d56a92e97427d3e4ec10c4
Author: masahi <[email protected]>
AuthorDate: Thu Oct 26 07:43:25 2023 +0900
[Unity][BYOC] Support variable-length attention by flash attention (#15959)
* works
* add test
* fix tests
---
3rdparty/libflash_attn | 2 +-
python/tvm/contrib/cutlass/attention_operation.py | 56 +++++++
python/tvm/contrib/cutlass/gen_tensor_op.py | 16 +-
tests/python/relax/test_codegen_cutlass.py | 190 ++++++++++++++++------
4 files changed, 206 insertions(+), 58 deletions(-)
diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn
index c1d793ad93..55d3603f74 160000
--- a/3rdparty/libflash_attn
+++ b/3rdparty/libflash_attn
@@ -1 +1 @@
-Subproject commit c1d793ad939c8ec3cec351db84bc80808e4d34c3
+Subproject commit 55d3603f741eb68e82640ff55ccea4b17dd8053e
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 5579819001..7084a105c8 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -279,3 +279,59 @@ def instantiate_flash_attention_template(attrs):
return substitute_template(template_stacked, attrs)
return substitute_template(template, attrs)
+
+
+def instantiate_flash_attention_var_len_template(attrs):
+ """Return host code for flash attention with variable sequence lengths."""
+
+ template = """
+ int _max_seqlen_q, _max_seqlen_k;
+ cudaMemcpy(&_max_seqlen_q, (int32_t*)${max_seqlen_q}->data,
sizeof(int32_t),
+ cudaMemcpyDeviceToHost);
+ cudaMemcpy(&_max_seqlen_k, (int32_t*)${max_seqlen_k}->data,
sizeof(int32_t),
+ cudaMemcpyDeviceToHost);
+
+ int batch_size = ${seqstart_q}->shape[0] - 1;
+
+ int q_head_stride = ${head_dim};
+ int k_head_stride = ${head_dim};
+ int v_head_stride = ${head_dim};
+ int o_head_stride = ${head_dim};
+ int q_row_stride = q_head_stride * ${num_q_heads};
+ int k_row_stride = k_head_stride * ${num_kv_heads};
+ int v_row_stride = v_head_stride * ${num_kv_heads};
+ int o_row_stride = o_head_stride * ${num_q_heads};
+
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ cudaStream_t stream = static_cast<cudaStream_t>((*func)().operator
void*());
+
+ flash_attn::flash_attention_var_len_forward(
+ static_cast<const
cutlass::half_t*>(${query}->data),
+ static_cast<const cutlass::half_t*>(${key}->data),
+ static_cast<const cutlass::half_t*>(${value}->data),
+ static_cast<const int*>(${seqstart_q}->data),
+ static_cast<const int*>(${seqstart_k}->data),
+ static_cast<cutlass::half_t*>(out0->data),
+ batch_size,
+ _max_seqlen_q,
+ _max_seqlen_k,
+ ${num_q_heads},
+ ${num_kv_heads},
+ ${head_dim},
+ q_head_stride,
+ k_head_stride,
+ v_head_stride,
+ o_head_stride,
+ q_row_stride,
+ k_row_stride,
+ v_row_stride,
+ o_row_stride,
+ ${scale},
+ ${is_causal},
+ ${is_causal} ? _max_seqlen_k : -1,
+ ${window_size_right},
+ stream);
+ """
+
+ return substitute_template(template, attrs)
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index e86a02df60..d42791d71b 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -32,6 +32,7 @@ from . import _ffi_api as ffi
from .attention_operation import (
instantiate_attention_template,
instantiate_flash_attention_template,
+ instantiate_flash_attention_var_len_template,
)
from .conv2d_operation import instantiate_conv2d_template
from .gemm_operation import instantiate_gemm_template, emit_fp16A_intB_matmul
@@ -778,7 +779,6 @@ def instantiate_template(func_name, annotations, func_args):
)
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
- and not is_var_len
)
if "window_size" in annotations:
@@ -789,15 +789,23 @@ def instantiate_template(func_name, annotations,
func_args):
attrs["window_size_left"] = int(annotations["window_size"]) - 1
attrs["window_size_right"] = 0
else:
- attrs["window_size_left"] = -1
- attrs["window_size_right"] = -1
+ if int(annotations["custom_mask_type"]) == 2:
+ attrs["window_size_left"] = attrs["num_keys"]
+ attrs["window_size_right"] = 0
+ else:
+ attrs["window_size_left"] = -1
+ attrs["window_size_right"] = -1
if use_flash:
headers.append("flash.h")
attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
attrs["num_q_heads"] = annotations["num_q_heads"]
attrs["num_kv_heads"] = annotations["num_kv_heads"]
- code = instantiate_flash_attention_template(attrs)
+
+ if is_var_len:
+ code = instantiate_flash_attention_var_len_template(attrs)
+ else:
+ code = instantiate_flash_attention_template(attrs)
else:
headers.append("kernel_forward.h")
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 151e05e9b6..f07a0dfcbb 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -611,13 +611,35 @@ def get_relax_attention_module(
return tvm.IRModule({"main": func})
-@memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
def get_numpy_attention_ref(
- b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal, dtype,
window_size=None
+ b,
+ s,
+ s_kv,
+ n,
+ h,
+ h_v,
+ bias_shape,
+ qk_scale,
+ causal,
+ dtype,
+ window_size=None,
+ num_kv_head=None,
):
+ if num_kv_head is None:
+ num_kv_head = n
+
q = np.random.randn(b, s, n, h).astype(dtype)
- k = np.random.randn(b, s_kv, n, h).astype(dtype)
- v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
+ k_orig = np.random.randn(b, s_kv, num_kv_head, h).astype(dtype)
+ v_orig = np.random.randn(b, s_kv, num_kv_head, h_v).astype(dtype)
+
+ if num_kv_head is None:
+ k = k_orig
+ v = v_orig
+ else:
+ factor = n // num_kv_head
+ k = np.repeat(k_orig, factor, axis=2)
+ v = np.repeat(v_orig, factor, axis=2)
+
qt = q.transpose(0, 2, 1, 3) # b, n, s, h
kt = k.transpose(0, 2, 3, 1) # b, n, h, s_kv
if not qk_scale == "none":
@@ -655,7 +677,7 @@ def get_numpy_attention_ref(
vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
ref = attn @ vt # b, n, s, h_v
- return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
+ return q, k_orig, v_orig, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
def test_attention_offload(attention_size, attention_dtype):
@@ -1191,6 +1213,7 @@ def test_attention_rewrite_fp16():
v: R.Tensor((4, 8, 32, 16), dtype="float16"),
bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+ R.func_attr({"num_input": 4})
with R.dataflow():
lv = R.permute_dims(q, axes=[0, 2, 1, 3])
lv1 = R.reshape(lv, R.shape([128, 16, 8]))
@@ -1262,7 +1285,8 @@ def test_attention_rewrite_fp16():
k: R.Tensor((4, 8, 32, 8), dtype="float16"),
v: R.Tensor((4, 8, 32, 16), dtype="float16"),
bias: R.Tensor((4, 32, 16, 8), dtype="float16"),
- ) -> R.Tensor((4, 16, 32, 16), dtype="float16"):
+ ) -> R.Tensor((4, 16, 32, 16), dtype="float32"):
+ R.func_attr({"num_input": 4})
cls = Expected
with R.dataflow():
lv = R.vm.alloc_storage(R.shape([65536]), R.prim_value(0),
R.dtype("uint8"))
@@ -2016,49 +2040,10 @@ def test_attention_rewrite_multi_query():
tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
-def test_batched_var_len_attention():
+def _test_batched_var_len_attention(mod, seq_lens, num_head, num_kv_head,
head_size):
if not tvm.get_global_func("tvm.contrib.thrust.sum_scan", True):
return
- @I.ir_module
- class Module:
- @R.function
- def main(
- queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
- keys: R.Tensor(("num_tokens", 4096), dtype="float16"),
- values: R.Tensor(("num_tokens", 4096), dtype="float16"),
- seq_lens: R.Tensor(("num_seq",), dtype="int32"),
- ) -> R.Tensor(("num_tokens", 4096), dtype="float16"):
- cls = Module
- num_tokens = T.int64()
- num_seq = T.int64()
-
- with R.dataflow():
- # TODO(masahi): Workaround for the broken Relax cumsum op on
GPU.
- # https://github.com/apache/tvm/issues/15851
- cumsum = R.call_dps_packed(
- "tvm.contrib.thrust.sum_scan", seq_lens,
out_sinfo=seq_lens.struct_info
- )
- max_seqlen_q = R.max(seq_lens)
- seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
- q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
- k = R.reshape(keys, R.shape([1, num_tokens, 128, 32]))
- v = R.reshape(values, R.shape([1, num_tokens, 128, 32]))
- attn_out = R.nn.attention_var_len(
- q,
- k,
- v,
- seqstart_q,
- max_seqlen_q,
- causal_mask="BottomRight",
- )
- out = R.reshape(attn_out, R.shape([num_tokens, 4096]))
- R.output(out)
- return out
-
- seq_lens = [5, 3, 8]
- num_head = 128
- head_size = 32
hidden_size = num_head * head_size
batched_queries = []
@@ -2068,11 +2053,21 @@ def test_batched_var_len_attention():
for s in seq_lens:
q, k, v, _, ref = get_numpy_attention_ref(
- 1, s, s, num_head, head_size, head_size, "none", "none",
"BottomRight", "float16"
+ 1,
+ s,
+ s,
+ num_head,
+ head_size,
+ head_size,
+ "none",
+ "none",
+ "BottomRight",
+ "float16",
+ num_kv_head=num_kv_head,
)
batched_queries.append(np.reshape(q, [-1, hidden_size]))
- batched_keys.append(np.reshape(k, [-1, hidden_size]))
- batched_values.append(np.reshape(v, [-1, hidden_size]))
+ batched_keys.append(np.reshape(k, [-1, num_kv_head * head_size]))
+ batched_values.append(np.reshape(v, [-1, num_kv_head * head_size]))
batched_refs.append(np.reshape(ref, [-1, hidden_size]))
batched_queries = np.vstack(batched_queries)
@@ -2080,7 +2075,7 @@ def test_batched_var_len_attention():
batched_values = np.vstack(batched_values)
ref = np.vstack(batched_refs)
- mod = partition_for_cutlass(Module)
+ mod = partition_for_cutlass(mod)
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80}})
mod = codegen_pass(mod)
@@ -2099,8 +2094,6 @@ def test_batched_var_len_attention():
"cuda",
)
- tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
-
############# xformer reference for verification #############
# attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens)
@@ -2115,7 +2108,98 @@ def test_batched_var_len_attention():
# ).cpu().numpy()[0]
# out = np.reshape(out, [-1, hidden_size])
- # tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
+def test_batched_var_len_attention():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
+ keys: R.Tensor(("num_tokens", 4096), dtype="float16"),
+ values: R.Tensor(("num_tokens", 4096), dtype="float16"),
+ seq_lens: R.Tensor(("num_seq",), dtype="int32"),
+ ) -> R.Tensor(("num_tokens", 4096), dtype="float16"):
+ R.func_attr({"num_input": 4})
+ cls = Module
+ num_tokens = T.int64()
+ num_seq = T.int64()
+
+ with R.dataflow():
+ # TODO(masahi): Workaround for the broken Relax cumsum op on
GPU.
+ # https://github.com/apache/tvm/issues/15851
+ cumsum = R.call_dps_packed(
+ "tvm.contrib.thrust.sum_scan", seq_lens,
out_sinfo=seq_lens.struct_info
+ )
+ max_seqlen_q = R.max(seq_lens)
+ seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
+ q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
+ k = R.reshape(keys, R.shape([1, num_tokens, 128, 32]))
+ v = R.reshape(values, R.shape([1, num_tokens, 128, 32]))
+ attn_out = R.nn.attention_var_len(
+ q,
+ k,
+ v,
+ seqstart_q,
+ max_seqlen_q,
+ causal_mask="BottomRight",
+ )
+ out = R.reshape(attn_out, R.shape([num_tokens, 4096]))
+ R.output(out)
+ return out
+
+ seq_lens = [5, 3, 8]
+ num_head = 128
+ head_size = 32
+
+ _test_batched_var_len_attention(Module, seq_lens, num_head, num_head,
head_size)
+
+
+def test_batched_var_len_multi_query_attention():
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ queries: R.Tensor(("num_tokens", 4096), dtype="float16"),
+ keys: R.Tensor(("num_tokens", 512), dtype="float16"),
+ values: R.Tensor(("num_tokens", 512), dtype="float16"),
+ seq_lens: R.Tensor(("num_seq",), dtype="int32"),
+ ) -> R.Tensor(("num_tokens", 4096), dtype="float16"):
+ R.func_attr({"num_input": 4})
+ cls = Module
+ num_tokens = T.int64()
+ num_seq = T.int64()
+
+ with R.dataflow():
+ # TODO(masahi): Workaround for the broken Relax cumsum op on
GPU.
+ # https://github.com/apache/tvm/issues/15851
+ cumsum = R.call_dps_packed(
+ "tvm.contrib.thrust.sum_scan", seq_lens,
out_sinfo=seq_lens.struct_info
+ )
+ max_seqlen_q = R.max(seq_lens)
+ seqstart_q = R.concat([R.zeros((1,), "int32"), cumsum])
+ q = R.reshape(queries, R.shape([1, num_tokens, 128, 32]))
+ k = R.reshape(keys, R.shape([1, num_tokens, 16, 32]))
+ v = R.reshape(values, R.shape([1, num_tokens, 16, 32]))
+ attn_out = R.nn.attention_var_len(
+ q,
+ k,
+ v,
+ seqstart_q,
+ max_seqlen_q,
+ causal_mask="BottomRight",
+ )
+ out = R.reshape(attn_out, R.shape([num_tokens, 4096]))
+ R.output(out)
+ return out
+
+ seq_lens = [5, 3, 8]
+ num_head = 128
+ num_kv_head = 16
+ head_size = 32
+
+ _test_batched_var_len_attention(Module, seq_lens, num_head, num_kv_head,
head_size)
def test_sliding_window():