masahi commented on code in PR #14446:
URL: https://github.com/apache/tvm/pull/14446#discussion_r1154859227
##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -807,25 +818,112 @@ class PatternRewriter : ExprMutator {
Expr VisitExpr_(const CallNode* call_node) final {
auto call = ExprMutator::VisitExpr_(call_node);
- if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) {
+ if (!pattern_) {
+ return call;
+ } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call,
bindings_)) {
auto rewriten_expr = rewriter_func_(call, matches_opt.value());
memo_[call_node] = rewriten_expr;
return rewriten_expr;
}
return call;
}
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
+ if (!ctx_) {
+ return ExprMutator::VisitBindingBlock_(block_node);
+ }
+ return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
+ }
+
private:
- DFPattern pattern_;
+ void EmitUsedVars(Expr val, const Array<Binding>& pending_bindings,
+ std::unordered_set<const VarNode*>* emitted_vars) {
+ std::unordered_set<const VarNode*> unemitted_vars;
+ PostOrderVisit(val, [=, &unemitted_vars](Expr e) {
+ if (auto v = e.as<VarNode>(); v && !emitted_vars->count(v)) {
+ unemitted_vars.insert(v);
+ }
+ });
+
+ if (unemitted_vars.empty()) {
+ return;
+ }
+
+ size_t num_unemitted = unemitted_vars.size();
+ for (size_t i = 0; i < pending_bindings.size(); ++i) {
+ const auto& binding = pending_bindings[i];
+ if (auto var_bind = binding.as<VarBindingNode>();
+ var_bind && unemitted_vars.count(var_bind->var.get())) {
+ // var_bind->value may also depend on other unemitted vars in this
range
+ Array<Binding> prev_bindings(pending_bindings.begin(),
pending_bindings.begin() + i);
+ EmitUsedVars(var_bind->value, prev_bindings, emitted_vars);
+ this->VisitBinding(binding);
+ emitted_vars->insert(var_bind->var.get());
+ if (--num_unemitted == 0) {
+ return;
+ }
+ }
+ }
+ }
+
+ // Repeat until all matchable subsets of bindings are rewritten.
+ BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
Review Comment:
We need to apply rewriting repeatedly since, for example, the same QKV
projection pattern appears a number of times in a single DFB.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]