This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 1ba11f69a7 [Unity][BYOC] Add support for sliding window in attention 
op (#15951)
1ba11f69a7 is described below

commit 1ba11f69a786f83187cd158cb03f28659649ab26
Author: masahi <[email protected]>
AuthorDate: Fri Oct 20 06:03:51 2023 +0900

    [Unity][BYOC] Add support for sliding window in attention op (#15951)
    
    * update flash rev
    
    * import fix
    
    * update
    
    * add window_size attribute
    
    * add byoc support
    
    * wip test
    
    * wip
    
    * wip
    
    * wip
    
    * wip
    
    * numpy ref and cutlass res match
    
    * works
    
    * doc
    
    * update rev
    
    * minor
    
    * lint
---
 3rdparty/libflash_attn                            |  2 +-
 include/tvm/relax/attrs/nn.h                      |  2 +
 python/tvm/contrib/cutlass/attention_operation.py |  4 ++
 python/tvm/contrib/cutlass/build.py               |  3 +
 python/tvm/contrib/cutlass/gen_tensor_op.py       | 13 +++++
 python/tvm/relax/op/nn/__init__.py                |  1 +
 python/tvm/relax/op/nn/nn.py                      | 11 +++-
 python/tvm/relax/transform/legalize_ops/nn.py     |  6 ++
 src/relax/op/nn/attention.cc                      |  8 ++-
 src/relax/op/nn/attention.h                       |  2 +-
 tests/python/relax/test_codegen_cutlass.py        | 69 ++++++++++++++++++++++-
 11 files changed, 112 insertions(+), 9 deletions(-)

diff --git a/3rdparty/libflash_attn b/3rdparty/libflash_attn
index 63cce0ca8f..c1d793ad93 160000
--- a/3rdparty/libflash_attn
+++ b/3rdparty/libflash_attn
@@ -1 +1 @@
-Subproject commit 63cce0ca8fa6bfca1982b342588273641cc5b86b
+Subproject commit c1d793ad939c8ec3cec351db84bc80808e4d34c3
diff --git a/include/tvm/relax/attrs/nn.h b/include/tvm/relax/attrs/nn.h
index 0d895dccb1..424874bd75 100644
--- a/include/tvm/relax/attrs/nn.h
+++ b/include/tvm/relax/attrs/nn.h
@@ -379,12 +379,14 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> 
{
 struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
   Optional<FloatImm> scale;
   Optional<String> causal_mask;
+  Optional<IntImm> window_size;
 
   TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") {
     TVM_ATTR_FIELD(scale).describe(
         "The custom scale applied before the softmax. The default value is 1 / 
sqrt(head_dim).");
     TVM_ATTR_FIELD(causal_mask)
         .describe("The type of the causal mask, i.e. 'TopLeft' and 
'BottomRight'.");
+    TVM_ATTR_FIELD(window_size).describe("The size of the window for 
sliding-window attention.");
   }
 };  // struct AttentionAttrs
 
diff --git a/python/tvm/contrib/cutlass/attention_operation.py 
b/python/tvm/contrib/cutlass/attention_operation.py
index 5766e2cb2d..5579819001 100644
--- a/python/tvm/contrib/cutlass/attention_operation.py
+++ b/python/tvm/contrib/cutlass/attention_operation.py
@@ -218,6 +218,8 @@ def instantiate_flash_attention_template(attrs):
                            o_row_stride,
                            ${scale},
                            ${is_causal},
+                            ${window_size_left},
+                            ${window_size_right},
                            stream);
     """
 
@@ -268,6 +270,8 @@ def instantiate_flash_attention_template(attrs):
                            o_row_stride,
                            ${scale},
                            ${is_causal},
+                            ${window_size_left},
+                            ${window_size_right},
                            stream);
     """
 
diff --git a/python/tvm/contrib/cutlass/build.py 
b/python/tvm/contrib/cutlass/build.py
index ba3ecedb5b..671bca7d02 100644
--- a/python/tvm/contrib/cutlass/build.py
+++ b/python/tvm/contrib/cutlass/build.py
@@ -950,6 +950,9 @@ class CutlassRelaxFunctionAnnotator(relax.PyExprMutator):
                 if arg in arg_idx:
                     attrs[arg + "_idx"] = arg_idx[arg]
 
+        if op_attrs.window_size:
+            attrs["window_size"] = op_attrs.window_size
+
         return f.with_attrs(attrs)
 
     def handle_norm(self, f, op_type):
diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py 
b/python/tvm/contrib/cutlass/gen_tensor_op.py
index dedb392fcb..e86a02df60 100644
--- a/python/tvm/contrib/cutlass/gen_tensor_op.py
+++ b/python/tvm/contrib/cutlass/gen_tensor_op.py
@@ -770,15 +770,28 @@ def instantiate_template(func_name, annotations, 
func_args):
             # For the causal case (custom mask = "BottomRight"), only use 
flash for multi-query
             # attention workloads. Otherwise, CUTLASS fMHA seems faster for 
causal attention
             # with a single query.
+            # In addition, sliding-window attention is only supported by flash.
             and (
                 int(annotations["custom_mask_type"]) == 0
                 or (int(annotations["custom_mask_type"]) == 2 and is_mqa)
+                or (int(annotations["custom_mask_type"]) == 2 and 
"window_size" in annotations)
             )
             # Flash v2 is currently not supported for sm < 80
             and int(annotations["arch"]) >= 80
             and not is_var_len
         )
 
+        if "window_size" in annotations:
+            assert use_flash, "Sliding-window attention is supported only by 
Flash Attention."
+            assert (
+                int(annotations["custom_mask_type"]) == 2
+            ), "Sliding-window attention is only supported for causal with 
bottom right mask."
+            attrs["window_size_left"] = int(annotations["window_size"]) - 1
+            attrs["window_size_right"] = 0
+        else:
+            attrs["window_size_left"] = -1
+            attrs["window_size_right"] = -1
+
         if use_flash:
             headers.append("flash.h")
             attrs["is_causal"] = int(annotations["custom_mask_type"]) == 2
diff --git a/python/tvm/relax/op/nn/__init__.py 
b/python/tvm/relax/op/nn/__init__.py
index d1569e11cb..9f01086a69 100644
--- a/python/tvm/relax/op/nn/__init__.py
+++ b/python/tvm/relax/op/nn/__init__.py
@@ -18,6 +18,7 @@
 from .nn import (
     adaptive_avg_pool2d,
     attention,
+    attention_var_len,
     avg_pool2d,
     batch_norm,
     conv1d,
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 3eddd5f591..5adf38d7d6 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -1194,6 +1194,7 @@ def attention(
     bias: Optional[Expr] = None,
     scale: Optional[FloatImm] = None,
     causal_mask: Optional[str] = None,
+    window_size: Optional[int] = None,
 ) -> Expr:
     r"""Computes fused multi head attention.
 
@@ -1265,6 +1266,8 @@ def attention(
             [[1, 1, 1, 0],
             [1, 1, 1, 1]]
 
+    window_size: Optional[int]
+        The size of the window for sliding-window attention.
 
     Returns
     -------
@@ -1272,7 +1275,9 @@ def attention(
         The computed result. The layout of the output should be
         (batch_size, seq_len, num_head, head_dim_v).
     """
-    return _ffi_api.attention(query, key, value, bias, scale, causal_mask)  # 
type: ignore
+    return _ffi_api.attention(
+        query, key, value, bias, scale, causal_mask, window_size
+    )  # type: ignore
 
 
 def attention_var_len(
@@ -1285,6 +1290,7 @@ def attention_var_len(
     max_seqlen_k: Optional[Expr] = None,
     scale: Optional[FloatImm] = None,
     causal_mask: Optional[str] = None,
+    window_size: Optional[int] = None,
 ) -> Expr:
     r"""Computes fused multi head attention over batched sequences of variable 
lengths.
 
@@ -1348,6 +1354,8 @@ def attention_var_len(
         [[1, 1, 1, 0],
          [1, 1, 1, 1]]
 
+    window_size: Optional[int]
+        The size of the window for sliding-window attention.
 
     Returns
     -------
@@ -1368,4 +1376,5 @@ def attention_var_len(
         max_seqlen_k,
         scale,
         causal_mask,
+        window_size,
     )  # type: ignore
diff --git a/python/tvm/relax/transform/legalize_ops/nn.py 
b/python/tvm/relax/transform/legalize_ops/nn.py
index 894cbad346..a82f54b84c 100644
--- a/python/tvm/relax/transform/legalize_ops/nn.py
+++ b/python/tvm/relax/transform/legalize_ops/nn.py
@@ -456,6 +456,9 @@ def _te_attention(
 
 @register_legalize("relax.nn.attention")
 def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
+    assert (
+        call.attrs.window_size is None
+    ), "Legalization for sliding-window attention is not supported yet."
     return bb.call_te(
         _te_attention,
         call.args[0],
@@ -470,6 +473,9 @@ def _nn_attention(bb: BlockBuilder, call: Call) -> Expr:
 
 @register_legalize("relax.nn.attention_bias")
 def _nn_attention_bias(bb: BlockBuilder, call: Call) -> Expr:
+    assert (
+        call.attrs.window_size is None
+    ), "Legalization for sliding-window attention is not supported yet."
     return bb.call_te(
         _te_attention,
         call.args[0],
diff --git a/src/relax/op/nn/attention.cc b/src/relax/op/nn/attention.cc
index a44169f64d..c6aed941b6 100644
--- a/src/relax/op/nn/attention.cc
+++ b/src/relax/op/nn/attention.cc
@@ -29,10 +29,11 @@ namespace relax {
 TVM_REGISTER_NODE_TYPE(AttentionAttrs);
 
 Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, 
Optional<FloatImm> scale,
-               Optional<String> causal_mask) {
+               Optional<String> causal_mask, Optional<IntImm> window_size) {
   ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
   attrs->scale = scale;
   attrs->causal_mask = causal_mask;
+  attrs->window_size = window_size;
 
   if (bias) {
     return Call(Op::Get("relax.nn.attention_bias"),
@@ -45,10 +46,11 @@ Expr attention(Expr query, Expr key, Expr value, 
Optional<Expr> bias, Optional<F
 
 Expr attention_var_len(Expr query, Expr key, Expr value, Expr seqstart_q, Expr 
seqstart_k,
                        Expr max_seqlen_q, Expr max_seqlen_k, 
Optional<FloatImm> scale,
-                       Optional<String> causal_mask) {
+                       Optional<String> causal_mask, Optional<IntImm> 
window_size) {
   ObjectPtr<AttentionAttrs> attrs = make_object<AttentionAttrs>();
   attrs->scale = scale;
   attrs->causal_mask = causal_mask;
+  attrs->window_size = window_size;
 
   return Call(Op::Get("relax.nn.attention_var_len"),
               {query, key, value, seqstart_q, seqstart_k, max_seqlen_q, 
max_seqlen_k}, Attrs(attrs),
@@ -139,7 +141,7 @@ StructInfo InferStructInfoAttention(const Call& call, const 
BlockBuilder& ctx) {
 
 Call InferMixedPrecisionAttention(const Call& call, const DataType& out_dtype) 
{
   return Downcast<Call>(
-      attention(call->args[0], call->args[1], call->args[2], NullOpt, NullOpt, 
NullOpt));
+      attention(call->args[0], call->args[1], call->args[2], NullOpt, NullOpt, 
NullOpt, NullOpt));
 }
 
 TVM_REGISTER_OP("relax.nn.attention")
diff --git a/src/relax/op/nn/attention.h b/src/relax/op/nn/attention.h
index 8bbf2596ce..346907f8e9 100644
--- a/src/relax/op/nn/attention.h
+++ b/src/relax/op/nn/attention.h
@@ -34,7 +34,7 @@ namespace relax {
 
 /*! \brief fused multi head attention */
 Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, 
Optional<FloatImm> scale,
-               Optional<String> causal_mask);
+               Optional<String> causal_mask, Optional<IntImm> window_size);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 41eabe0600..151e05e9b6 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -571,7 +571,15 @@ def attention_size(request):
 
 
 def get_relax_attention_module(
-    q_shape, k_shape, v_shape, *, dtype, bias_shape=None, qk_scale=None, 
causal_mask=None
+    q_shape,
+    k_shape,
+    v_shape,
+    *,
+    dtype,
+    bias_shape=None,
+    qk_scale=None,
+    causal_mask=None,
+    window_size=None,
 ):
     from tvm.script.ir_builder import IRBuilder
     from tvm.script.ir_builder import relax as relax_builder
@@ -580,6 +588,9 @@ def get_relax_attention_module(
     if qk_scale is not None:
         qk_scale = T.FloatImm("float32", qk_scale)
 
+    if window_size is not None:
+        window_size = T.IntImm("int32", window_size)
+
     with IRBuilder() as builder:
         with relax_builder.function():
             R.func_name("main")
@@ -591,7 +602,7 @@ def get_relax_attention_module(
                 bias = R.arg("bias", R.Tensor(bias_shape, dtype))
 
             with R.dataflow() as frame:
-                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, 
causal_mask))
+                result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, 
causal_mask, window_size))
                 R.output(result)
 
             R.func_ret_value(frame.output_vars[0])
@@ -601,7 +612,9 @@ def get_relax_attention_module(
 
 
 @memoize("topi.tests.test_codegen_cutlass.test_attention_offload")
-def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, qk_scale, 
causal, dtype):
+def get_numpy_attention_ref(
+    b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal, dtype, 
window_size=None
+):
     q = np.random.randn(b, s, n, h).astype(dtype)
     k = np.random.randn(b, s_kv, n, h).astype(dtype)
     v = np.random.randn(b, s_kv, n, h_v).astype(dtype)
@@ -626,11 +639,20 @@ def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, 
bias_shape, qk_scale, causal,
         else:
             raise NotImplementedError()
         score_masked = np.tril(score, k=offset)
+
+        if window_size:
+            score_masked = np.triu(score_masked, -window_size + 1)
+
         score_masked_exp = np.tril(
             np.exp(score_masked - np.max(score_masked, axis=-1, 
keepdims=True)), k=offset
         )
+
+        if window_size:
+            score_masked_exp = np.triu(score_masked_exp, -window_size + 1)
+
         score_masked_sum = np.sum(score_masked_exp, axis=-1, keepdims=True)
         attn = np.divide(score_masked_exp, score_masked_sum)
+
     vt = v.transpose(0, 2, 1, 3)  # b, n, s_kv, h_v
     ref = attn @ vt  # b, n, s, h_v
     return q, k, v, bias, ref.transpose(0, 2, 1, 3)  # b, s, n, h_v
@@ -2096,5 +2118,46 @@ def test_batched_var_len_attention():
     # tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
 
 
+def test_sliding_window():
+    q_shape = (1, 64, 16, 8)
+    k_shape = v_shape = q_shape
+    window_size = 8
+    causal = "BottomRight"
+
+    mod = get_relax_attention_module(
+        q_shape,
+        k_shape,
+        v_shape,
+        dtype="float16",
+        causal_mask=causal,
+        window_size=window_size,
+    )
+
+    q, k, v, _, ref = get_numpy_attention_ref(
+        1, 64, 64, 16, 8, 8, "none", "none", causal, "float16", 
window_size=window_size
+    )
+
+    out = get_result_with_relax_cutlass_offload(mod, q, k, v, 
num_final_bindings=3)
+
+    tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+    ############# xformer reference for verification #############
+
+    # attn_bias = BlockDiagonalCausalMask.from_seqlens([64])
+
+    # if window_size > 0:
+    #     attn_bias = attn_bias.make_local_attention(window_size)
+
+    # query = torch.from_numpy(q).to("cuda")
+    # key = torch.from_numpy(k).to("cuda")
+    # value = torch.from_numpy(v).to("cuda")
+
+    # ref = xops.memory_efficient_attention_forward(
+    #     query, key, value, attn_bias=attn_bias,
+    # ).cpu().numpy()
+
+    # tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to