This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ec868ebc58 [Unity] Add callback to FuseOpsByPattern to check match 
result is accepted (#14109)
ec868ebc58 is described below

commit ec868ebc58175431250c5f56af10749b77347f51
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Feb 27 14:04:56 2023 -0800

    [Unity] Add callback to FuseOpsByPattern to check match result is accepted 
(#14109)
    
    * [Unity] Add callback to FuseOpsByPattern to check match result is accepted
    
    * add callnode to callback args
    
    * update pattern registry
    
    * fix
---
 include/tvm/relax/transform.h                      |  5 +++-
 python/tvm/relax/backend/pattern_registry.py       | 16 +++++++++----
 python/tvm/relax/transform/transform.py            | 25 +++++++++++++++++---
 src/relax/backend/pattern_registry.h               |  7 ++++++
 src/relax/transform/fuse_ops.cc                    | 27 +++++++++++++++-------
 .../relax/test_transform_fuse_ops_by_pattern.py    | 10 ++++++++
 6 files changed, 74 insertions(+), 16 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 3c02871f6c..907f0bf8cd 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -223,6 +223,8 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
  * of a fused function after successful matching.
  * \param patterns The patterns to detect. The order of the patterns 
determines the order
  * of priority in which they are matched. Higher-priority patterns should come 
earlier in the list.
+ * \param checks The callback functions with type (Map<DFPattern, Expr>, Expr) 
-> bool. It takes a
+ * match result and returns a boolean value to indicate whether the match 
result is accepted.
  * \param annotate_codegen If true, wrap each created composite function with 
another function,
  * whose body consists only of a call to the composite function, and annotate 
the outer function
  * with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set 
as the prefix of the
@@ -232,7 +234,8 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
  * \return The Pass.
  */
 TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
-                              const tvm::Array<DFPattern>& patterns, bool 
annotate_codegen = false);
+                              const tvm::Array<DFPattern>& patterns,
+                              const tvm::Array<PackedFunc>& checks, bool 
annotate_codegen = false);
 
 /*!
  * \brief Group one or multiple composite functions created by 
FuseOpsByPattern into a new
diff --git a/python/tvm/relax/backend/pattern_registry.py 
b/python/tvm/relax/backend/pattern_registry.py
index 0016de0a50..0cccab7842 100644
--- a/python/tvm/relax/backend/pattern_registry.py
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -17,13 +17,14 @@
 
 """Pattern registry for BYOC backends"""
 
-from typing import List, Mapping, Optional, Tuple, Union
+from typing import Callable, List, Mapping, Optional, Tuple, Union
 
 import tvm
 from tvm.relax.dpl import DFPattern
 from tvm.runtime import Object
 
 from . import _ffi_api
+from ..expr import Expr
 
 
 @tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
@@ -60,7 +61,10 @@ class PatternRegistryEntry(Object):
 Pattern = Union[
     PatternRegistryEntry,
     Tuple[str, DFPattern],
-    Tuple[str, Tuple[DFPattern, Mapping[str, DFPattern]]],
+    Tuple[
+        str,
+        Tuple[DFPattern, Mapping[str, DFPattern], Callable[[Mapping[DFPattern, 
Expr], Expr], bool]],
+    ],
 ]
 
 
@@ -82,9 +86,13 @@ def register_patterns(patterns: List[Pattern]):
         elif isinstance(item, tuple):
             name, pattern_or_tuple = item
             if isinstance(pattern_or_tuple, tuple):
-                pattern, arg_patterns = pattern_or_tuple
+                if len(pattern_or_tuple) == 2:
+                    pattern, arg_patterns = pattern_or_tuple
+                    check = lambda *_: True
+                else:
+                    pattern, arg_patterns, check = pattern_or_tuple
             else:
-                pattern, arg_patterns = pattern_or_tuple, {}
+                pattern, arg_patterns, check = pattern_or_tuple, {}, lambda 
*_: True
             entries.append(PatternRegistryEntry(name, pattern, arg_patterns))
         else:
             raise TypeError(f"Cannot register type {type(pattern)} as pattern")
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 263195a100..97daae4941 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -292,7 +292,11 @@ def FuseOpsByPattern(
 
     Parameters
     ----------
-    patterns : List[Tuple[str, DFPattern]]
+    patterns : List[Union[Tuple[str, DFPattern], Tuple[str, DFPattern, 
Callable]]]
+        A list of tuple of (name, pattern) or (name, pattern, predicate) to be 
matched.
+        The predicate is a function with type (Map<DFPattern, Expr>, Expr) -> 
bool. It takes a
+        match result and returns a boolean value to indicate whether the match 
result is accepted.
+
         The patterns to detect. The order of the patterns determines the order 
of priority in which
         they are matched. Higher-priority patterns should come earlier in the 
list.
         The string is the name of the corresponding pattern. It becomes the 
value of the kComposite
@@ -313,8 +317,23 @@ def FuseOpsByPattern(
         The registered pass for pattern-based fusion.
 
     """
-    pattern_names, df_patterns = zip(*patterns)
-    return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns, 
annotate_codegen)  # type: ignore
+    pattern_names = []
+    df_patterns = []
+    checks = []
+    for tup in patterns:
+        if len(tup) == 2:
+            pattern_names.append(tup[0])
+            df_patterns.append(tup[1])
+            checks.append(lambda *_: True)
+        elif len(tup) == 3:
+            pattern_names.append(tup[0])
+            df_patterns.append(tup[1])
+            checks.append(tup[2])
+        else:
+            raise ValueError("Invalid pattern: {}".format(tup))
+    return _ffi_api.FuseOpsByPattern(
+        pattern_names, df_patterns, checks, annotate_codegen
+    )  # type: ignore
 
 
 def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
diff --git a/src/relax/backend/pattern_registry.h 
b/src/relax/backend/pattern_registry.h
index 2e199a2bb1..1fdb69319f 100644
--- a/src/relax/backend/pattern_registry.h
+++ b/src/relax/backend/pattern_registry.h
@@ -42,6 +42,8 @@ namespace backend {
  */
 class PatternRegistryEntryNode : public Object {
  public:
+  using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern, 
Expr>&, const Expr&)>;
+
   /*!
    * \brief The name of pattern. Usually it starts with the name of backend, 
like
    * 'cutlass.matmul'.
@@ -59,6 +61,11 @@ class PatternRegistryEntryNode : public Object {
    */
   Map<String, DFPattern> arg_patterns;
 
