This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 016b512ad4 [Relax] Refactor PatternRewriter into separate Block/Expr
mutators (#16730)
016b512ad4 is described below
commit 016b512ad4950cba32eaf81be0cfe3c0321851f7
Author: Eric Lunderberg <[email protected]>
AuthorDate: Tue Mar 26 08:43:36 2024 -0500
[Relax] Refactor PatternRewriter into separate Block/Expr mutators (#16730)
Prior to this commit, the `PatternRewriter` mutator handled pattern
rewriting at either the expression level (`rewrite_call`) or the
dataflow block level (`rewrite_bindings`). These two functionalities
had different external APIs, defined diffierent member variables, and
visited different IR nodes. In effect, it had two entirely
independent implementations, which just happened to be implemented
within the same class.
This commit refactors the single `PatternRewriter` mutator into
separate `BlockPatternRewriter` and `ExprPatternRewriter` mutators.
---
include/tvm/relax/dataflow_matcher.h | 4 +-
src/relax/ir/dataflow_matcher.cc | 238 ++++++++++++++++++++---------------
2 files changed, 140 insertions(+), 102 deletions(-)
diff --git a/include/tvm/relax/dataflow_matcher.h
b/include/tvm/relax/dataflow_matcher.h
index bbc8e9382e..8f2024f264 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -67,7 +67,9 @@ TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const
PatternContext& ctx,
* \param f The function to rewrite
* \return The rewritten or the input function, depending on the pattern
matching result.
*/
-TVM_DLL Function RewriteBindings(const PatternContext& ctx, PackedFunc
rewriter, Function f);
+TVM_DLL Function RewriteBindings(
+ const PatternContext& ctx,
+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)>
rewriter, Function f);
/**
* \brief Rewrite a function with the given pattern and the rewriter function.
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index a14d43f6d3..531971d3db 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -973,102 +973,33 @@ TVM_REGISTER_GLOBAL("relax.dpl.match_dfb")
});
/*!
- * \brief Apply pattern matching to each call node and dataflow block, and
replace matching ones
+ * \brief Apply pattern matching to each dataflow block, replacing matches
* with the output of a user-provided rewriter function.
*/
-class PatternRewriter : ExprMutator {
+class BlockPatternRewriter : ExprMutator {
public:
using ExprMutator::VisitBindingBlock_;
using ExprMutator::VisitExpr_;
- PatternRewriter(DFPattern pat, PackedFunc rewriter_func,
- const std::unordered_set<const VarNode*>& params)
- : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}
-
- PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func,
- const std::unordered_set<const VarNode*>& params)
- : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
+ BlockPatternRewriter(
+ const PatternContext& ctx,
+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)>
rewriter_func)
+ : ctx_(ctx), rewriter_func_(rewriter_func) {}
template <typename PatternType>
- static Function Run(PatternType pat, PackedFunc rewriter_func, Function f) {
- std::unordered_set<const VarNode*> params;
- for (const auto& p : f->params) {
- params.insert(p.get());
- }
- PatternRewriter rewriter(pat, rewriter_func, params);
- return Downcast<Function>(RemoveAllUnused(rewriter.VisitExpr(f)));
- }
-
- Expr VisitExpr_(const SeqExprNode* seq) override {
- if (ctx_) {
- return ExprMutator::VisitExpr_(seq);
- }
-
- auto cache = bindings_;
- SeqExpr prev = GetRef<SeqExpr>(seq);
-
- StructuralEqual struct_equal;
-
- while (true) {
- SeqExpr next =
Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
- if (struct_equal(prev, next)) {
- return std::move(next);
- }
-
- // Canonicalization may result in two previously-different
- // expressions being recognized as identical. Elimination of
- // common subexpressions may result in trival var-to-var
- // bindings that can be canonicalized. Therefore, iterate the
- // simplification steps until converged.
- while (true) {
- auto start_of_loop = next;
- next = Downcast<SeqExpr>(CanonicalizeBindings(next));
- next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
- next = Downcast<SeqExpr>(RemoveAllUnused(next));
- if (struct_equal(start_of_loop, next)) {
- break;
- }
- }
-
- if (struct_equal(prev, next)) {
- return std::move(next);
- }
-
- // Reset all knowledge of bindings that were collected from
- // this DataflowBlock. The collected bindings are only after
- // the point where they were collected, and we are repeating
- // the mutation of this DataflowBlock.
- bindings_ = cache;
- prev = next;
- }
+ static Function Run(
+ PatternType pat,
+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)>
rewriter_func,
+ Function func) {
+ BlockPatternRewriter rewriter(pat, rewriter_func);
+
+ func = Downcast<Function>(rewriter(func));
+ func = Downcast<Function>(RemoveAllUnused(func));
+ return func;
}
BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node)
override {
- if (ctx_) {
- return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
- } else {
- return ExprMutator::VisitBindingBlock_(block_node);
- }
- }
-
- void VisitBinding_(const VarBindingNode* binding) override {
- auto expr = VisitExpr(binding->value);
- bindings_.Set(binding->var, expr);
- ReEmitBinding(binding, expr);
- }
-
- Expr VisitExpr(const Expr& expr) override {
- auto node = ExprMutator::VisitExpr(expr);
-
- if (pattern_) {
- if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), node,
bindings_)) {
- Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
- if (!rewritten_expr.same_as(node)) {
- return builder_->Normalize(rewritten_expr);
- }
- }
- }
- return node;
+ return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
}
private:
@@ -1106,7 +1037,7 @@ class PatternRewriter : ExprMutator {
BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
auto df_block = Downcast<DataflowBlock>(block);
Map<Var, Expr> bindings = AnalyzeVar2Value(df_block);
- if (auto matches = MatchGraph(ctx_.value(), df_block, bindings)) {
+ if (auto matches = MatchGraph(ctx_, df_block, bindings)) {
builder_->BeginDataflowBlock();
Map<Var, Expr> replacements = rewriter_func_(matches.value(), bindings);
@@ -1140,34 +1071,139 @@ class PatternRewriter : ExprMutator {
return block;
}
- /*! \brief The pattern for rewriting call nodes */
- Optional<DFPattern> pattern_;
/*! \brief The pattern constraint contexts for rewriting dataflow blocks */
- Optional<PatternContext> ctx_;
+ PatternContext ctx_;
/*!
* \brief The user-provided rewriter function. Its signature and semantics
are:
- * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the
matched
- * call node and the map of patterns and matched expressions, it should
return a new call node
- * to replace the original one or the original matched call node as is.
- * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr> for dataflow
block rewriting.
- * Given the map of patterns and corresponding variables (bound variables
or parameters),
- * it should return a map that specifies new values for matched bound
variables. It can refer
+ *
+ * - (Map<DFPattern, Var>, Map<Var, Expr>) -> Map<Var, Expr>
+ *
+ * Given the map of patterns and corresponding variables (bound
+ * variables or parameters), it should return a map that
+ * specifies new values for matched bound variables. It can refer
* to the passed bindings to create the replacement expressions.
*/
- PackedFunc rewriter_func_;
- std::unordered_set<const VarNode*> params_;
+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)>
rewriter_func_;
+};
+
+/*!
+ * \brief Apply pattern matching to each expression, replacing
+ * matches with the output of a user-provided rewriter function.
+ */
+class ExprPatternRewriter : ExprMutator {
+ public:
+ using ExprMutator::VisitBindingBlock_;
+ using ExprMutator::VisitExpr_;
+
+ ExprPatternRewriter(DFPattern pat,
+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>
rewriter_func)
+ : pattern_(pat), rewriter_func_(rewriter_func) {}
+
+ template <typename PatternType>
+ static Function Run(PatternType pat,
+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>
rewriter_func,
+ Function func) {
+ ExprPatternRewriter rewriter(pat, rewriter_func);
+ func = Downcast<Function>(rewriter(func));
+ func = Downcast<Function>(RemoveAllUnused(func));
+ return func;
+ }
+
+ Expr VisitExpr_(const SeqExprNode* seq) override {
+ auto cache = bindings_;
+ SeqExpr prev = GetRef<SeqExpr>(seq);
+
+ StructuralEqual struct_equal;
+
+ while (true) {
+ SeqExpr next =
Downcast<SeqExpr>(builder_->Normalize(ExprMutator::VisitExpr_(prev.get())));
+ if (struct_equal(prev, next)) {
+ return std::move(next);
+ }
+
+ // Canonicalization may result in two previously-different
+ // expressions being recognized as identical. Elimination of
+ // common subexpressions may result in trival var-to-var
+ // bindings that can be canonicalized. Therefore, iterate the
+ // simplification steps until converged.
+ while (true) {
+ auto start_of_loop = next;
+ next = Downcast<SeqExpr>(CanonicalizeBindings(next));
+ next = Downcast<SeqExpr>(EliminateCommonSubexpr(next));
+ next = Downcast<SeqExpr>(RemoveAllUnused(next));
+ if (struct_equal(start_of_loop, next)) {
+ break;
+ }
+ }
+
+ if (struct_equal(prev, next)) {
+ return std::move(next);
+ }
+
+ // Reset all knowledge of bindings that were collected from
+ // this SeqExpr. The collected bindings are only after
+ // the point where they were collected, and we are repeating
+ // the mutation of this SeqExpr.
+ bindings_ = cache;
+ prev = next;
+ }
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) override {
+ auto expr = VisitExpr(binding->value);
+ bindings_.Set(binding->var, expr);
+ ReEmitBinding(binding, expr);
+ }
+
+ Expr VisitExpr(const Expr& expr) override {
+ auto node = ExprMutator::VisitExpr(expr);
+
+ if (auto matches_opt = ExtractMatchedExpr(pattern_, node, bindings_)) {
+ Expr rewritten_expr = rewriter_func_(node, matches_opt.value());
+ if (!rewritten_expr.same_as(node)) {
+ return builder_->Normalize(rewritten_expr);
+ }
+ }
+
+ return node;
+ }
+
+ private:
+ /*! \brief The pattern for rewriting call nodes */
+ DFPattern pattern_;
+ /*!
+ * \brief The user-provided rewriter function. Its signature and semantics
are:
+ *
+ * - (Call, Map<DFPattern, Expr>) -> Call
+ *
+ * Given the matched call node and the map of patterns and
+ * matched expressions, it should return a new call node to
+ * replace the original one or the original matched call node as
+ * is.
+ */
+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)> rewriter_func_;
+
+ /*! \brief The known variable bindings
+ *
+ * The variable bindings whose value is known. This must be tracked
+ * separately from the block builder, so that it can be reset after
+ * each iteration of the mutate-until-converged loop applied to
+ * `SeqExpr`.
+ */
Map<Var, Expr> bindings_;
};
-Function RewriteBindings(const PatternContext& ctx, PackedFunc rewriter,
Function f) {
- return PatternRewriter::Run(ctx, rewriter, f);
+Function RewriteBindings(
+ const PatternContext& ctx,
+ TypedPackedFunc<Map<Var, Expr>(Map<DFPattern, Var>, Map<Var, Expr>)>
rewriter, Function func) {
+ return BlockPatternRewriter::Run(ctx, rewriter, func);
}
TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings").set_body_typed(RewriteBindings);
Function RewriteCall(const DFPattern& pat,
- TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>
rewriter, Function f) {
- return PatternRewriter::Run(pat, rewriter, f);
+ TypedPackedFunc<Expr(Expr, Map<DFPattern, Expr>)>
rewriter, Function func) {
+ return ExprPatternRewriter::Run(pat, rewriter, func);
}
TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call").set_body_typed(RewriteCall);