ganler commented on code in PR #14501:
URL: https://github.com/apache/tvm/pull/14501#discussion_r1159035242


##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -523,109 +523,117 @@ bool MatchExpr(DFPattern pattern, Expr expr, 
Optional<Map<Var, Expr>> bindings_o
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
 
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+  // caller -> callee table.
+  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+  const VarNode* cur_user_;
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // init
+    cur_user_ = binding->var.get();
+    this->VisitVarDef(binding->var);
+    this->VisitExpr(binding->value);
+    cur_user_ = nullptr;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (nullptr == cur_user_) return;
+
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* 
var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
+    caller2callees[cur_user_].push_back(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) override {
+    VisitExpr_(static_cast<const VarNode*>(op));
+  }
+};
+
 struct PNode {
   const DFPatternNode* ptr;
-  const VarNode* matched = nullptr;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
 };
 
 struct RNode {
   const VarNode* ptr;
-  const DFPatternNode* matched = nullptr;
   std::vector<RNode*> children;
   std::vector<RNode*> parents;
 };
 
-/**
- * \brief This method try to match a real node and a pattern node along with 
its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
-    PNode* p, RNode* r, DFPatternMatcher* m,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched 
before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return 
std::nullopt;
-
-  UndoItems undo;
-
-  const auto commit = [&undo](PNode* p, RNode* r) {
-    // match with each other.
-    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
-    if (p->ptr == r->matched) {
-      ICHECK_EQ(p->matched, r->ptr);
-      return;
-    }
-    p->matched = r->ptr;
-    r->matched = p->ptr;
-    undo.emplace_back(p, r);
-  };
-
-  const auto quit = [&undo] {
-    for (auto& [p_node, r_node] : undo) {
-      p_node->matched = nullptr;
-      r_node->matched = nullptr;
-    }
-    return std::nullopt;
-  };
+struct MatchState {
+  void add(const PNode* p, const RNode* r) {
+    match_p_r[p] = r;
+    match_r_p[r] = p;
+  }
 
-  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
-    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
-      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
-      return true;
+  void add(const MatchState& other) {
+    for (const auto& [p, r] : other.match_p_r) {

Review Comment:
   May consider making it more cache-friendly:
   
   ```c++
   match_p_r.insert(other.match_p_r.cbegin(), other.match_p_r.cend());
   match_r_p.insert(other.match_r_p.cbegin(), other.match_r_p.cend());
   ```



##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -523,109 +523,117 @@ bool MatchExpr(DFPattern pattern, Expr expr, 
Optional<Map<Var, Expr>> bindings_o
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
 
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+  // caller -> callee table.
+  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+  const VarNode* cur_user_;
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // init
+    cur_user_ = binding->var.get();
+    this->VisitVarDef(binding->var);
+    this->VisitExpr(binding->value);
+    cur_user_ = nullptr;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (nullptr == cur_user_) return;
+
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* 
var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
+    caller2callees[cur_user_].push_back(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) override {
+    VisitExpr_(static_cast<const VarNode*>(op));
+  }
+};
+
 struct PNode {
   const DFPatternNode* ptr;
-  const VarNode* matched = nullptr;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
 };
 
 struct RNode {
   const VarNode* ptr;
-  const DFPatternNode* matched = nullptr;
   std::vector<RNode*> children;
   std::vector<RNode*> parents;
 };
 
-/**
- * \brief This method try to match a real node and a pattern node along with 
its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
-    PNode* p, RNode* r, DFPatternMatcher* m,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched 
before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return 
std::nullopt;
-
-  UndoItems undo;
-
-  const auto commit = [&undo](PNode* p, RNode* r) {
-    // match with each other.
-    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
-    if (p->ptr == r->matched) {
-      ICHECK_EQ(p->matched, r->ptr);
-      return;
-    }
-    p->matched = r->ptr;
-    r->matched = p->ptr;
-    undo.emplace_back(p, r);
-  };
-
-  const auto quit = [&undo] {
-    for (auto& [p_node, r_node] : undo) {
-      p_node->matched = nullptr;
-      r_node->matched = nullptr;
-    }
-    return std::nullopt;
-  };
+struct MatchState {
+  void add(const PNode* p, const RNode* r) {
+    match_p_r[p] = r;
+    match_r_p[r] = p;
+  }
 
-  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
-    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
-      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
-      return true;
+  void add(const MatchState& other) {
+    for (const auto& [p, r] : other.match_p_r) {
+      add(p, r);
     }
-    return false;
-  };
+  }
 
-  commit(p, r);
+  const VarNode* matched(const PNode* p) const {
+    if (!match_p_r.count(p)) return nullptr;
+    return match_p_r.at(p)->ptr;

Review Comment:
   minor perf to get things down with only one look-up (feel free to ignore it 
if you believe it hurts readbility):
   
   ```suggestion
       if (auto it = match_p_r.find(p); it != match_p_r.end()) return it->ptr;
       return nullptr;
   ```
   



##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -523,109 +523,117 @@ bool MatchExpr(DFPattern pattern, Expr expr, 
Optional<Map<Var, Expr>> bindings_o
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
 
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+  // caller -> callee table.
+  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+  const VarNode* cur_user_;
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // init
+    cur_user_ = binding->var.get();
+    this->VisitVarDef(binding->var);
+    this->VisitExpr(binding->value);
+    cur_user_ = nullptr;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (nullptr == cur_user_) return;
+
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* 
var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
+    caller2callees[cur_user_].push_back(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) override {
+    VisitExpr_(static_cast<const VarNode*>(op));
+  }
+};
+
 struct PNode {
   const DFPatternNode* ptr;
-  const VarNode* matched = nullptr;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
 };
 
 struct RNode {
   const VarNode* ptr;
-  const DFPatternNode* matched = nullptr;
   std::vector<RNode*> children;
   std::vector<RNode*> parents;
 };
 
-/**
- * \brief This method try to match a real node and a pattern node along with 
its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
-    PNode* p, RNode* r, DFPatternMatcher* m,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched 
before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return 
std::nullopt;
-
-  UndoItems undo;
-
-  const auto commit = [&undo](PNode* p, RNode* r) {
-    // match with each other.
-    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
-    if (p->ptr == r->matched) {
-      ICHECK_EQ(p->matched, r->ptr);
-      return;
-    }
-    p->matched = r->ptr;
-    r->matched = p->ptr;
-    undo.emplace_back(p, r);
-  };
-
-  const auto quit = [&undo] {
-    for (auto& [p_node, r_node] : undo) {
-      p_node->matched = nullptr;
-      r_node->matched = nullptr;
-    }
-    return std::nullopt;
-  };
+struct MatchState {
+  void add(const PNode* p, const RNode* r) {
+    match_p_r[p] = r;
+    match_r_p[r] = p;
+  }
 
-  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
-    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
-      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
-      return true;
+  void add(const MatchState& other) {
+    for (const auto& [p, r] : other.match_p_r) {

Review Comment:
   Here I assume the incoming key-value pairs won't conflict with that in 
`*this`. But if it is intended to overriding values of existing keys, then use 
[`insert_or_assign`](https://en.cppreference.com/w/cpp/container/unordered_map/insert_or_assign)
 over `insert`.



##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -523,109 +523,117 @@ bool MatchExpr(DFPattern pattern, Expr expr, 
Optional<Map<Var, Expr>> bindings_o
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
 
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+  // caller -> callee table.
+  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+  const VarNode* cur_user_;
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // init
+    cur_user_ = binding->var.get();
+    this->VisitVarDef(binding->var);
+    this->VisitExpr(binding->value);
+    cur_user_ = nullptr;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (nullptr == cur_user_) return;
+
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* 
var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
+    caller2callees[cur_user_].push_back(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) override {
+    VisitExpr_(static_cast<const VarNode*>(op));
+  }
+};
+
 struct PNode {
   const DFPatternNode* ptr;
-  const VarNode* matched = nullptr;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
 };
 
 struct RNode {
   const VarNode* ptr;
-  const DFPatternNode* matched = nullptr;
   std::vector<RNode*> children;
   std::vector<RNode*> parents;
 };
 
-/**
- * \brief This method try to match a real node and a pattern node along with 
its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
-    PNode* p, RNode* r, DFPatternMatcher* m,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched 
before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return 
std::nullopt;
-
-  UndoItems undo;
-
-  const auto commit = [&undo](PNode* p, RNode* r) {
-    // match with each other.
-    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
-    if (p->ptr == r->matched) {
-      ICHECK_EQ(p->matched, r->ptr);
-      return;
-    }
-    p->matched = r->ptr;
-    r->matched = p->ptr;
-    undo.emplace_back(p, r);
-  };
-
-  const auto quit = [&undo] {
-    for (auto& [p_node, r_node] : undo) {
-      p_node->matched = nullptr;
-      r_node->matched = nullptr;
-    }
-    return std::nullopt;
-  };
+struct MatchState {
+  void add(const PNode* p, const RNode* r) {
+    match_p_r[p] = r;
+    match_r_p[r] = p;
+  }
 
-  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
-    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
-      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
-      return true;
+  void add(const MatchState& other) {
+    for (const auto& [p, r] : other.match_p_r) {
+      add(p, r);
     }
-    return false;
-  };
+  }
 
-  commit(p, r);
+  const VarNode* matched(const PNode* p) const {
+    if (!match_p_r.count(p)) return nullptr;
+    return match_p_r.at(p)->ptr;
+  }
 
-  // match parent patterns.
-  for (auto& [pparent, constraints] : p->parents) {
-    bool any_cons_sat = false;
-    for (auto& rparent : r->parents) {
-      // skip if mismatch.
-      if (rparent->matched && rparent->matched != pparent->ptr) continue;
+  const DFPatternNode* matched(const RNode* r) const {
+    if (!match_r_p.count(r)) return nullptr;
+    return match_r_p.at(r)->ptr;
+  }
 
-      const auto& uses = def2use.at(rparent->ptr);
+  const VarNode* matched(const PNode& p) const { return matched(&p); }
+  const DFPatternNode* matched(const RNode& r) const { return matched(&r); }
 
-      // check edge constraints.
-      bool cons_sat = true;
-      for (const auto& cons : constraints) {
-        if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
-          cons_sat = false;
-          break;
-        }
-
-        if (cons.index != -1) {
-          const auto& callees = use2def.at(r->ptr);
-          if (callees.size() <= static_cast<size_t>(cons.index) ||
-              callees[cons.index] != rparent->ptr) {
-            cons_sat = false;
-            break;
-          }
-        }
-      }
-      if (!cons_sat) continue;
-      any_cons_sat = true;
+ private:
+  std::unordered_map<const PNode*, const RNode*> match_p_r;
+  std::unordered_map<const RNode*, const PNode*> match_r_p;
+};
 
-      // try all parent R nodes that are not matched yet.
-      // as long as ppattern can match one node.
-      if (!pparent->matched && try_match_update_undo(pparent, rparent)) {
-        commit(pparent, rparent);
-        break;
+/**
+ * \brief This method try to match a real node and a pattern node along with 
its neighbors.
+ */
+static std::optional<MatchState> TryMatch(const PNode& p, const RNode& r, 
DFPatternMatcher* m,
+                                          const MatcherUseDefAnalysis& 
ud_analysis) {
+  if (!m->Match(GetRef<DFPattern>(p.ptr), GetRef<Var>(r.ptr))) return 
std::nullopt;
+
+  MatchState result;
+
+  for (size_t i = 0; i < p.parents.size(); ++i) {
+    const auto p_node_parent = p.parents[i].first;
+    if (p_node_parent->ptr->IsInstance<WildcardPatternNode>()) {
+      ICHECK_EQ(p.parents.size(), r.parents.size());
+      if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) 
{
+        // A parent wildcard pattern is already matched to other variable.
+        return std::nullopt;

Review Comment:
   I guess you meant that if it is wildcard, we don't have to obey "one-one" 
mapping (originally suggested here 
https://github.com/tlc-pack/relax/issues/160#issue-1268184762 in the `One-One 
Match` section). If that's the case, I felt it might still make some sense to 
optionally allow one-one mapping even for wildcards? 



##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -523,109 +523,117 @@ bool MatchExpr(DFPattern pattern, Expr expr, 
Optional<Map<Var, Expr>> bindings_o
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
 
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+  // caller -> callee table.
+  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+  const VarNode* cur_user_;
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // init
+    cur_user_ = binding->var.get();
+    this->VisitVarDef(binding->var);
+    this->VisitExpr(binding->value);
+    cur_user_ = nullptr;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (nullptr == cur_user_) return;
+
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* 
var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
+    caller2callees[cur_user_].push_back(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) override {
+    VisitExpr_(static_cast<const VarNode*>(op));
+  }
+};
+
 struct PNode {
   const DFPatternNode* ptr;
-  const VarNode* matched = nullptr;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
 };
 
 struct RNode {
   const VarNode* ptr;
-  const DFPatternNode* matched = nullptr;
   std::vector<RNode*> children;
   std::vector<RNode*> parents;
 };
 
-/**
- * \brief This method try to match a real node and a pattern node along with 
its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
-    PNode* p, RNode* r, DFPatternMatcher* m,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched 
before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return 
std::nullopt;
-
-  UndoItems undo;
-
-  const auto commit = [&undo](PNode* p, RNode* r) {
-    // match with each other.
-    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
-    if (p->ptr == r->matched) {
-      ICHECK_EQ(p->matched, r->ptr);
-      return;
-    }
-    p->matched = r->ptr;
-    r->matched = p->ptr;
-    undo.emplace_back(p, r);
-  };
-
-  const auto quit = [&undo] {
-    for (auto& [p_node, r_node] : undo) {
-      p_node->matched = nullptr;
-      r_node->matched = nullptr;
-    }
-    return std::nullopt;
-  };
+struct MatchState {
+  void add(const PNode* p, const RNode* r) {
+    match_p_r[p] = r;
+    match_r_p[r] = p;
+  }
 
-  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
-    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
-      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
-      return true;
+  void add(const MatchState& other) {
+    for (const auto& [p, r] : other.match_p_r) {
+      add(p, r);
     }
-    return false;
-  };
+  }
 
-  commit(p, r);
+  const VarNode* matched(const PNode* p) const {
+    if (!match_p_r.count(p)) return nullptr;
+    return match_p_r.at(p)->ptr;
+  }
 
-  // match parent patterns.
-  for (auto& [pparent, constraints] : p->parents) {
-    bool any_cons_sat = false;
-    for (auto& rparent : r->parents) {
-      // skip if mismatch.
-      if (rparent->matched && rparent->matched != pparent->ptr) continue;
+  const DFPatternNode* matched(const RNode* r) const {
+    if (!match_r_p.count(r)) return nullptr;
+    return match_r_p.at(r)->ptr;
+  }
 
-      const auto& uses = def2use.at(rparent->ptr);
+  const VarNode* matched(const PNode& p) const { return matched(&p); }
+  const DFPatternNode* matched(const RNode& r) const { return matched(&r); }
 
-      // check edge constraints.
-      bool cons_sat = true;
-      for (const auto& cons : constraints) {
-        if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
-          cons_sat = false;
-          break;
-        }
-
-        if (cons.index != -1) {
-          const auto& callees = use2def.at(r->ptr);
-          if (callees.size() <= static_cast<size_t>(cons.index) ||
-              callees[cons.index] != rparent->ptr) {
-            cons_sat = false;
-            break;
-          }
-        }
-      }
-      if (!cons_sat) continue;
-      any_cons_sat = true;
+ private:
+  std::unordered_map<const PNode*, const RNode*> match_p_r;
+  std::unordered_map<const RNode*, const PNode*> match_r_p;
+};
 
-      // try all parent R nodes that are not matched yet.
-      // as long as ppattern can match one node.
-      if (!pparent->matched && try_match_update_undo(pparent, rparent)) {
-        commit(pparent, rparent);
-        break;
+/**
+ * \brief This method try to match a real node and a pattern node along with 
its neighbors.
+ */
+static std::optional<MatchState> TryMatch(const PNode& p, const RNode& r, 
DFPatternMatcher* m,
+                                          const MatcherUseDefAnalysis& 
ud_analysis) {
+  if (!m->Match(GetRef<DFPattern>(p.ptr), GetRef<Var>(r.ptr))) return 
std::nullopt;
+
+  MatchState result;
+
+  for (size_t i = 0; i < p.parents.size(); ++i) {
+    const auto p_node_parent = p.parents[i].first;
+    if (p_node_parent->ptr->IsInstance<WildcardPatternNode>()) {
+      ICHECK_EQ(p.parents.size(), r.parents.size());
+      if (auto v = result.matched(p_node_parent); v && v != r.parents[i]->ptr) 
{
+        // A parent wildcard pattern is already matched to other variable.
+        return std::nullopt;

Review Comment:
   Sorry if I missed anything here. But could you kindly help explain why we 
need to add special handlings for nodes whose parents (aka arguments) are 
wildcard? 



##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -523,109 +523,117 @@ bool MatchExpr(DFPattern pattern, Expr expr, 
Optional<Map<Var, Expr>> bindings_o
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_expr").set_body_typed(MatchExpr);
 
+class MatcherUseDefAnalysis : public relax::ExprVisitor {
+ public:
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
+  // caller -> callee table.
+  std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
+
+  const VarNode* cur_user_;
+
+  void VisitBinding_(const VarBindingNode* binding) override {
+    // init
+    cur_user_ = binding->var.get();
+    this->VisitVarDef(binding->var);
+    this->VisitExpr(binding->value);
+    cur_user_ = nullptr;
+  }
+
+  void VisitExpr_(const VarNode* op) override {
+    if (nullptr == cur_user_) return;
+
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* 
var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
+    caller2callees[cur_user_].push_back(op);
+  }
+
+  void VisitExpr_(const DataflowVarNode* op) override {
+    VisitExpr_(static_cast<const VarNode*>(op));
+  }
+};
+
 struct PNode {
   const DFPatternNode* ptr;
-  const VarNode* matched = nullptr;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> children;
   std::vector<std::pair<PNode*, const std::vector<PairCons>&>> parents;
 };
 
 struct RNode {
   const VarNode* ptr;
-  const DFPatternNode* matched = nullptr;
   std::vector<RNode*> children;
   std::vector<RNode*> parents;
 };
 
-/**
- * \brief This method try to match a real node and a pattern node along with 
its neighbors.
- */
-using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
-static std::optional<UndoItems> try_match(
-    PNode* p, RNode* r, DFPatternMatcher* m,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
-    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched 
before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return 
std::nullopt;
-
-  UndoItems undo;
-
-  const auto commit = [&undo](PNode* p, RNode* r) {
-    // match with each other.
-    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
-    if (p->ptr == r->matched) {
-      ICHECK_EQ(p->matched, r->ptr);
-      return;
-    }
-    p->matched = r->ptr;
-    r->matched = p->ptr;
-    undo.emplace_back(p, r);
-  };
-
-  const auto quit = [&undo] {
-    for (auto& [p_node, r_node] : undo) {
-      p_node->matched = nullptr;
-      r_node->matched = nullptr;
-    }
-    return std::nullopt;
-  };
+struct MatchState {
+  void add(const PNode* p, const RNode* r) {
+    match_p_r[p] = r;
+    match_r_p[r] = p;
+  }
 
-  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
-    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
-      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
-      return true;
+  void add(const MatchState& other) {
+    for (const auto& [p, r] : other.match_p_r) {
+      add(p, r);
     }
-    return false;
-  };
+  }
 
-  commit(p, r);
+  const VarNode* matched(const PNode* p) const {
+    if (!match_p_r.count(p)) return nullptr;
+    return match_p_r.at(p)->ptr;
+  }
 
-  // match parent patterns.
-  for (auto& [pparent, constraints] : p->parents) {
-    bool any_cons_sat = false;
-    for (auto& rparent : r->parents) {
-      // skip if mismatch.
-      if (rparent->matched && rparent->matched != pparent->ptr) continue;
+  const DFPatternNode* matched(const RNode* r) const {
+    if (!match_r_p.count(r)) return nullptr;
+    return match_r_p.at(r)->ptr;

Review Comment:
   Same:
   
   ```suggestion
       if (auto it = match_r_p.find(p); it != match_r_p.end()) return it->ptr;
       return nullptr;
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to