ganler commented on code in PR #14501:
URL: https://github.com/apache/tvm/pull/14501#discussion_r1159037200
##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -523,109 +523,117 @@ bool MatchExpr(DFPattern pattern, Expr expr,
Optional<Map<Var, Expr>> bindings_o
TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+ std::vector<const VarNode*> vars;
+ std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+ // caller -> callee table.
+ std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+ const VarNode* cur_user_;
+
+ void VisitBinding_(const VarBindingNode* binding) override {
+ // init
+ cur_user_ = binding->var.get();
+ this->VisitVarDef(binding->var);
+ this->VisitExpr(binding->value);
+ cur_user_ = nullptr;
+ }
+
+ void VisitExpr_(const VarNode* op) override {
+ if (nullptr == cur_user_) return;
+
+ auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode*
var) {
+ if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+ vec.push_back(var);
+ }
+ };
+
+ check_and_push(def2use[op], cur_user_);
+ check_and_push(vars, op);
+
+ caller2callees[cur_user_].push_back(op);
+ }
+
+ void VisitExpr_(const DataflowVarNode* op) override {
+ VisitExpr_(static_cast<const VarNode*>(op));
+ }
+};
+
struct PNode {
const DFPatternNode* ptr;
- const VarNode* matched = nullptr;
std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
};
struct RNode {
const VarNode* ptr;
- const DFPatternNode* matched = nullptr;
std::vector<RNode*> children;
std::vector<RNode*> parents;
};
-/**
- * \brief This method try to match a real node and a pattern node along with
its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
- PNode* p, RNode* r, DFPatternMatcher* m,
- const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
- const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
- if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched
before.
- if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return
std::nullopt;
-
- UndoItems undo;
-
- const auto commit = [&undo](PNode* p, RNode* r) {
- // match with each other.
- // TODO(ganler, masahi): Why commit on the same p-r pair happens more than
once?
- if (p->ptr == r->matched) {
- ICHECK_EQ(p->matched, r->ptr);
- return;
- }
- p->matched = r->ptr;
- r->matched = p->ptr;
- undo.emplace_back(p, r);
- };
-
- const auto quit = [&undo] {
- for (auto& [p_node, r_node] : undo) {
- p_node->matched = nullptr;
- r_node->matched = nullptr;
- }
- return std::nullopt;
- };
+struct MatchState {
+ void add(const PNode* p, const RNode* r) {
+ match_p_r[p] = r;
+ match_r_p[r] = p;
+ }
- const auto try_match_update_undo = [&](PNode* p, RNode* r) {
- if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
- undo.insert(undo.end(), undo_more->begin(), undo_more->end());
- return true;
+ void add(const MatchState& other) {
+ for (const auto& [p, r] : other.match_p_r) {
Review Comment:
Here I assume the incoming key-value pairs won't conflict with that in
`*this`. But if it is intended to overriding values of existing keys, then use
[`insert_or_assign`](https://en.cppreference.com/w/cpp/container/unordered_map/insert_or_assign)
over `insert`.
**Update** there is a cool C++17 API called `merge` can do the job directly.
https://en.cppreference.com/w/cpp/container/unordered_map/merge
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]