This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 8b5d64af71 [Unity][BYOC] Check leaked intermediate variables in
cutlass patterns (#14350)
8b5d64af71 is described below
commit 8b5d64af711b9f2adbbad4a12f5984ed18f4d6a8
Author: Lite Ye <[email protected]>
AuthorDate: Tue Mar 21 04:16:01 2023 -0400
[Unity][BYOC] Check leaked intermediate variables in cutlass patterns
(#14350)
Check leaked intermediate variables in cutlass patterns
---
include/tvm/relax/transform.h | 22 ++++++++++--
python/tvm/relax/backend/contrib/cutlass.py | 30 +++++++++++++++-
python/tvm/relax/transform/transform.py | 15 ++++++--
src/relax/transform/fuse_ops.cc | 24 +++++++++----
tests/python/relax/test_codegen_cutlass.py | 55 +++++++++++++++++++++++++++++
5 files changed, 134 insertions(+), 12 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 5a21f76b0b..a3d0d4a0e9 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -275,6 +275,11 @@ class FusionPattern : public ObjectRef {
*/
class PatternCheckContextNode : public Object {
public:
+ /*!
+ * \brief The expression that's matched with the FusionPattern::pattern.
+ */
+ Expr matched_expr;
+
/*!
* \brief A map which contains all expressions matched by the sub patterns in
* FusionPattern::annotation_patterns.
@@ -282,17 +287,27 @@ class PatternCheckContextNode : public Object {
Map<String, Expr> annotated_expr;
/*!
- * \brief A map mapping variable definitions to a set of uses.
+ * \brief Map from variable to its value. It contains variables from
bindings that
+ * is being fused by FuseOpsByPattern.
+ */
+ Map<Var, Expr> matched_bindings;
+
+ /*!
+ * \brief A map mapping variable definitions to a set of uses. It has all
variables
+ * used in the function.
*/
Map<Var, Array<Var>> var_usages;
/*!
- * \brief Map from value to its bound variable.
+ * \brief Map from value to its bound variable. It doesn't have variables
after the
+ * matched expression.
*/
Map<Expr, Var> value_to_bound_var;
void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("matched_expr", &matched_expr);
v->Visit("annotated_expr", &annotated_expr);
+ v->Visit("matched_bindings", &matched_bindings);
v->Visit("var_usages", &var_usages);
v->Visit("value_to_bound_var", &value_to_bound_var);
}
@@ -303,7 +318,8 @@ class PatternCheckContextNode : public Object {
class PatternCheckContext : public ObjectRef {
public:
- PatternCheckContext(Map<String, Expr> annotated_expr, Map<Var, Array<Var>>
var_usages,
+ PatternCheckContext(Expr matched_expr, Map<String, Expr> annotated_expr,
+ Map<Var, Expr> matched_bindings, Map<Var, Array<Var>>
var_usages,
Map<Expr, Var> value_to_bound_var);
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef,
diff --git a/python/tvm/relax/backend/contrib/cutlass.py
b/python/tvm/relax/backend/contrib/cutlass.py
index 4d539928cf..c03c913d63 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -21,7 +21,7 @@ from typing import Mapping, Optional, Sequence, Tuple
import tvm
from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import ShapeExpr, Var, transform
+from tvm.relax import DataflowVar, ShapeExpr, Var, transform
from tvm.relax.transform import PatternCheckContext
from ..pattern_registry import get_patterns_with_prefix, register_patterns
@@ -52,6 +52,28 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):
)
+def _has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
+ """
+ Check whether intermediate variables in the region to be fused are used
outside
+ the fused region.
+ """
+ defined_vars = set(context.matched_bindings.keys())
+ output_var = context.value_to_bound_var[context.matched_expr]
+ intermediate_vars = {v for v in context.matched_bindings if v !=
output_var}
+
+ if any(not isinstance(v, DataflowVar) for v in intermediate_vars):
+ # If intermediate variable is not a DataflowVar, it can be accessed
and potentially
+ # used outside the DataflowBlock.
+ return True
+
+ # Check whether all users of an intermediate variable are inside the fused
region.
+ for var in intermediate_vars:
+ if any(var_user not in defined_vars for var_user in
context.var_usages[var]):
+ return True
+
+ return False
+
+
def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var,
Sequence[Var]]):
if from_var == to_var:
return True
@@ -72,6 +94,9 @@ def _has_dependency(from_var: Var, to_var: Var, var_usages:
Mapping[Var, Sequenc
def _check_conv2d(context: PatternCheckContext) -> bool:
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
+ if _has_leaking_intermediate_variables(context):
+ return False
+
conv2d_call = context.annotated_expr["root"]
data_layout = conv2d_call.attrs.data_layout
kernel_layout = conv2d_call.attrs.kernel_layout
@@ -101,6 +126,9 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
def _check_matmul(context: PatternCheckContext) -> bool:
"""Check if the given matmul workload can be offloaded to CUTLASS."""
+ if _has_leaking_intermediate_variables(context):
+ return False
+
lhs = context.annotated_expr["lhs"]
rhs = context.annotated_expr["rhs"]
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 95f81f7e6c..e8e3d73113 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -277,18 +277,29 @@ class PatternCheckContext(Object):
Parameters
----------
+ matched_expr: Expr
+ The expression that's matched with the FusionPattern.pattern.
+
annotated_expr: Mapping[str, Expr]
A map which contains all expressions matched by the sub patterns in
FusionPattern.annotation_patterns.
+ matched_bindings: Mapping[Var, Expr]
+ Map from variable to its value. It contains variables from bindings
that is
+ being fused by FuseOpsByPattern.
+
var_usages: Mapping[Var, Sequence[Var]]
- A map mapping variable definitions to a set of uses.
+ A map mapping variable definitions to a set of uses. It has all
variables
+ used in the function.
value_to_bound_var: Mapping[Expr, Var]
- Map from value to its bound variable.
+ Map from value to its bound variable. It doesn't have variables after
the
+ matched expression.
"""
+ matched_expr: Expr
annotated_expr: Mapping[str, Expr]
+ matched_bindings: Mapping[Var, Expr]
var_usages: Mapping[Var, Sequence[Var]]
value_to_bound_var: Mapping[Expr, Var]
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 76f53eebc5..15bcf3513c 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -905,6 +905,7 @@ class PatternBasedPartitioner : ExprVisitor {
public:
using Group = GraphPartitioner::Group;
using GroupMap = OperatorFusor::GroupMap;
+ using PatternCheckContext = transform::PatternCheckContext;
using ExprVisitor::VisitExpr_;
using FCheckMatch = runtime::TypedPackedFunc<bool(const
transform::PatternCheckContext&)>;
@@ -944,9 +945,7 @@ class PatternBasedPartitioner : ExprVisitor {
void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
VisitVarDef(binding->var);
if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call),
bindings_)) {
- if (check_ != nullptr &&
-
!check_(transform::PatternCheckContext(GetAnnotatedExpr(matches_opt.value()),
- current_block_use_def_,
value_to_bound_var_))) {
+ if (check_ != nullptr && !check_(CreatePatternCheckContext(call,
matches_opt.value()))) {
return;
}
// If a match is found, put all matching expressions into the same group.
@@ -990,14 +989,24 @@ class PatternBasedPartitioner : ExprVisitor {
return group_map_[bound_var.get()]->FindRoot();
}
- Map<String, Expr> GetAnnotatedExpr(const Map<DFPattern, Expr>
matched_result) {
+ PatternCheckContext CreatePatternCheckContext(const CallNode* call,
+ const Map<DFPattern, Expr>&
matched_result) {
Map<String, Expr> annotated_expr;
for (const auto& it : annotation_pat_) {
if (matched_result.count(it.second)) {
annotated_expr.Set(it.first, matched_result[it.second]);
}
}
- return annotated_expr;
+
+ Map<Var, Expr> matched_bindings;
+ for (const auto& [pat, match] : matched_result) {
+ if (pat->IsInstance<CallPatternNode>()) {
+ matched_bindings.Set(value_to_bound_var_[match], match);
+ }
+ }
+
+ return PatternCheckContext(GetRef<Call>(call), annotated_expr,
matched_bindings,
+ current_block_use_def_, value_to_bound_var_);
}
String pat_name_;
@@ -1123,11 +1132,14 @@ TVM_REGISTER_GLOBAL("relax.transform.FusionPattern")
return FusionPattern(name, pattern, annotation_patterns, check);
});
-PatternCheckContext::PatternCheckContext(Map<String, Expr> annotated_expr,
+PatternCheckContext::PatternCheckContext(Expr matched_expr, Map<String, Expr>
annotated_expr,
+ Map<Var, Expr> matched_bindings,
Map<Var, Array<Var>> var_usages,
Map<Expr, Var> value_to_bound_var) {
ObjectPtr<PatternCheckContextNode> n =
make_object<PatternCheckContextNode>();
+ n->matched_expr = std::move(matched_expr);
n->annotated_expr = std::move(annotated_expr);
+ n->matched_bindings = std::move(matched_bindings);
n->var_usages = std::move(var_usages);
n->value_to_bound_var = std::move(value_to_bound_var);
data_ = std::move(n);
diff --git a/tests/python/relax/test_codegen_cutlass.py
b/tests/python/relax/test_codegen_cutlass.py
index 0bae6801ca..5ea9a9d040 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -492,6 +492,61 @@ def test_cutlass_partition_matmul_blocked(x_shape,
y_shape, transpose_y, dtype):
assert len(mod.functions) == 1
+def test_cutlass_partition_matmul_tuple_return_blocked():
+ @tvm.script.ir_module
+ class TransposedMatmul:
+ @R.function
+ def main(
+ x: R.Tensor((4, 4), "float32"),
+ y: R.Tensor((4, 4), "float32"),
+ ):
+ with R.dataflow():
+ lv1 = R.permute_dims(y)
+ # Because lv1 is used by both lv2 and out, it should stay out
of
+ # the fused function. Otherwise the fused function will return
+ # tuple output, which isn't possible in cutlass, e.g.
+ # @R.function
+ # def fused_relax_permute_dims_relax_matmul(...):
+ # R.func_attr({"Composite": "cutlass.matmul_transposed",
"Primitive": 1})
+ # with R.dataflow():
+ # gv: R.Tensor((4, 4), dtype="float32") =
R.permute_dims(y, axes=None)
+ # gv1: R.Tensor((4, 4), dtype="float32") = R.matmul(x,
gv, out_dtype="void")
+ # R.output(gv, gv1)
+ # return (gv, gv1) # Cannot get `gv` if dispatch to
cutlass kernel.
+ lv2 = R.matmul(x, lv1)
+ out = R.matmul(lv1, lv2)
+ R.output(out)
+
+ return out
+
+ mod = partition_for_cutlass(TransposedMatmul, annotate_codegen=False)
+ for f_var in mod.functions:
+ func = mod[f_var]
+ if func.attrs and "Composite" in func.attrs:
+ # verify that the function is not fused as transposed matmul
+ assert func.attrs["Composite"] == "cutlass.matmul"
+
+
+def test_cutlass_partition_matmul_cyclic_dependency_blocked():
+ @tvm.script.ir_module
+ class Module:
+ @R.function
+ def main(x: R.Tensor((128, 128), "float16"), w: R.Tensor((128, 128),
"float16")):
+ with R.dataflow():
+ # Because lv1 depends on lv, this block should be fused as
matmul instead of matmul_bias.
+ lv = R.matmul(x, w)
+ lv1 = R.power(lv, R.const(2.0, "float16"))
+ lv2 = R.add(lv, lv1)
+ R.output(lv2)
+ return lv2
+
+ mod = partition_for_cutlass(Module, annotate_codegen=False)
+ for f_var in mod.functions:
+ func = mod[f_var]
+ if func.attrs and "Composite" in func.attrs:
+ assert func.attrs["Composite"] == "cutlass.matmul"
+
+
@pytest.fixture(params=["float16", "float32"])
def attention_dtype(request):
return request.param