This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f19e6835fe [Unity][BYOC] Add check for stacked attention patterns 
(#14664)
f19e6835fe is described below

commit f19e6835fea7fe20a750099b9ebfefb6e111e6b0
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Apr 18 22:07:59 2023 -0700

    [Unity][BYOC] Add check for stacked attention patterns (#14664)
    
    * [Unity][BYOC] Add check for stacked attention patterns
    
    This PR is a follow up for #14608 and #14649. In this PR, we add the checks 
for the fused stacked attention patterns. So we only enable the fusion of 
`stacked_qkv` with `ndim=3` and the `split/strided_slice axis=2`.
    
    * check the order of strided_slice
---
 python/tvm/relax/backend/contrib/cutlass.py | 31 +++++++++++++++++++++++++++++
 python/tvm/relax/backend/patterns.py        | 10 ++++++----
 src/relax/transform/fuse_ops.cc             |  2 +-
 3 files changed, 38 insertions(+), 5 deletions(-)

diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 06edd9febf..0c2f38e300 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -232,6 +232,33 @@ def residual_block_patterns():
     return patterns
 
 
+def _check_stacked_attention(context: PatternCheckContext) -> bool:
+    """Check if the given stacked attention workload can be offloaded to 
CUTLASS."""
+    if _has_leaking_intermediate_variables(context):
+        return False
+    if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3:
+        return False
+    if "split" in context.annotated_expr:
+        split_op = context.annotated_expr["split"]
+        if not split_op.attrs.axis == 2:
+            return False
+    else:
+        last_end = 0
+        for name in ["query", "key", "value"]:
+            assert f"strided_slice_{name}" in context.annotated_expr
+            strided_slice_op = context.annotated_expr[f"strided_slice_{name}"]
+            if list(strided_slice_op.attrs.axes) != [2]:
+                return False
+            if list(strided_slice_op.attrs.begin) != [last_end]:
+                return False
+            if not len(strided_slice_op.attrs.end) == 1:
+                return False
+            last_end = strided_slice_op.attrs.end[0]
+            if list(strided_slice_op.attrs.strides) != [1]:
+                return False
+    return True
+
+
 def attention_patterns():
     """
     Returns a list of all attention patterns in cutlass BYOC backend.
@@ -248,18 +275,22 @@ def attention_patterns():
         (
             "cutlass.stacked_attention",
             *make_stacked_attention_pattern(start_op="split"),
+            _check_stacked_attention,
         ),
         (
             "cutlass.stacked_attention",
             *make_stacked_attention_pattern(start_op="split", with_bias=True),
+            _check_stacked_attention,
         ),
         (
             "cutlass.stacked_attention",
             *make_stacked_attention_pattern(start_op="strided_slice"),
+            _check_stacked_attention,
         ),
         (
             "cutlass.stacked_attention",
             *make_stacked_attention_pattern(start_op="strided_slice", 
with_bias=True),
+            _check_stacked_attention,
         ),
     ]
 
diff --git a/python/tvm/relax/backend/patterns.py 
b/python/tvm/relax/backend/patterns.py
index 6197fe44ca..7119c6c4b0 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -220,15 +220,16 @@ def make_stacked_attention_pattern(start_op: str, 
with_bias: bool = False):
         check function and codegen.
     """
     stacked_qkv = wildcard()
+    ops = {}
     if start_op == "split":
-        qkv_tuple = is_op("relax.split")(stacked_qkv)
+        ops["split"] = qkv_tuple = is_op("relax.split")(stacked_qkv)
         query_raw = is_tuple_get_item(qkv_tuple, 0)
         key_raw = is_tuple_get_item(qkv_tuple, 1)
         value_raw = is_tuple_get_item(qkv_tuple, 2)
     elif start_op == "strided_slice":
-        query_raw = is_op("relax.strided_slice")(stacked_qkv)
-        key_raw = is_op("relax.strided_slice")(stacked_qkv)
-        value_raw = is_op("relax.strided_slice")(stacked_qkv)
+        ops["strided_slice_query"] = query_raw = 
is_op("relax.strided_slice")(stacked_qkv)
+        ops["strided_slice_key"] = key_raw = 
is_op("relax.strided_slice")(stacked_qkv)
+        ops["strided_slice_value"] = value_raw = 
is_op("relax.strided_slice")(stacked_qkv)
     else:
         raise NotImplementedError()
     query_reshape_list = wildcard()
@@ -242,6 +243,7 @@ def make_stacked_attention_pattern(start_op: str, 
with_bias: bool = False):
         "query_reshape_list": query_reshape_list,
         "key_reshape_list": key_reshape_list,
         "value_reshape_list": value_reshape_list,
+        **ops,
     }
     if with_bias:
         bias = wildcard()
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index adce61f4b8..c9c36bfcd8 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1045,7 +1045,7 @@ class PatternBasedPartitioner : ExprVisitor {
 
     Map<Var, Expr> matched_bindings;
     for (const auto& [pat, match] : matched_result) {
-      if (pat->IsInstance<CallPatternNode>()) {
+      if (pat->IsInstance<CallPatternNode>() || 
pat->IsInstance<TupleGetItemPatternNode>()) {
         matched_bindings.Set(value_to_bound_var_[match], match);
       }
     }

Reply via email to