This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new a70768ea2d [Unity][Graph matching] Clean up undo stack for parent and 
child nodes properly (#14440)
a70768ea2d is described below

commit a70768ea2dab8919386378f41f2468e43bf6ca9c
Author: masahi <[email protected]>
AuthorDate: Sat Apr 1 03:56:19 2023 +0900

    [Unity][Graph matching] Clean up undo stack for parent and child nodes 
properly (#14440)
    
    * Clean up undo stack for parent and child nodes properly
    
    * Update src/relax/analysis/udchain.cc
    
    Co-authored-by: Jiawei Liu <[email protected]>
    
    * minor change
    
    * stack -> vector
    
    * remove stack header
    
    * fix accidentally removed statement from recent commit
    
    ---------
    
    Co-authored-by: Jiawei Liu <[email protected]>
---
 include/tvm/relax/dataflow_matcher.h        | 10 ++---
 src/relax/analysis/udchain.cc               | 26 ++++++------
 src/relax/ir/dataflow_matcher.cc            | 63 ++++++++++++++++-------------
 tests/python/relax/test_dataflow_pattern.py | 33 +++++++++++++++
 4 files changed, 87 insertions(+), 45 deletions(-)

diff --git a/include/tvm/relax/dataflow_matcher.h 
b/include/tvm/relax/dataflow_matcher.h
index 498f77a3f7..e4268be882 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -58,12 +58,12 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(
  * \param dfb The function to match.
  * \param start_hint The starting point expression to match to distinguish 
multiple matches.
  * \param must_include_hint If start_hint is given, the return pattern must 
include start_hint.
- * \return tvm::runtime::Map<DFPattern, Var>
+ * \return Matched patterns and corresponding bound variables
  */
-TVM_DLL tvm::runtime::Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
-                                                     const DataflowBlock& dfb,
-                                                     Optional<Var> start_hint 
= NullOpt,
-                                                     bool must_include_hint = 
false);
+TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
+                                                 const DataflowBlock& dfb,
+                                                 Optional<Var> start_hint = 
NullOpt,
+                                                 bool must_include_hint = 
false);
 
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index 77e52408a7..1c49fd581f 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -65,36 +65,36 @@ class UDChain : public relax::ExprVisitor {
 std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>> 
FunctionUseDef(
     const Function& fn) {
   UDChain udchain;
-  udchain.VisitExpr_(fn.get());
+  udchain.VisitExpr(fn);
 
   Map<Var, Array<Var>> user_map;
   Array<Var> fn_outs;
 
-  for (const auto& kv : udchain.to_users) {
+  for (const auto& [var, users] : udchain.to_users) {
     Array<Var> uses{};
-    uses.reserve(kv.second.size());
-    for (const auto& v : kv.second) {
-      if (nullptr == v &&
-          fn_outs.end() == std::find(fn_outs.begin(), fn_outs.end(), 
GetRef<Var>(kv.first))) {
-        fn_outs.push_back(GetRef<Var>(kv.first));
+    uses.reserve(users.size());
+    for (const auto& v : users) {
+      if (v == nullptr &&
+          std::find(fn_outs.begin(), fn_outs.end(), GetRef<Var>(var)) == 
fn_outs.end()) {
+        fn_outs.push_back(GetRef<Var>(var));
       } else {
         uses.push_back(GetRef<Var>(v));
       }
     }
-    user_map.Set(GetRef<Var>(kv.first), std::move(uses));
+    user_map.Set(GetRef<Var>(var), std::move(uses));
   }
   return std::make_pair(std::move(user_map), std::move(fn_outs));
 }
 
 runtime::Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb) {
   UDChain udchain;
-  udchain.VisitBindingBlock_(dfb.get());
+  udchain.VisitBindingBlock(dfb);
   runtime::Map<Var, Array<Var>> ret;
-  for (const auto& kv : udchain.to_users) {
+  for (const auto& [var, users] : udchain.to_users) {
     Array<Var> uses{};
-    uses.reserve(kv.second.size());
-    for (const auto& v : kv.second) uses.push_back(GetRef<Var>(v));
-    ret.Set(GetRef<Var>(kv.first), std::move(uses));
+    uses.reserve(users.size());
+    for (const auto& v : users) uses.push_back(GetRef<Var>(v));
+    ret.Set(GetRef<Var>(var), std::move(uses));
   }
   return ret;
 }
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index d1b7a1d62c..88381d6e26 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -33,7 +33,7 @@
 #include <array>
 #include <cstddef>
 #include <limits>
-#include <stack>
+#include <optional>
 #include <type_traits>
 #include <unordered_map>
 #include <unordered_set>
@@ -540,33 +540,40 @@ struct RNode {
 /**
  * \brief This method try to match a real node and a pattern node along with 
its neighbors.
  */
-static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
-                      const std::map<const VarNode*, std::vector<const 
VarNode*>>& def2use,
-                      const std::map<const VarNode*, std::vector<const 
VarNode*>>& use2def) {
-  if (p->matched != nullptr && p->matched == r->ptr) return true;  // matched 
before.
-  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return false;
+using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
+static std::optional<UndoItems> try_match(
+    PNode* p, RNode* r, DFPatternMatcher* m,
+    const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
+    const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
+  if (p->matched != nullptr && p->matched == r->ptr) return {};  // matched 
before.
+  if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return 
std::nullopt;
 
-  std::stack<std::pair<PNode*, RNode*>> undo_stack{};
+  UndoItems undo;
 
-  const auto commit = [&undo_stack](PNode* p, RNode* r) {
+  const auto commit = [&undo](PNode* p, RNode* r) {
     // match with each other.
     // TODO(ganler, masahi): Why commit on the same p-r pair happens more than 
once?
     if (p->ptr == r->matched) {
       ICHECK_EQ(p->matched, r->ptr);
       return;
     }
-    ICHECK(r->matched == nullptr);
     p->matched = r->ptr;
     r->matched = p->ptr;
-    undo_stack.emplace(p, r);
+    undo.emplace_back(p, r);
   };
 
-  const auto quit = [&undo_stack] {
-    while (!undo_stack.empty()) {
-      auto& top = undo_stack.top();
-      top.first->matched = nullptr;
-      top.second->matched = nullptr;
-      undo_stack.pop();
+  const auto quit = [&undo] {
+    for (auto& [p_node, r_node] : undo) {
+      p_node->matched = nullptr;
+      r_node->matched = nullptr;
+    }
+    return std::nullopt;
+  };
+
+  const auto try_match_update_undo = [&](PNode* p, RNode* r) {
+    if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
+      undo.insert(undo.end(), undo_more->begin(), undo_more->end());
+      return true;
     }
     return false;
   };
@@ -604,7 +611,7 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher* 
m,
 
       // try all parent R nodes that are not matched yet.
       // as long as ppattern can match one node.
-      if (!pparent->matched && try_match(pparent, rparent, m, def2use, 
use2def)) {
+      if (!pparent->matched && try_match_update_undo(pparent, rparent)) {
         commit(pparent, rparent);
         break;
       }
@@ -639,14 +646,14 @@ static bool try_match(PNode* p, RNode* r, 
DFPatternMatcher* m,
       if (!all_cons_pass) continue;
       any_cons_sat = true;
 
-      if (!pchild->matched && try_match(pchild, rchild, m, def2use, use2def)) {
+      if (!pchild->matched && try_match_update_undo(pchild, rchild)) {
         commit(pchild, rchild);
         break;
       }
     }
     if (!pchild->matched || !any_cons_sat) return quit();
   }
-  return true;
+  return undo;
 }
 
 class MatcherUseDefAnalysis : public relax::ExprVisitor {
@@ -686,13 +693,12 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
   }
 };
 
-Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock& 
dfb,
-                               Optional<Var> start_hint, bool 
must_include_hint) {
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const 
DataflowBlock& dfb,
+                                         Optional<Var> start_hint, bool 
must_include_hint) {
   if (ctx->src_ordered.size() == 0) {
-    return {};
+    return NullOpt;
   }
 
-  Map<DFPattern, Var> ret;
   // TODO(@ganler): Handle non-may external use.
   ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is 
supported yet.";
   ICHECK(!must_include_hint || start_hint.defined())
@@ -737,12 +743,15 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, 
const DataflowBlock& d
     }
   }
 
+  Map<DFPattern, Var> ret;
+
   if (start_hint) {
     auto rnode_ptr = var2node.at(start_hint.value().get());
     for (auto& p_node : pattern2node) {
       if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use, 
caller2callees)) {
-        for (const auto& [df_pattern, pattern_node] : pattern2node)
+        for (const auto& [df_pattern, pattern_node] : pattern2node) {
           ret.Set(GetRef<DFPattern>(df_pattern), 
GetRef<Var>(pattern_node.matched));
+        }
         return ret;
       }
     }
@@ -757,15 +766,15 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, 
const DataflowBlock& d
       if (start_hint.defined() && start_hint.value().get() == var) continue;
       RNode& r_node = var2node[var];
       if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees)) 
{
-        for (const auto& [df_pattern, pattern_node] : pattern2node)
+        for (const auto& [df_pattern, pattern_node] : pattern2node) {
           ret.Set(GetRef<DFPattern>(df_pattern), 
GetRef<Var>(pattern_node.matched));
-
+        }
         return ret;
       }
     }
   }
 
-  return ret;
+  return NullOpt;
 }
 
 TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 76bce47f7f..f18244096e 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1042,5 +1042,38 @@ def test_attention_qkv():
         assert out[V_weight_pat].name_hint == "w2"
 
 
+def test_attention_fake_qkv():
+    @tvm.script.ir_module
+    class QKV_proj:
+        @R.function
+        def main(
+            x1: R.Tensor((2, 1024, 640), "float32"),
+            x2: R.Tensor((2, 1024, 640), "float32"),
+            w0: R.Tensor((640, 640), "float32"),
+            w1: R.Tensor((640, 640), "float32"),
+            w2: R.Tensor((640, 640), "float32"),
+        ) -> R.Tensor:
+            with R.dataflow():
+                lv0 = R.matmul(x1, w0)
+                lv1 = R.matmul(x2, w1)
+                lv2 = R.matmul(x2, w2)
+                out = (lv0, lv1, lv2)
+                R.output(out)
+            return out
+
+    with PatternContext() as ctx:
+        inp_pat = wildcard()
+        Q_weight_pat = wildcard()
+        K_weight_pat = wildcard()
+        V_weight_pat = wildcard()
+
+        matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+        matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+        matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+        dfb = QKV_proj["main"].body.blocks[0]
+        assert ctx.match_dfb(dfb) is None
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to