+  /*!
+   * \brief The function to check whether the match result is accepted.
+   */
+  FCheckMatch check;
+
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("name", &name);
     v->Visit("pattern", &pattern);
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index c5042d0191..60b2c77e49 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -40,6 +40,7 @@
 
 #include "../../relay/analysis/graph_partitioner.h"
 #include "../../support/arena.h"
+#include "../backend/pattern_registry.h"
 
 namespace tvm {
 namespace relax {
@@ -899,15 +900,18 @@ class PatternBasedPartitioner : ExprVisitor {
   using Group = GraphPartitioner::Group;
   using GroupMap = OperatorFusor::GroupMap;
   using ExprVisitor::VisitExpr_;
+  using FCheckMatch = backend::PatternRegistryEntryNode::FCheckMatch;
 
-  static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr, 
support::Arena* arena) {
-    PatternBasedPartitioner part(pattern_name, pattern, arena);
+  static GroupMap Run(String pattern_name, DFPattern pattern, FCheckMatch 
check, Expr expr,
+                      support::Arena* arena) {
+    PatternBasedPartitioner part(pattern_name, pattern, check, arena);
     part.VisitExpr(expr);
     return part.group_map_;
   }
 
-  PatternBasedPartitioner(String pattern_name, DFPattern pattern, 
support::Arena* arena)
-      : pat_name_(pattern_name), pat_(pattern), arena_(arena) {}
+  PatternBasedPartitioner(String pattern_name, DFPattern pattern, FCheckMatch 
check,
+                          support::Arena* arena)
+      : pat_name_(pattern_name), pat_(pattern), check_(check), arena_(arena) {}
 
   void VisitVarDef(const Var& var) final { group_map_[var.get()] = 
arena_->make<Group>(); }
 
@@ -922,6 +926,9 @@ class PatternBasedPartitioner : ExprVisitor {
   void VisitBinding_(const VarBindingNode* binding, const CallNode* call) 
final {
     VisitVarDef(binding->var);
     if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call), 
bindings_)) {
+      if (!check_(matches_opt.value(), GetRef<Call>(call))) {
+        return;
+      }
       // If a match is found, put all matching expressions into the same group.
       // OperatorFusor also requires that the bound variable be in the same 
group as the RHS value.
       // Since is_op(...) based pattern only matches against call nodes on the 
right hand side,
@@ -965,6 +972,7 @@ class PatternBasedPartitioner : ExprVisitor {
 
   String pat_name_;
   DFPattern pat_;
+  FCheckMatch check_;
   support::Arena* arena_;
   Map<Var, Expr> bindings_;
   Map<Expr, Var> value_to_bound_var_;
@@ -1042,7 +1050,8 @@ class CompositeFunctionAnnotator : public ExprMutator {
 };
 
 IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
-                          const tvm::Array<DFPattern>& patterns, IRModule mod,
+                          const tvm::Array<DFPattern>& patterns,
+                          const tvm::Array<runtime::PackedFunc>& checks, 
IRModule mod,
                           bool annotate_codegen) {
   support::Arena arena;
   for (size_t i = 0; i < pattern_names.size(); ++i) {
@@ -1051,7 +1060,8 @@ IRModule FuseOpsByPattern(const tvm::Array<String>& 
pattern_names,
       if (entry.second->IsInstance<tir::PrimFuncNode>()) {
         continue;
       }
-      auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], 
entry.second, &arena);
+      auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i], 
checks[i],
+                                              entry.second, &arena);
       group_map.insert(map.begin(), map.end());
     }
     mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ false);
@@ -1080,10 +1090,11 @@ Pass FuseOps(int fuse_opt_level) {
 TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
 
 Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
-                      const tvm::Array<DFPattern>& patterns, bool 
annotate_codegen) {
+                      const tvm::Array<DFPattern>& patterns,
+                      const tvm::Array<runtime::PackedFunc>& checks, bool 
annotate_codegen) {
   runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =  //
       [=](IRModule m, PassContext pc) {
-        return relax::FuseOpsByPattern(pattern_names, patterns, m, 
annotate_codegen);
+        return relax::FuseOpsByPattern(pattern_names, patterns, checks, m, 
annotate_codegen);
       };
   return CreateModulePass(/*pass_function=*/pass_func,       //
                           /*opt_level=*/0,                   //
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py 
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 21f952096b..a9e76feb6c 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -579,5 +579,15 @@ def test_unused():
     check(Conv2dReLU, [("cutlass.conv2d", pat)], Conv2dReLU_partitioned)
 
 
+def test_check_pattern():
+    pat = make_fused_bias_activation_pattern("relax.nn.conv2d", 
with_bias=False, activation=None)
+
+    def pred(match, expr):
+        assert isinstance(expr, relax.expr.Call) and expr.op.name == 
"relax.nn.conv2d"
+        return expr.struct_info.dtype == "float32"
+
+    check(Conv2dx2, [("cutlass.conv2d", pat, pred)], Conv2dx2)  # expect no 
partitioning
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to