This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 1ba11f69a7 [Unity][BYOC] Add support for sliding window in attention
op (#15951)
1ba11f69a7 is described below
commit 1ba11f69a786f83187cd158cb03f28659649ab26
Author: masahi <[email protected]>
AuthorDate: Fri Oct 20 06:03:51 2023 +0900
[Unity][BYOC] Add support for sliding window in attention op (#15951)
* update flash rev
* import fix
* update
* add window_size attribute
* add byoc support
* wip test
* wip
* wip
* wip
* wip
* numpy ref and cutlass res match
* works
* doc
* update rev
* minor
* lint
---
3rdparty/libflash_attn | 2 +-
include/tvm/relax/attrs/nn.h | 2 +
python/tvm/contrib/cutlass/attention_operation.py | 4 ++
python/tvm/contrib/cutlass/build.py | 3 +
python/tvm/contrib/cutlass/gen_tensor_op.py | 13 +++++
python/tvm/relax/op/nn/__init__.py | 1 +
python/tvm/relax/op/nn/nn.py | 11 +++-
python/tvm/relax/transform/legalize_ops/nn.py | 6 ++
src/relax/op/nn/attention.cc | 8 ++-
src/relax/op/nn/attention.h | 2 +-
tests/python/relax/test_codegen_cutlass.py | 69 ++++++++++++++++++++++-
11 files changed, 112 insertions(+), 9 deletions(-)
diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn
index 63cce0ca8f..c1d793ad93 160000
--- a/3rdparty/libflash_attn
+++ b/3rdparty/libflash_attn
@@ -1 +1 @@
-Subproject commit 63cce0ca8fa6bfca1982b342588273641cc5b86b
+Subproject commit c1d793ad939c8ec3cec351db84bc80808e4d34c3
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 0d895dccb1..424874bd75 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -379,12 +379,14 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs>
{
struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
Optional<FloatImm> scale;
Optional<String> causal_mask;
+ Optional<IntImm> window_size;
TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") {
TVM_ATTR_FIELD(scale).describe(
"The custom scale applied before the softmax. The default value is 1 /
sqrt(head_dim).");
TVM_ATTR_FIELD(causal_mask)
.describe("The type of the causal mask, i.e. 'TopLeft' and
'BottomRight'.");
+ TVM_ATTR_FIELD(window_size).describe("The size of the window for
sliding-window attention.");
}
}; // struct AttentionAttrs
diff --git a/python/tvm/contrib/cutlass/attention_operation.py
b/python/tvm/contrib/cutlass/attention_operation.py
index 5766e2cb2d..5579819001 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -218,6 +218,8 @@ def instantiate_flash_attention_template(attrs):
o_row_stride,
${scale},
${is_causal},
+ ${window_size_left},
+ ${window_size_right},
stream);
"""
@@ -268,6 +270,8 @@ def instantiate_flash_attention_template(attrs):
o_row_stride,
${scale},
${is_causal},
+ ${window_size_left},
+ ${window_size_right},
stream);
"""
diff --git a/python/tvm/contrib/cutlass/build.py
b/python/tvm/contrib/cutlass/build.py
index ba3ecedb5b..671bca7d02 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -950,6 +950,9 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
if arg in arg_idx:
attrs[arg + "_idx"] = arg_idx[arg]
+ if op_attrs.window_size:
+ attrs["window_size"] = op_attrs.window_size
+
return f.with_attrs(attrs)
def handle_norm(self, f, op_type):
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index dedb392fcb..e86a02df60 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -770,15 +770,28 @@ def instantiate_template(func_name, annotations,
func_args):
# For the causal case (custom mask = "BottomRight"), only use
flash for multi-query
# attention workloads. Otherwise, CUTLASS fMHA seems faster for
causal attention
# with a single query.
+ # In addition, sliding-window attention is only supported by flash.
and (
int(annotations["custom_mask_type"]) == 0
or (int(annotations["custom_mask_type"]) == 2 and is_mqa)
+ or (int(annotations["custom_mask_type"]) == 2 and
"window_size" in annotations)
)
# Flash v2 is currently not supported for sm < 80
and int(annotations["arch"]) >= 80
and not is_var_len
)
+ if "window_size" in annotations:
+ assert use_flash, "Sliding-window attention is supported only by
Flash Attention."
+ assert (
+ int(annotations["custom_mask_type"]) == 2
+ ), "Sliding-window attention is only supported for causal with
bottom right mask."
+ attrs["window_size_left"] = int(annotations["window_size"]) - 1
+ attrs["window_size_right"] = 0
+ else:
+ attrs["window_size_left"] = -1
+ attrs["window_size_right"] = -1
+
if use_flash:
headers.append("flash.h")
attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
diff --git a/python/tvm/relax/op/nn/__init__.py
b/python/tvm/relax/op/nn/__init__.py
index d1569e11cb..9f01086a69 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -18,6 +18,7 @@
from .nn import (
adaptive_avg_pool2d,
attention,
+ attention_var_len,
avg_pool2d,
batch_norm,
conv1d,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 3eddd5f591..5adf38d7d6 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1194,6 +1194,7 @@ def attention(
bias: Optional[Expr] = None,
scale: Optional[FloatImm] = None,
causal_mask: Optional[str] = None,
+ window_size: Optional[int] = None,
) -> Expr:
r"""Computes fused multi head attention.
@@ -1265,6 +1266,8 @@ def attention(
[[1, 1, 1, 0],
[1, 1, 1, 1]]
+ window_size: Optional[int]
+ The size of the window for sliding-window attention.
Returns
-------
@@ -1272,7 +1275,9 @@ def attention(
The computed result. The layout of the output should be
(batch_size, seq_len, num_head, head_dim_v).
"""
- return _ffi_api.attention(query, key, value, bias, scale, causal_mask) #
type: ignore
+ return _ffi_api.attention(
+ query, key, value, bias, scale, causal_mask, window_size
+ ) # type: ignore
def attention_var_len(
@@ -1285,6 +1290,7 @@ def attention_var_len(
max_seqlen_k: Optional[Expr] = None,
scale: Optional[FloatImm] = None,
causal_mask: Optional[str] = None,
+ window_size: Optional[int] = None,
) -> Expr:
r"""Computes fused multi head attention over batched sequences of variable
lengths.
@@ -1348,6 +1354,8 @@ def attention_var_len(
[[1, 1, 1, 0],
[1, 1, 1, 1]]
+ window_size: Optional[int]
+ The size of the window for sliding-window attention.
Returns
-------
@@ -1368,4 +1376,5 @@ def attention_var_len(
max_seqlen_k,
scale,
causal_mask,
+ window_size,
) # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py
b/python/tvm/relax/transform/legalize_ops/nn.py
index 894cbad346..a82f54b84c 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -456,6 +456,9 @@ def _te_attention(
@register_legalize("relax.nn.attention")
def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
+ assert (
+ call.attrs.window_size is None
+ ), "Legalization for sliding-window attention is not supported yet."
return bb.call_te(
_te_attention,
call.args[0],
@@ -470,6 +473,9 @@ def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
@register_legalize("relax.nn.attention_bias")
def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr:
+ assert (
+ call.attrs.window_size is None
+ ), "Legalization for sliding-window attention is not supported yet."
return bb.call_te(
_te_attention,
call.args[0],
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index a44169f64d..c6aed941b6 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -29,10 +29,11 @@ namespace relax {
TVM_REGISTER_NODE_TYPE(AttentionAttrs);
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias,
Optional<FloatImm> scale,
- Optional<String> causal_mask) {
+ Optional<String> causal_mask, Optional<IntImm> window_size) {
ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
attrs->scale = scale;
attrs->causal_mask = causal_mask;
+ attrs->window_size = window_size;
if (bias) {
return Call(Op::Get("relax.nn.attention_bias"),
@@ -45,10 +46,11 @@ Expr attention(Expr query, Expr key, Expr value,
Optional<Expr> bias, Optional<F
Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr
seqstart_k,
Expr max_seqlen_q, Expr max_seqlen_k,
Optional<FloatImm> scale,
- Optional<String> causal_mask) {
+ Optional<String> causal_mask, Optional<IntImm>
window_size) {
ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
attrs->scale = scale;
attrs->causal_mask = causal_mask;
+ attrs->window_size = window_size;
return Call(Op::Get("relax.nn.attention_var_len"),
{query, key, value, seqstart_q, seqstart_k, max_seqlen_q,
max_seqlen_k}, Attrs(attrs),
@@ -139,7 +141,7 @@ StructInfo InferStructInfoAttention(const Call& call, const
BlockBuilder& ctx) {
Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype)
{
return Downcast<Call>(
- attention(call->args[0], call->args[1], call->args[2], NullOpt, NullOpt,
NullOpt));
+ attention(call->args[0], call->args[1], call->args[2], NullOpt, NullOpt,
NullOpt, NullOpt));
}
TVM_REGISTER_OP("relax.nn.attention")
diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h
index 8bbf2596ce..346907f8e9 100644
--- a/src/relax/op/nn/attention.h
+++ b/src/relax/op/nn/attention.h
@@ -34,7 +34,7 @@ namespace relax {
/*! \brief fused multi head attention */
Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias,
Optional<FloatImm> scale,
- Optional<String> causal_mask);
+ Optional<String> causal_mask, Optional<IntImm> window_size);
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 41eabe0600..151e05e9b6 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -571,7 +571,15 @@ def attention_size(request):
def get_relax_attention_module(
- q_shape, k_shape, v_shape, *, dtype, bias_shape=None, qk_scale=None,
causal_mask=None
+ q_shape,
+ k_shape,
+ v_shape,
+ *,
+ dtype,
+ bias_shape=None,
+ qk_scale=None,
+ causal_mask=None,
+ window_size=None,
):
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
@@ -580,6 +588,9 @@ def get_relax_attention_module(
if qk_scale is not None:
qk_scale = T.FloatImm("float32", qk_scale)
+ if window_size is not None:
+ window_size = T.IntImm("int32", window_size)
+
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
@@ -591,7 +602,7 @@ def get_relax_attention_module(
bias = R.arg("bias", R.Tensor(bias_shape, dtype))
with R.dataflow() as frame:
- result = R.emit(R.nn.attention(q, k, v, bias, qk_scale,
causal_mask))
+ result = R.emit(R.nn.attention(q, k, v, bias, qk_scale,
causal_mask, window_size))
R.output(result)
R.func_ret_value(frame.output_vars[0])
@@ -601,7 +612,9 @@ def get_relax_attention_module(
@memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
-def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, qk_scale,
causal, dtype):
+def get_numpy_attention_ref(
+ b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal, dtype,
window_size=None
+):
q = np.random.randn(b, s, n, h).astype(dtype)
k = np.random.randn(b, s_kv, n, h).astype(dtype)
v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
@@ -626,11 +639,20 @@ def get_numpy_attention_ref(b, s, s_kv, n, h, h_v,
bias_shape, qk_scale, causal,
else:
raise NotImplementedError()
score_masked = np.tril(score, k=offset)
+
+ if window_size:
+ score_masked = np.triu(score_masked, -window_size + 1)
+
score_masked_exp = np.tril(
np.exp(score_masked - np.max(score_masked, axis=-1,
keepdims=True)), k=offset
)
+
+ if window_size:
+ score_masked_exp = np.triu(score_masked_exp, -window_size + 1)
+
score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True)
attn = np.divide(score_masked_exp, score_masked_sum)
+
vt = v.transpose(0, 2, 1, 3) # b, n, s_kv, h_v
ref = attn @ vt # b, n, s, h_v
return q, k, v, bias, ref.transpose(0, 2, 1, 3) # b, s, n, h_v
@@ -2096,5 +2118,46 @@ def test_batched_var_len_attention():
# tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+def test_sliding_window():
+ q_shape = (1, 64, 16, 8)
+ k_shape = v_shape = q_shape
+ window_size = 8
+ causal = "BottomRight"
+
+ mod = get_relax_attention_module(
+ q_shape,
+ k_shape,
+ v_shape,
+ dtype="float16",
+ causal_mask=causal,
+ window_size=window_size,
+ )
+
+ q, k, v, _, ref = get_numpy_attention_ref(
+ 1, 64, 64, 16, 8, 8, "none", "none", causal, "float16",
window_size=window_size
+ )
+
+ out = get_result_with_relax_cutlass_offload(mod, q, k, v,
num_final_bindings=3)
+
+ tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+ ############# xformer reference for verification #############
+
+ # attn_bias = BlockDiagonalCausalMask.from_seqlens([64])
+
+ # if window_size > 0:
+ # attn_bias = attn_bias.make_local_attention(window_size)
+
+ # query = torch.from_numpy(q).to("cuda")
+ # key = torch.from_numpy(k).to("cuda")
+ # value = torch.from_numpy(v).to("cuda")
+
+ # ref = xops.memory_efficient_attention_forward(
+ # query, key, value, attn_bias=attn_bias,
+ # ).cpu().numpy()
+
+ # tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
if __name__ == "__main__":
tvm.testing.main()