This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 0695cfe4e2 [Unity] Remove non-deterministic behavior from graph 
pattern matching  (#14417)
0695cfe4e2 is described below

commit 0695cfe4e2036195f6df418bab861f05f88eb93c
Author: masahi <[email protected]>
AuthorDate: Fri Mar 31 02:57:30 2023 +0900

    [Unity] Remove non-deterministic behavior from graph pattern matching  
(#14417)
    
    * Remove all non-determinsm from graph matching
    
    * add test
    
    * typo
    
    * try fixing compile error for gcc
    
    * more style update
    
    * suppress compile warning
---
 include/tvm/relax/dataflow_pattern.h        |  24 +++++--
 src/relax/ir/dataflow_matcher.cc            | 104 +++++++++++++++-------------
 tests/python/relax/test_dataflow_pattern.py |  45 ++++++++++++
 3 files changed, 119 insertions(+), 54 deletions(-)

diff --git a/include/tvm/relax/dataflow_pattern.h 
b/include/tvm/relax/dataflow_pattern.h
index 37640750a8..144a7f45bf 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -191,7 +191,10 @@ class PatternContextNode : public Object {
     kMustNot, /*!< All nodes except outputs only have internal depedencies in 
the matched graph. */
   } allow_extern_use = kMay;
   // src node -> <dst node, constraint type> constraints.
-  std::map<DFPattern, std::map<DFPattern, std::vector<PairCons>>> constraints;
+  // Dst nodes are kept in a vector to keep them ordered.
+  std::map<DFPattern, std::vector<std::pair<DFPattern, 
std::vector<PairCons>>>> constraints;
+  // Keep a separate vector of patterns to process constraints in a fixed 
order.
+  std::vector<DFPattern> src_ordered;
 
   static constexpr const char* _type_key = "relax.dpl.PatternContext";
   TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object);
@@ -224,9 +227,22 @@ class PatternContext : public ObjectRef {
    * \param cons The constraint type. \sa PairCons
    */
   void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
-    auto& vec = (*this)->constraints[producer][consumer];
-    ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) << 
"Constraint already exists";
-    vec.push_back(cons);
+    auto& pairs = (*this)->constraints[producer];
+    auto it = std::find_if(pairs.begin(), pairs.end(),
+                           [consumer](auto p) { return p.first == consumer; });
+    if (it == pairs.end()) {
+      pairs.emplace_back(consumer, std::vector{cons});
+    } else {
+      auto& vec = it->second;
+      ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend())
+          << "Constraint already exists";
+      vec.push_back(cons);
+    }
+
+    auto& patterns = (*this)->src_ordered;
+    if (std::find(patterns.begin(), patterns.end(), producer) == 
patterns.end()) {
+      patterns.push_back(producer);
+    }
   }
 
   /*! \brief Get the pass context object on the top of the stack */
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index c6d705b5b4..d1b7a1d62c 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -541,15 +541,21 @@ struct RNode {
  * \brief This method try to match a real node and a pattern node along with 
its neighbors.
  */
 static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
-                      const std::map<const VarNode*, std::set<const 
VarNode*>>& def2use,
+                      const std::map<const VarNode*, std::vector<const 
VarNode*>>& def2use,
                       const std::map<const VarNode*, std::vector<const 
VarNode*>>& use2def) {
-  if (nullptr != p->matched && p->matched == r->ptr) return true;  // matched 
before.
+  if (p->matched != nullptr && p->matched == r->ptr) return true;  // matched 
before.
   if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return false;
 
   std::stack<std::pair<PNode*, RNode*>> undo_stack{};
 
   const auto commit = [&undo_stack](PNode* p, RNode* r) {
     // match with each other.
+    // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
+    if (p->ptr == r->matched) {
+      ICHECK_EQ(p->matched, r->ptr);
+      return;
+    }
+    ICHECK(r->matched == nullptr);
     p->matched = r->ptr;
     r->matched = p->ptr;
     undo_stack.emplace(p, r);
@@ -568,31 +574,26 @@ static bool try_match(PNode* p, RNode* r, 
DFPatternMatcher* m,
   commit(p, r);
 
   // match parent patterns.
-  for (auto& pparent_pairs : p->parents) {
-    PNode* pparent = pparent_pairs.first;
-    const std::vector<PairCons>& constraints = pparent_pairs.second;
-
+  for (auto& [pparent, constraints] : p->parents) {
     bool any_cons_sat = false;
     for (auto& rparent : r->parents) {
       // skip if mismatch.
       if (rparent->matched && rparent->matched != pparent->ptr) continue;
 
       const auto& uses = def2use.at(rparent->ptr);
-      // skip if `rparent` is not used by `r`.
-      if (uses.cend() == uses.find(r->ptr)) continue;
 
       // check edge constraints.
       bool cons_sat = true;
       for (const auto& cons : constraints) {
-        if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) {
+        if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
           cons_sat = false;
           break;
         }
 
-        if (-1 != cons.index) {
+        if (cons.index != -1) {
           const auto& callees = use2def.at(r->ptr);
-          if (static_cast<size_t>(cons.index) >= callees.size() ||
-              rparent->ptr != callees[cons.index]) {
+          if (callees.size() <= static_cast<size_t>(cons.index) ||
+              callees[cons.index] != rparent->ptr) {
             cons_sat = false;
             break;
           }
@@ -612,27 +613,24 @@ static bool try_match(PNode* p, RNode* r, 
DFPatternMatcher* m,
   }
 
   // forward matching;
-  for (auto& pchild_pairs : p->children) {
-    PNode* pchild = pchild_pairs.first;
-    const std::vector<PairCons>& constraints = pchild_pairs.second;
+  for (auto& [pchild, constraints] : p->children) {
     bool any_cons_sat = false;
     for (auto& rchild : r->children) {
       if (rchild->matched && rchild->matched != pchild->ptr) continue;
 
       const auto& uses = def2use.at(r->ptr);
-      if (uses.cend() == uses.find(rchild->ptr)) continue;
 
       // check edge constraints.
       bool all_cons_pass = true;
       for (const auto& cons : constraints) {
-        if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) {
+        if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
           all_cons_pass = false;
           break;
         }
 
-        if (-1 != cons.index) {
+        if (cons.index != -1) {
           const auto& callees = use2def.at(rchild->ptr);
-          if (static_cast<size_t>(cons.index) >= callees.size() || r->ptr != 
callees[cons.index]) {
+          if (callees.size() <= static_cast<size_t>(cons.index) || 
callees[cons.index] != r->ptr) {
             all_cons_pass = false;
             break;
           }
@@ -648,13 +646,13 @@ static bool try_match(PNode* p, RNode* r, 
DFPatternMatcher* m,
     }
     if (!pchild->matched || !any_cons_sat) return quit();
   }
-
   return true;
 }
 
 class MatcherUseDefAnalysis : public relax::ExprVisitor {
  public:
-  std::map<const VarNode*, std::set<const VarNode*>> def2use;
+  std::vector<const VarNode*> vars;
+  std::map<const VarNode*, std::vector<const VarNode*>> def2use;
   // caller -> callee table.
   std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
 
@@ -671,7 +669,15 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
   void VisitExpr_(const VarNode* op) override {
     if (nullptr == cur_user_) return;
 
-    def2use[op].insert(cur_user_);
+    auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode* 
var) {
+      if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+        vec.push_back(var);
+      }
+    };
+
+    check_and_push(def2use[op], cur_user_);
+    check_and_push(vars, op);
+
     caller2callees[cur_user_].push_back(op);
   }
 
@@ -682,6 +688,10 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
 
 Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock& 
dfb,
                                Optional<Var> start_hint, bool 
must_include_hint) {
+  if (ctx->src_ordered.size() == 0) {
+    return {};
+  }
+
   Map<DFPattern, Var> ret;
   // TODO(@ganler): Handle non-may external use.
   ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is 
supported yet.";
@@ -691,7 +701,6 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, 
const DataflowBlock& d
   const auto var2val = AnalyzeVar2Value(dfb);
   DFPatternMatcher matcher(var2val);
 
-  // std::map<const VarNode*, std::set<const VarNode*>>
   MatcherUseDefAnalysis ud_analysis;
   ud_analysis.VisitBindingBlock_(dfb.get());
   const auto& def2use = ud_analysis.def2use;
@@ -701,9 +710,8 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, 
const DataflowBlock& d
   std::unordered_map<const VarNode*, RNode> var2node;
   var2node.reserve(dfb->bindings.size());
 
-  for (const auto& du : def2use) {
-    const VarNode* cur_var = du.first;
-    const std::set<const VarNode*>& uses = du.second;
+  for (const VarNode* cur_var : ud_analysis.vars) {
+    const auto& uses = def2use.at(cur_var);
     RNode& cur_node = var2node[cur_var];
     cur_node.ptr = cur_var;
     for (const VarNode* use : uses) {
@@ -717,29 +725,24 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, 
const DataflowBlock& d
   std::unordered_map<const DFPatternNode*, PNode> pattern2node;
   pattern2node.reserve(ctx->constraints.size());
 
-  for (const auto& def2use_pattern : ctx->constraints) {
-    const DFPatternNode* def_pattern = def2use_pattern.first.get();
-    const std::map<DFPattern, std::vector<PairCons>>& uses = 
def2use_pattern.second;
-    PNode& def_node = pattern2node[def_pattern];
-    def_node.ptr = def_pattern;
+  for (const auto& [def_pattern, uses] : ctx->constraints) {
+    PNode& def_node = pattern2node[def_pattern.get()];
+    def_node.ptr = def_pattern.get();
     def_node.children.reserve(uses.size());
-    for (const auto& use : uses) {
-      const auto& cons = use.second;
-      const DFPatternNode* use_pattern = use.first.get();
-      PNode& use_node = pattern2node[use_pattern];
-      use_node.ptr = use_pattern;
+    for (const auto& [use_pattern, cons] : uses) {
+      PNode& use_node = pattern2node[use_pattern.get()];
+      use_node.ptr = use_pattern.get();
       use_node.parents.emplace_back(&def_node, std::ref(cons));
       def_node.children.emplace_back(&use_node, std::ref(cons));
     }
   }
 
-  if (start_hint.defined()) {
-    Var v = start_hint.value();
-    auto rnode_ptr = var2node.find(v.get());
-    for (auto& ppair : pattern2node) {
-      if (try_match(&ppair.second, &rnode_ptr->second, &matcher, def2use, 
caller2callees)) {
-        for (auto ppair : pattern2node)
-          ret.Set(GetRef<DFPattern>(ppair.first), 
GetRef<Var>(ppair.second.matched));
+  if (start_hint) {
+    auto rnode_ptr = var2node.at(start_hint.value().get());
+    for (auto& p_node : pattern2node) {
+      if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, 
caller2callees)) {
+        for (const auto& [df_pattern, pattern_node] : pattern2node)
+          ret.Set(GetRef<DFPattern>(df_pattern), 
GetRef<Var>(pattern_node.matched));
         return ret;
       }
     }
@@ -747,14 +750,15 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, 
const DataflowBlock& d
     if (must_include_hint) return ret;
   }
 
-  PNode* pnode_start = &pattern2node.begin()->second;
+  PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()];
 
-  if (!pnode_start->matched) {
-    for (auto& rpair : var2node) {
-      if (start_hint.defined() && start_hint.value().get() == rpair.first) 
continue;
-      if (try_match(pnode_start, &rpair.second, &matcher, def2use, 
caller2callees)) {
-        for (auto ppair : pattern2node)
-          ret.Set(GetRef<DFPattern>(ppair.first), 
GetRef<Var>(ppair.second.matched));
+  if (!pnode_start.matched) {
+    for (const auto& var : ud_analysis.vars) {
+      if (start_hint.defined() && start_hint.value().get() == var) continue;
+      RNode& r_node = var2node[var];
+      if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) 
{
+        for (const auto& [df_pattern, pattern_node] : pattern2node)
+          ret.Set(GetRef<DFPattern>(df_pattern), 
GetRef<Var>(pattern_node.matched));
 
         return ret;
       }
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index a40faf3bcb..9679e14fff 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1006,5 +1006,50 @@ def test_rewrite_attention():
     tvm.ir.assert_structural_equal(rewritten, expected)
 
 
+def test_attention_qkv():
+    @tvm.script.ir_module
+    class QKV_proj:
+        @R.function
+        def main(
+            x: R.Tensor((2, 1024, 640), "float32"),
+            w0: R.Tensor((640, 640), "float32"),
+            w1: R.Tensor((640, 640), "float32"),
+            w2: R.Tensor((640, 640), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                lv0 = R.matmul(x, w0)
+                lv1 = R.matmul(x, w1)
+                lv2 = R.matmul(x, w2)
+                out = (lv0, lv1, lv2)
+                R.output(out)
+            return out
+
+    with PatternContext() as ctx:
+        inp_pat = wildcard()
+        Q_weight_pat = wildcard()
+        K_weight_pat = wildcard()
+        V_weight_pat = wildcard()
+
+        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+        # TODO(masahi): Automate addition of used_by constraints during is_op
+        inp_pat.used_by(matmul1, 0)
+        inp_pat.used_by(matmul2, 0)
+        inp_pat.used_by(matmul3, 0)
+
+        Q_weight_pat.only_used_by(matmul1, 1)
+        K_weight_pat.only_used_by(matmul2, 1)
+        V_weight_pat.only_used_by(matmul3, 1)
+
+        dfb = QKV_proj["main"].body.blocks[0]
+        out = ctx.match_dfb(dfb)
+
+        assert out[Q_weight_pat].name_hint == "w0"
+        assert out[K_weight_pat].name_hint == "w1"
+        assert out[V_weight_pat].name_hint == "w2"
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to