This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f19e6835fe [Unity][BYOC] Add check for stacked attention patterns
(#14664)
f19e6835fe is described below
commit f19e6835fea7fe20a750099b9ebfefb6e111e6b0
Author: Yaxing Cai <[email protected]>
AuthorDate: Tue Apr 18 22:07:59 2023 -0700
[Unity][BYOC] Add check for stacked attention patterns (#14664)
* [Unity][BYOC] Add check for stacked attention patterns
This PR is a follow up for #14608 and #14649. In this PR, we add the checks
for the fused stacked attention patterns. So we only enable the fusion of
`stacked_qkv` with `ndim=3` and the `split/strided_slice axis=2`.
* check the order of strided_slice
---
python/tvm/relax/backend/contrib/cutlass.py | 31 +++++++++++++++++++++++++++++
python/tvm/relax/backend/patterns.py | 10 ++++++----
src/relax/transform/fuse_ops.cc | 2 +-
3 files changed, 38 insertions(+), 5 deletions(-)
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 06edd9febf..0c2f38e300 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -232,6 +232,33 @@ def residual_block_patterns():
return patterns
+def _check_stacked_attention(context: PatternCheckContext) -> bool:
+ """Check if the given stacked attention workload can be offloaded to
CUTLASS."""
+ if _has_leaking_intermediate_variables(context):
+ return False
+ if not context.annotated_expr["stacked_qkv"].struct_info.ndim == 3:
+ return False
+ if "split" in context.annotated_expr:
+ split_op = context.annotated_expr["split"]
+ if not split_op.attrs.axis == 2:
+ return False
+ else:
+ last_end = 0
+ for name in ["query", "key", "value"]:
+ assert f"strided_slice_{name}" in context.annotated_expr
+ strided_slice_op = context.annotated_expr[f"strided_slice_{name}"]
+ if list(strided_slice_op.attrs.axes) != [2]:
+ return False
+ if list(strided_slice_op.attrs.begin) != [last_end]:
+ return False
+ if not len(strided_slice_op.attrs.end) == 1:
+ return False
+ last_end = strided_slice_op.attrs.end[0]
+ if list(strided_slice_op.attrs.strides) != [1]:
+ return False
+ return True
+
+
def attention_patterns():
"""
Returns a list of all attention patterns in cutlass BYOC backend.
@@ -248,18 +275,22 @@ def attention_patterns():
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(start_op="split"),
+ _check_stacked_attention,
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(start_op="split", with_bias=True),
+ _check_stacked_attention,
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(start_op="strided_slice"),
+ _check_stacked_attention,
),
(
"cutlass.stacked_attention",
*make_stacked_attention_pattern(start_op="strided_slice",
with_bias=True),
+ _check_stacked_attention,
),
]
diff --git a/python/tvm/relax/backend/patterns.py
b/python/tvm/relax/backend/patterns.py
index 6197fe44ca..7119c6c4b0 100644
--- a/python/tvm/relax/backend/patterns.py
+++ b/python/tvm/relax/backend/patterns.py
@@ -220,15 +220,16 @@ def make_stacked_attention_pattern(start_op: str,
with_bias: bool = False):
check function and codegen.
"""
stacked_qkv = wildcard()
+ ops = {}
if start_op == "split":
- qkv_tuple = is_op("relax.split")(stacked_qkv)
+ ops["split"] = qkv_tuple = is_op("relax.split")(stacked_qkv)
query_raw = is_tuple_get_item(qkv_tuple, 0)
key_raw = is_tuple_get_item(qkv_tuple, 1)
value_raw = is_tuple_get_item(qkv_tuple, 2)
elif start_op == "strided_slice":
- query_raw = is_op("relax.strided_slice")(stacked_qkv)
- key_raw = is_op("relax.strided_slice")(stacked_qkv)
- value_raw = is_op("relax.strided_slice")(stacked_qkv)
+ ops["strided_slice_query"] = query_raw =
is_op("relax.strided_slice")(stacked_qkv)
+ ops["strided_slice_key"] = key_raw =
is_op("relax.strided_slice")(stacked_qkv)
+ ops["strided_slice_value"] = value_raw =
is_op("relax.strided_slice")(stacked_qkv)
else:
raise NotImplementedError()
query_reshape_list = wildcard()
@@ -242,6 +243,7 @@ def make_stacked_attention_pattern(start_op: str,
with_bias: bool = False):
"query_reshape_list": query_reshape_list,
"key_reshape_list": key_reshape_list,
"value_reshape_list": value_reshape_list,
+ **ops,
}
if with_bias:
bias = wildcard()
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index adce61f4b8..c9c36bfcd8 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -1045,7 +1045,7 @@ class PatternBasedPartitioner : ExprVisitor {
Map<Var, Expr> matched_bindings;
for (const auto& [pat, match] : matched_result) {
- if (pat->IsInstance<CallPatternNode>()) {
+ if (pat->IsInstance<CallPatternNode>() ||
pat->IsInstance<TupleGetItemPatternNode>()) {
matched_bindings.Set(value_to_bound_var_[match], match);
}
}