masahi commented on code in PR #14446:
URL: https://github.com/apache/tvm/pull/14446#discussion_r1154868977
##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -807,25 +818,112 @@ class PatternRewriter : ExprMutator {
Expr VisitExpr_(const CallNode* call_node) final {
auto call = ExprMutator::VisitExpr_(call_node);
- if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) {
+ if (!pattern_) {
+ return call;
+ } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call,
bindings_)) {
auto rewriten_expr = rewriter_func_(call, matches_opt.value());
memo_[call_node] = rewriten_expr;
return rewriten_expr;
}
return call;
}
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
+ if (!ctx_) {
+ return ExprMutator::VisitBindingBlock_(block_node);
+ }
+ return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
+ }
+
private:
- DFPattern pattern_;
+ void EmitUsedVars(Expr val, const Array<Binding>& pending_bindings,
+ std::unordered_set<const VarNode*>* emitted_vars) {
+ std::unordered_set<const VarNode*> unemitted_vars;
+ PostOrderVisit(val, [=, &unemitted_vars](Expr e) {
+ if (auto v = e.as<VarNode>(); v && !emitted_vars->count(v)) {
+ unemitted_vars.insert(v);
+ }
+ });
+
+ if (unemitted_vars.empty()) {
+ return;
+ }
+
+ size_t num_unemitted = unemitted_vars.size();
+ for (size_t i = 0; i < pending_bindings.size(); ++i) {
+ const auto& binding = pending_bindings[i];
+ if (auto var_bind = binding.as<VarBindingNode>();
+ var_bind && unemitted_vars.count(var_bind->var.get())) {
+ // var_bind->value may also depend on other unemitted vars in this
range
+ Array<Binding> prev_bindings(pending_bindings.begin(),
pending_bindings.begin() + i);
+ EmitUsedVars(var_bind->value, prev_bindings, emitted_vars);
Review Comment:
I want to get rid of this recursive call and make sure we traverse
`pending_bindings` only once. The issue is that `PostOrderVisit` does not look
into subexpressions when it encounters the corresponding bound variable. For
example, given the contrived input bindings below,
```
with R.dataflow():
w0_t = R.permute_dims(w0, axes=None)
lv0 = R.matmul(x1, w0_t)
w1_t = R.permute_dims(w1, axes=None)
w1_t_t = R.permute_dims(w1_t, axes=None)
lv1 = R.matmul(x1, w1_t_t)
w2_t = R.permute_dims(w2, axes=None)
lv2 = R.matmul(x1, w2_t)
```
we need to emit all `permute_dims` binding before emitting `concat` and the
combined matmul, since `concat` depends on all weights some of which are
defined after the first matmul. When `PostOrderVisit` is applied on
`R.matmul(x1, w1_t_t)`, `w1_t` is not visited. So even though we need to emit
`w1_t` before `w1_t_t`, `w1_t` is not put into the initial `unemitted_vars` set.
I think we can use `AnalyzeVar2Value` on the input function to get bindings,
and recursively traverse the bound expression when we encounter a new unemitted
var. But I find that a bit complicated for a simple job like this, so I'm
looking for a simpler solution. For now I'm keeping this recursive solution
that is not efficient but extremely simple.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]