This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ec868ebc58 [Unity] Add callback to FuseOpsByPattern to check match
result is accepted (#14109)
ec868ebc58 is described below
commit ec868ebc58175431250c5f56af10749b77347f51
Author: Wuwei Lin <[email protected]>
AuthorDate: Mon Feb 27 14:04:56 2023 -0800
[Unity] Add callback to FuseOpsByPattern to check match result is accepted
(#14109)
* [Unity] Add callback to FuseOpsByPattern to check match result is accepted
* add callnode to callback args
* update pattern registry
* fix
---
include/tvm/relax/transform.h | 5 +++-
python/tvm/relax/backend/pattern_registry.py | 16 +++++++++----
python/tvm/relax/transform/transform.py | 25 +++++++++++++++++---
src/relax/backend/pattern_registry.h | 7 ++++++
src/relax/transform/fuse_ops.cc | 27 +++++++++++++++-------
.../relax/test_transform_fuse_ops_by_pattern.py | 10 ++++++++
6 files changed, 74 insertions(+), 16 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 3c02871f6c..907f0bf8cd 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -223,6 +223,8 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
* of a fused function after successful matching.
* \param patterns The patterns to detect. The order of the patterns
determines the order
* of priority in which they are matched. Higher-priority patterns should come
earlier in the list.
+ * \param checks The callback functions with type (Map<DFPattern, Expr>, Expr)
-> bool. It takes a
+ * match result and returns a boolean value to indicate whether the match
result is accepted.
* \param annotate_codegen If true, wrap each created composite function with
another function,
* whose body consists only of a call to the composite function, and annotate
the outer function
* with kCodegen and kGlobalSymbol attributes. The kCodegen attribute is set
as the prefix of the
@@ -232,7 +234,8 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
* \return The Pass.
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<runtime::String>& pattern_names,
- const tvm::Array<DFPattern>& patterns, bool
annotate_codegen = false);
+ const tvm::Array<DFPattern>& patterns,
+ const tvm::Array<PackedFunc>& checks, bool
annotate_codegen = false);
/*!
* \brief Group one or multiple composite functions created by
FuseOpsByPattern into a new
diff --git a/python/tvm/relax/backend/pattern_registry.py
b/python/tvm/relax/backend/pattern_registry.py
index 0016de0a50..0cccab7842 100644
--- a/python/tvm/relax/backend/pattern_registry.py
+++ b/python/tvm/relax/backend/pattern_registry.py
@@ -17,13 +17,14 @@
"""Pattern registry for BYOC backends"""
-from typing import List, Mapping, Optional, Tuple, Union
+from typing import Callable, List, Mapping, Optional, Tuple, Union
import tvm
from tvm.relax.dpl import DFPattern
from tvm.runtime import Object
from . import _ffi_api
+from ..expr import Expr
@tvm._ffi.register_object("relax.backend.PatternRegistryEntry")
@@ -60,7 +61,10 @@ class PatternRegistryEntry(Object):
Pattern = Union[
PatternRegistryEntry,
Tuple[str, DFPattern],
- Tuple[str, Tuple[DFPattern, Mapping[str, DFPattern]]],
+ Tuple[
+ str,
+ Tuple[DFPattern, Mapping[str, DFPattern], Callable[[Mapping[DFPattern,
Expr], Expr], bool]],
+ ],
]
@@ -82,9 +86,13 @@ def register_patterns(patterns: List[Pattern]):
elif isinstance(item, tuple):
name, pattern_or_tuple = item
if isinstance(pattern_or_tuple, tuple):
- pattern, arg_patterns = pattern_or_tuple
+ if len(pattern_or_tuple) == 2:
+ pattern, arg_patterns = pattern_or_tuple
+ check = lambda *_: True
+ else:
+ pattern, arg_patterns, check = pattern_or_tuple
else:
- pattern, arg_patterns = pattern_or_tuple, {}
+ pattern, arg_patterns, check = pattern_or_tuple, {}, lambda
*_: True
entries.append(PatternRegistryEntry(name, pattern, arg_patterns))
else:
raise TypeError(f"Cannot register type {type(pattern)} as pattern")
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 263195a100..97daae4941 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -292,7 +292,11 @@ def FuseOpsByPattern(
Parameters
----------
- patterns : List[Tuple[str, DFPattern]]
+ patterns : List[Union[Tuple[str, DFPattern], Tuple[str, DFPattern,
Callable]]]
+ A list of tuple of (name, pattern) or (name, pattern, predicate) to be
matched.
+ The predicate is a function with type (Map<DFPattern, Expr>, Expr) ->
bool. It takes a
+ match result and returns a boolean value to indicate whether the match
result is accepted.
+
The patterns to detect. The order of the patterns determines the order
of priority in which
they are matched. Higher-priority patterns should come earlier in the
list.
The string is the name of the corresponding pattern. It becomes the
value of the kComposite
@@ -313,8 +317,23 @@ def FuseOpsByPattern(
The registered pass for pattern-based fusion.
"""
- pattern_names, df_patterns = zip(*patterns)
- return _ffi_api.FuseOpsByPattern(pattern_names, df_patterns,
annotate_codegen) # type: ignore
+ pattern_names = []
+ df_patterns = []
+ checks = []
+ for tup in patterns:
+ if len(tup) == 2:
+ pattern_names.append(tup[0])
+ df_patterns.append(tup[1])
+ checks.append(lambda *_: True)
+ elif len(tup) == 3:
+ pattern_names.append(tup[0])
+ df_patterns.append(tup[1])
+ checks.append(tup[2])
+ else:
+ raise ValueError("Invalid pattern: {}".format(tup))
+ return _ffi_api.FuseOpsByPattern(
+ pattern_names, df_patterns, checks, annotate_codegen
+ ) # type: ignore
def MergeCompositeFunctions() -> tvm.ir.transform.Pass:
diff --git a/src/relax/backend/pattern_registry.h
b/src/relax/backend/pattern_registry.h
index 2e199a2bb1..1fdb69319f 100644
--- a/src/relax/backend/pattern_registry.h
+++ b/src/relax/backend/pattern_registry.h
@@ -42,6 +42,8 @@ namespace backend {
*/
class PatternRegistryEntryNode : public Object {
public:
+ using FCheckMatch = runtime::TypedPackedFunc<bool(const Map<DFPattern,
Expr>&, const Expr&)>;
+
/*!
* \brief The name of pattern. Usually it starts with the name of backend,
like
* 'cutlass.matmul'.
@@ -59,6 +61,11 @@ class PatternRegistryEntryNode : public Object {
*/
Map<String, DFPattern> arg_patterns;
+ /*!
+ * \brief The function to check whether the match result is accepted.
+ */
+ FCheckMatch check;
+
void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
v->Visit("pattern", &pattern);
diff --git a/src/relax/transform/fuse_ops.cc b/src/relax/transform/fuse_ops.cc
index c5042d0191..60b2c77e49 100644
--- a/src/relax/transform/fuse_ops.cc
+++ b/src/relax/transform/fuse_ops.cc
@@ -40,6 +40,7 @@
#include "../../relay/analysis/graph_partitioner.h"
#include "../../support/arena.h"
+#include "../backend/pattern_registry.h"
namespace tvm {
namespace relax {
@@ -899,15 +900,18 @@ class PatternBasedPartitioner : ExprVisitor {
using Group = GraphPartitioner::Group;
using GroupMap = OperatorFusor::GroupMap;
using ExprVisitor::VisitExpr_;
+ using FCheckMatch = backend::PatternRegistryEntryNode::FCheckMatch;
- static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr,
support::Arena* arena) {
- PatternBasedPartitioner part(pattern_name, pattern, arena);
+ static GroupMap Run(String pattern_name, DFPattern pattern, FCheckMatch
check, Expr expr,
+ support::Arena* arena) {
+ PatternBasedPartitioner part(pattern_name, pattern, check, arena);
part.VisitExpr(expr);
return part.group_map_;
}
- PatternBasedPartitioner(String pattern_name, DFPattern pattern,
support::Arena* arena)
- : pat_name_(pattern_name), pat_(pattern), arena_(arena) {}
+ PatternBasedPartitioner(String pattern_name, DFPattern pattern, FCheckMatch
check,
+ support::Arena* arena)
+ : pat_name_(pattern_name), pat_(pattern), check_(check), arena_(arena) {}
void VisitVarDef(const Var& var) final { group_map_[var.get()] =
arena_->make<Group>(); }
@@ -922,6 +926,9 @@ class PatternBasedPartitioner : ExprVisitor {
void VisitBinding_(const VarBindingNode* binding, const CallNode* call)
final {
VisitVarDef(binding->var);
if (auto matches_opt = ExtractMatchedExpr(pat_, GetRef<Call>(call),
bindings_)) {
+ if (!check_(matches_opt.value(), GetRef<Call>(call))) {
+ return;
+ }
// If a match is found, put all matching expressions into the same group.
// OperatorFusor also requires that the bound variable be in the same
group as the RHS value.
// Since is_op(...) based pattern only matches against call nodes on the
right hand side,
@@ -965,6 +972,7 @@ class PatternBasedPartitioner : ExprVisitor {
String pat_name_;
DFPattern pat_;
+ FCheckMatch check_;
support::Arena* arena_;
Map<Var, Expr> bindings_;
Map<Expr, Var> value_to_bound_var_;
@@ -1042,7 +1050,8 @@ class CompositeFunctionAnnotator : public ExprMutator {
};
IRModule FuseOpsByPattern(const tvm::Array<String>& pattern_names,
- const tvm::Array<DFPattern>& patterns, IRModule mod,
+ const tvm::Array<DFPattern>& patterns,
+ const tvm::Array<runtime::PackedFunc>& checks,
IRModule mod,
bool annotate_codegen) {
support::Arena arena;
for (size_t i = 0; i < pattern_names.size(); ++i) {
@@ -1051,7 +1060,8 @@ IRModule FuseOpsByPattern(const tvm::Array<String>&
pattern_names,
if (entry.second->IsInstance<tir::PrimFuncNode>()) {
continue;
}
- auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i],
entry.second, &arena);
+ auto map = PatternBasedPartitioner::Run(pattern_names[i], patterns[i],
checks[i],
+ entry.second, &arena);
group_map.insert(map.begin(), map.end());
}
mod = MakeGroupedFunctions(mod, group_map, /*lift_constants*/ false);
@@ -1080,10 +1090,11 @@ Pass FuseOps(int fuse_opt_level) {
TVM_REGISTER_GLOBAL("relax.transform.FuseOps").set_body_typed(FuseOps);
Pass FuseOpsByPattern(const tvm::Array<String>& pattern_names,
- const tvm::Array<DFPattern>& patterns, bool
annotate_codegen) {
+ const tvm::Array<DFPattern>& patterns,
+ const tvm::Array<runtime::PackedFunc>& checks, bool
annotate_codegen) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = //
[=](IRModule m, PassContext pc) {
- return relax::FuseOpsByPattern(pattern_names, patterns, m,
annotate_codegen);
+ return relax::FuseOpsByPattern(pattern_names, patterns, checks, m,
annotate_codegen);
};
return CreateModulePass(/*pass_function=*/pass_func, //
/*opt_level=*/0, //
diff --git a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
index 21f952096b..a9e76feb6c 100644
--- a/tests/python/relax/test_transform_fuse_ops_by_pattern.py
+++ b/tests/python/relax/test_transform_fuse_ops_by_pattern.py
@@ -579,5 +579,15 @@ def test_unused():
check(Conv2dReLU, [("cutlass.conv2d", pat)], Conv2dReLU_partitioned)
+def test_check_pattern():
+ pat = make_fused_bias_activation_pattern("relax.nn.conv2d",
with_bias=False, activation=None)
+
+ def pred(match, expr):
+ assert isinstance(expr, relax.expr.Call) and expr.op.name ==
"relax.nn.conv2d"
+ return expr.struct_info.dtype == "float32"
+
+ check(Conv2dx2, [("cutlass.conv2d", pat, pred)], Conv2dx2) # expect no
partitioning
+
+
if __name__ == "__main__":
pytest.main([__file__])