masahi commented on code in PR #14446:
URL: https://github.com/apache/tvm/pull/14446#discussion_r1154858483
##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -780,18 +780,29 @@ Optional<Map<DFPattern, Var>> MatchGraph(const
PatternContext& ctx, const Datafl
TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
/*!
- * \brief Apply pattern matching to each call node and replace matching ones
with the output of
- * a user-provided rewriter function.
+ * \brief Apply pattern matching to each call node and dataflow block, and
replace matching ones
+ * with the output of a user-provided rewriter function.
*/
class PatternRewriter : ExprMutator {
public:
+ using ExprMutator::VisitBindingBlock_;
using ExprMutator::VisitExpr_;
- PatternRewriter(DFPattern pat, PackedFunc rewriter_func)
- : pattern_(pat), rewriter_func_(rewriter_func) {}
+ PatternRewriter(DFPattern pat, PackedFunc rewriter_func,
+ const std::unordered_set<const VarNode*>& params)
+ : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}
- static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) {
- PatternRewriter rewriter(pat, rewriter_func);
+ PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func,
+ const std::unordered_set<const VarNode*>& params)
+ : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
+
+ template <typename PatternType>
+ static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) {
Review Comment:
Depending on the passed pattern type (`DFPattern` or `PatternContext`), it
does either call node rewriting or dataflow block rewriting. It never does both
in a single pass (obvious from the constructors).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]