This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 8b5d64af71 [Unity][BYOC] Check leaked intermediate variables in 
cutlass patterns (#14350)
8b5d64af71 is described below

commit 8b5d64af711b9f2adbbad4a12f5984ed18f4d6a8
Author: Lite Ye <[email protected]>
AuthorDate: Tue Mar 21 04:16:01 2023 -0400

    [Unity][BYOC] Check leaked intermediate variables in cutlass patterns 
(#14350)
    
    Check leaked intermediate variables in cutlass patterns
---
 include/tvm/relax/transform.h               | 22 ++++++++++--
 python/tvm/relax/backend/contrib/cutlass.py | 30 +++++++++++++++-
 python/tvm/relax/transform/transform.py     | 15 ++++++--
 src/relax/transform/fuse_ops.cc             | 24 +++++++++----
 tests/python/relax/test_codegen_cutlass.py  | 55 +++++++++++++++++++++++++++++
 5 files changed, 134 insertions(+), 12 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 5a21f76b0b..a3d0d4a0e9 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -275,6 +275,11 @@ class FusionPattern : public ObjectRef {
  */
 class PatternCheckContextNode : public Object {
  public:
+  /*!
+   * \brief The expression that's matched with the FusionPattern::pattern.
+   */
+  Expr matched_expr;
+
   /*!
    * \brief A map which contains all expressions matched by the sub patterns in
    * FusionPattern::annotation_patterns.
@@ -282,17 +287,27 @@ class PatternCheckContextNode : public Object {
   Map<String, Expr> annotated_expr;
 
   /*!
-   * \brief A map mapping variable definitions to a set of uses.
+   * \brief Map from variable to its value. It contains variables from 
bindings that
+   * is being fused by FuseOpsByPattern.
+   */
+  Map<Var, Expr> matched_bindings;
+
+  /*!
+   * \brief A map mapping variable definitions to a set of uses. It has all 
variables
+   * used in the function.
    */
   Map<Var, Array<Var>> var_usages;
 
   /*!
-   * \brief Map from value to its bound variable.
+   * \brief Map from value to its bound variable. It doesn't have variables 
after the
+   * matched expression.
    */
   Map<Expr, Var> value_to_bound_var;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("matched_expr", &matched_expr);
     v->Visit("annotated_expr", &annotated_expr);
+    v->Visit("matched_bindings", &matched_bindings);
     v->Visit("var_usages", &var_usages);
     v->Visit("value_to_bound_var", &value_to_bound_var);
   }
@@ -303,7 +318,8 @@ class PatternCheckContextNode : public Object {
 
 class PatternCheckContext : public ObjectRef {
  public:
-  PatternCheckContext(Map<String, Expr> annotated_expr, Map<Var, Array<Var>> 
var_usages,
+  PatternCheckContext(Expr matched_expr, Map<String, Expr> annotated_expr,
+                      Map<Var, Expr> matched_bindings, Map<Var, Array<Var>> 
var_usages,
                       Map<Expr, Var> value_to_bound_var);
 
   TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(PatternCheckContext, ObjectRef,
diff --git a/python/tvm/relax/backend/contrib/cutlass.py 
b/python/tvm/relax/backend/contrib/cutlass.py
index 4d539928cf..c03c913d63 100644
--- a/python/tvm/relax/backend/contrib/cutlass.py
+++ b/python/tvm/relax/backend/contrib/cutlass.py
@@ -21,7 +21,7 @@ from typing import Mapping, Optional, Sequence, Tuple
 
 import tvm
 from tvm.contrib.cutlass.build import is_shape_valid_for_cutlass_matmul
-from tvm.relax import ShapeExpr, Var, transform
+from tvm.relax import DataflowVar, ShapeExpr, Var, transform
 from tvm.relax.transform import PatternCheckContext
 
 from ..pattern_registry import get_patterns_with_prefix, register_patterns
@@ -52,6 +52,28 @@ def _is_supported_dtype(lhs_dtype, rhs_dtype):
     )
 
 
+def _has_leaking_intermediate_variables(context: PatternCheckContext) -> bool:
+    """
+    Check whether intermediate variables in the region to be fused are used 
outside
+    the fused region.
+    """
+    defined_vars = set(context.matched_bindings.keys())
+    output_var = context.value_to_bound_var[context.matched_expr]
+    intermediate_vars = {v for v in context.matched_bindings if v != 
output_var}
+
+    if any(not isinstance(v, DataflowVar) for v in intermediate_vars):
+        # If intermediate variable is not a DataflowVar, it can be accessed 
and potentially
+        # used outside the DataflowBlock.
+        return True
+
+    # Check whether all users of an intermediate variable are inside the fused 
region.
+    for var in intermediate_vars:
+        if any(var_user not in defined_vars for var_user in 
context.var_usages[var]):
+            return True
+
+    return False
+
+
 def _has_dependency(from_var: Var, to_var: Var, var_usages: Mapping[Var, 
Sequence[Var]]):
     if from_var == to_var:
         return True
@@ -72,6 +94,9 @@ def _has_dependency(from_var: Var, to_var: Var, var_usages: 
Mapping[Var, Sequenc
 
 def _check_conv2d(context: PatternCheckContext) -> bool:
     """Check if the given conv2d workload can be offloaded to CUTLASS."""
+    if _has_leaking_intermediate_variables(context):
+        return False
+
     conv2d_call = context.annotated_expr["root"]
     data_layout = conv2d_call.attrs.data_layout
     kernel_layout = conv2d_call.attrs.kernel_layout
@@ -101,6 +126,9 @@ def _check_conv2d(context: PatternCheckContext) -> bool:
 
 def _check_matmul(context: PatternCheckContext) -> bool:
     """Check if the given matmul workload can be offloaded to CUTLASS."""
+    if _has_leaking_intermediate_variables(context):
+        return False
+
     lhs = context.annotated_expr["lhs"]
     rhs = context.annotated_expr["rhs"]
 
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 95f81f7e6c..e8e3d73113 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -277,18 +277,29 @@ class PatternCheckContext(Object):
 
     Parameters
     ----------
+    matched_expr: Expr
+        The expression that's matched with the FusionPattern.pattern.
+
     annotated_expr: Mapping[str, Expr]
         A map which contains all expressions matched by the sub patterns in
         FusionPattern.annotation_patterns.
 
+    matched_bindings: Mapping[Var, Expr]
+        Map from variable to its value. It contains variables from bindings 
that is
+        being fused by FuseOpsByPattern.
+
     var_usages: Mapping[Var, Sequence[Var]]
-        A map mapping variable definitions to a set of uses.
+        A map mapping variable definitions to a set of uses. It has all 
variables
+        used in the function.
 
     value_to_bound_var: Mapping[Expr, Var]
-        Map from value to its bound variable.
+        Map from value to its bound variable. It doesn't have variables after 
the
+        matched expression.
     """
 
+    matched_expr: Expr
     annotated_expr: Mapping[str, Expr]
+    matched_bindings: Mapping[Var, Expr]
     var_usages: Mapping[Var, Sequence[Var]]
     value_to_bound_var: Mapping[Expr, Var]
 
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index 76f53eebc5..15bcf3513c 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -905,6 +905,7 @@ class PatternBasedPartitioner : ExprVisitor {
  public:
   using Group = GraphPartitioner::Group;
   using GroupMap = OperatorFusor::GroupMap;
+  using PatternCheckContext = transform::PatternCheckContext;
   using ExprVisitor::VisitExpr_;
   using FCheckMatch = runtime::TypedPackedFunc<bool(const 
transform::PatternCheckContext&)>;
 
@@ -944,9 +945,7 @@ class PatternBasedPartitioner : ExprVisitor {
   void VisitBinding_(const VarBindingNode* binding, const CallNode* call) 
final {
     VisitVarDef(binding->var);
     if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call), 
bindings_)) {
-      if (check_ != nullptr &&
-          
!check_(transform::PatternCheckContext(GetAnnotatedExpr(matches_opt.value()),
-                                                 current_block_use_def_, 
value_to_bound_var_))) {
+      if (check_ != nullptr && !check_(CreatePatternCheckContext(call, 
matches_opt.value()))) {
         return;
       }
       // If a match is found, put all matching expressions into the same group.
@@ -990,14 +989,24 @@ class PatternBasedPartitioner : ExprVisitor {
     return group_map_[bound_var.get()]->FindRoot();
   }
 
-  Map<String, Expr> GetAnnotatedExpr(const Map<DFPattern, Expr> 
matched_result) {
+  PatternCheckContext CreatePatternCheckContext(const CallNode* call,
+                                                const Map<DFPattern, Expr>& 
matched_result) {
     Map<String, Expr> annotated_expr;
     for (const auto& it : annotation_pat_) {
       if (matched_result.count(it.second)) {
         annotated_expr.Set(it.first, matched_result[it.second]);
       }
     }
-    return annotated_expr;
+
+    Map<Var, Expr> matched_bindings;
+    for (const auto& [pat, match] : matched_result) {
+      if (pat->IsInstance<CallPatternNode>()) {
+        matched_bindings.Set(value_to_bound_var_[match], match);
+      }
+    }
+
+    return PatternCheckContext(GetRef<Call>(call), annotated_expr, 
matched_bindings,
+                               current_block_use_def_, value_to_bound_var_);
   }
 
   String pat_name_;
@@ -1123,11 +1132,14 @@ TVM_REGISTER_GLOBAL("relax.transform.FusionPattern")
       return FusionPattern(name, pattern, annotation_patterns, check);
     });
 
-PatternCheckContext::PatternCheckContext(Map<String, Expr> annotated_expr,
+PatternCheckContext::PatternCheckContext(Expr matched_expr, Map<String, Expr> 
annotated_expr,
+                                         Map<Var, Expr> matched_bindings,
                                          Map<Var, Array<Var>> var_usages,
                                          Map<Expr, Var> value_to_bound_var) {
   ObjectPtr<PatternCheckContextNode> n = 
make_object<PatternCheckContextNode>();
+  n->matched_expr = std::move(matched_expr);
   n->annotated_expr = std::move(annotated_expr);
+  n->matched_bindings = std::move(matched_bindings);
   n->var_usages = std::move(var_usages);
   n->value_to_bound_var = std::move(value_to_bound_var);
   data_ = std::move(n);
diff --git a/tests/python/relax/test_codegen_cutlass.py 
b/tests/python/relax/test_codegen_cutlass.py
index 0bae6801ca..5ea9a9d040 100644
--- a/tests/python/relax/test_codegen_cutlass.py
+++ b/tests/python/relax/test_codegen_cutlass.py
@@ -492,6 +492,61 @@ def test_cutlass_partition_matmul_blocked(x_shape, 
y_shape, transpose_y, dtype):
     assert len(mod.functions) == 1
 
 
+def test_cutlass_partition_matmul_tuple_return_blocked():
+    @tvm.script.ir_module
+    class TransposedMatmul:
+        @R.function
+        def main(
+            x: R.Tensor((4, 4), "float32"),
+            y: R.Tensor((4, 4), "float32"),
+        ):
+            with R.dataflow():
+                lv1 = R.permute_dims(y)
+                # Because lv1 is used by both lv2 and out, it should stay out 
of
+                # the fused function. Otherwise the fused function will return
+                # tuple output, which isn't possible in cutlass, e.g.
+                # @R.function
+                # def fused_relax_permute_dims_relax_matmul(...):
+                #     R.func_attr({"Composite": "cutlass.matmul_transposed", 
"Primitive": 1})
+                #     with R.dataflow():
+                #         gv: R.Tensor((4, 4), dtype="float32") = 
R.permute_dims(y, axes=None)
+                #         gv1: R.Tensor((4, 4), dtype="float32") = R.matmul(x, 
gv, out_dtype="void")
+                #         R.output(gv, gv1)
+                #     return (gv, gv1)  # Cannot get `gv` if dispatch to 
cutlass kernel.
+                lv2 = R.matmul(x, lv1)
+                out = R.matmul(lv1, lv2)
+                R.output(out)
+
+            return out
+
+    mod = partition_for_cutlass(TransposedMatmul, annotate_codegen=False)
+    for f_var in mod.functions:
+        func = mod[f_var]
+        if func.attrs and "Composite" in func.attrs:
+            # verify that the function is not fused as transposed matmul
+            assert func.attrs["Composite"] == "cutlass.matmul"
+
+
+def test_cutlass_partition_matmul_cyclic_dependency_blocked():
+    @tvm.script.ir_module
+    class Module:
+        @R.function
+        def main(x: R.Tensor((128, 128), "float16"), w: R.Tensor((128, 128), 
"float16")):
+            with R.dataflow():
+                # Because lv1 depends on lv, this block should be fused as 
matmul instead of matmul_bias.
+                lv = R.matmul(x, w)
+                lv1 = R.power(lv, R.const(2.0, "float16"))
+                lv2 = R.add(lv, lv1)
+                R.output(lv2)
+            return lv2
+
+    mod = partition_for_cutlass(Module, annotate_codegen=False)
+    for f_var in mod.functions:
+        func = mod[f_var]
+        if func.attrs and "Composite" in func.attrs:
+            assert func.attrs["Composite"] == "cutlass.matmul"
+
+
 @pytest.fixture(params=["float16", "float32"])
 def attention_dtype(request):
     return request.param

Reply via email to