yelite commented on code in PR #14109:
URL: https://github.com/apache/tvm/pull/14109#discussion_r1116497037
##########
src/relax/transform/fuse_ops.cc:
##########
@@ -908,23 +908,27 @@ class PatternBasedPartitioner : ExprVisitor {
using GroupMap = OperatorFusor::GroupMap;
using ExprVisitor::VisitExpr_;
- static GroupMap Run(String pattern_name, DFPattern pattern, Expr expr,
support::Arena* arena) {
- PatternBasedPartitioner part(pattern_name, pattern,
AnalyzeVar2Value(expr));
+ static GroupMap Run(String pattern_name, DFPattern pattern,
runtime::PackedFunc check, Expr expr,
+ support::Arena* arena) {
+ PatternBasedPartitioner part(pattern_name, pattern, check,
AnalyzeVar2Value(expr));
// Initialize each expr to have its own group
PostOrderVisit(
expr, [arena, &part](const Expr& e) { part.group_map_[e.get()] =
arena->make<Group>(); });
part.VisitExpr(expr);
return part.group_map_;
}
- PatternBasedPartitioner(String pattern_name, DFPattern pattern, const
Map<Var, Expr>& bindings)
+ PatternBasedPartitioner(String pattern_name, DFPattern pattern,
runtime::PackedFunc check,
Review Comment:
Will it be better to use stronger typing, replacing `PackedFunc` with
`TypedPackedFunc`, so that it's clearer on what to pass here?
##########
include/tvm/relax/transform.h:
##########
@@ -223,6 +223,8 @@ TVM_DLL Pass FuseOps(int fuse_opt_level = -1);
* of a fused function after successful matching.
* \param patterns The patterns to detect. The order of the patterns
determines the order
* of priority in which they are matched. Higher-priority patterns should come
earlier in the list.
+ * \param checks The callback functions that takes a match result and returns
a boolean value to
Review Comment:
Taking a match result would work for simple cases, but will it still work
for more complicated ones like MHA? I think it would be difficult to retrieve
useful information from the match result alone, if the expr contains multiple
nodes with similar nature. OTOH, It seems that only the call expr and the match
result are available at the call site of checks function.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]