This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new a70768ea2d [Unity][Graph matching] Clean up undo stack for parent and
child nodes properly (#14440)
a70768ea2d is described below
commit a70768ea2dab8919386378f41f2468e43bf6ca9c
Author: masahi <[email protected]>
AuthorDate: Sat Apr 1 03:56:19 2023 +0900
[Unity][Graph matching] Clean up undo stack for parent and child nodes
properly (#14440)
* Clean up undo stack for parent and child nodes properly
* Update src/relax/analysis/udchain.cc
Co-authored-by: Jiawei Liu <[email protected]>
* minor change
* stack -> vector
* remove stack header
* fix accidentally removed statement from recent commit
---------
Co-authored-by: Jiawei Liu <[email protected]>
---
include/tvm/relax/dataflow_matcher.h | 10 ++---
src/relax/analysis/udchain.cc | 26 ++++++------
src/relax/ir/dataflow_matcher.cc | 63 ++++++++++++++++-------------
tests/python/relax/test_dataflow_pattern.py | 33 +++++++++++++++
4 files changed, 87 insertions(+), 45 deletions(-)
diff --git a/include/tvm/relax/dataflow_matcher.h
b/include/tvm/relax/dataflow_matcher.h
index 498f77a3f7..e4268be882 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -58,12 +58,12 @@ Optional<Map<DFPattern, Expr>> ExtractMatchedExpr(
* \param dfb The function to match.
* \param start_hint The starting point expression to match to distinguish
multiple matches.
* \param must_include_hint If start_hint is given, the return pattern must
include start_hint.
- * \return tvm::runtime::Map<DFPattern, Var>
+ * \return Matched patterns and corresponding bound variables
*/
-TVM_DLL tvm::runtime::Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
- const DataflowBlock& dfb,
- Optional<Var> start_hint
= NullOpt,
- bool must_include_hint =
false);
+TVM_DLL Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx,
+ const DataflowBlock& dfb,
+ Optional<Var> start_hint =
NullOpt,
+ bool must_include_hint =
false);
} // namespace relax
} // namespace tvm
diff --git a/src/relax/analysis/udchain.cc b/src/relax/analysis/udchain.cc
index 77e52408a7..1c49fd581f 100644
--- a/src/relax/analysis/udchain.cc
+++ b/src/relax/analysis/udchain.cc
@@ -65,36 +65,36 @@ class UDChain : public relax::ExprVisitor {
std::pair<runtime::Map<Var, runtime::Array<Var>>, runtime::Array<Var>>
FunctionUseDef(
const Function& fn) {
UDChain udchain;
- udchain.VisitExpr_(fn.get());
+ udchain.VisitExpr(fn);
Map<Var, Array<Var>> user_map;
Array<Var> fn_outs;
- for (const auto& kv : udchain.to_users) {
+ for (const auto& [var, users] : udchain.to_users) {
Array<Var> uses{};
- uses.reserve(kv.second.size());
- for (const auto& v : kv.second) {
- if (nullptr == v &&
- fn_outs.end() == std::find(fn_outs.begin(), fn_outs.end(),
GetRef<Var>(kv.first))) {
- fn_outs.push_back(GetRef<Var>(kv.first));
+ uses.reserve(users.size());
+ for (const auto& v : users) {
+ if (v == nullptr &&
+ std::find(fn_outs.begin(), fn_outs.end(), GetRef<Var>(var)) ==
fn_outs.end()) {
+ fn_outs.push_back(GetRef<Var>(var));
} else {
uses.push_back(GetRef<Var>(v));
}
}
- user_map.Set(GetRef<Var>(kv.first), std::move(uses));
+ user_map.Set(GetRef<Var>(var), std::move(uses));
}
return std::make_pair(std::move(user_map), std::move(fn_outs));
}
runtime::Map<Var, Array<Var>> DataflowBlockUseDef(const DataflowBlock& dfb) {
UDChain udchain;
- udchain.VisitBindingBlock_(dfb.get());
+ udchain.VisitBindingBlock(dfb);
runtime::Map<Var, Array<Var>> ret;
- for (const auto& kv : udchain.to_users) {
+ for (const auto& [var, users] : udchain.to_users) {
Array<Var> uses{};
- uses.reserve(kv.second.size());
- for (const auto& v : kv.second) uses.push_back(GetRef<Var>(v));
- ret.Set(GetRef<Var>(kv.first), std::move(uses));
+ uses.reserve(users.size());
+ for (const auto& v : users) uses.push_back(GetRef<Var>(v));
+ ret.Set(GetRef<Var>(var), std::move(uses));
}
return ret;
}
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index d1b7a1d62c..88381d6e26 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -33,7 +33,7 @@
#include <array>
#include <cstddef>
#include <limits>
-#include <stack>
+#include <optional>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
@@ -540,33 +540,40 @@ struct RNode {
/**
* \brief This method try to match a real node and a pattern node along with
its neighbors.
*/
-static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
- const std::map<const VarNode*, std::vector<const
VarNode*>>& def2use,
- const std::map<const VarNode*, std::vector<const
VarNode*>>& use2def) {
- if (p->matched != nullptr && p->matched == r->ptr) return true; // matched
before.
- if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return false;
+using UndoItems = std::vector<std::pair<PNode*, RNode*>>;
+static std::optional<UndoItems> try_match(
+ PNode* p, RNode* r, DFPatternMatcher* m,
+ const std::map<const VarNode*, std::vector<const VarNode*>>& def2use,
+ const std::map<const VarNode*, std::vector<const VarNode*>>& use2def) {
+ if (p->matched != nullptr && p->matched == r->ptr) return {}; // matched
before.
+ if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return
std::nullopt;
- std::stack<std::pair<PNode*, RNode*>> undo_stack{};
+ UndoItems undo;
- const auto commit = [&undo_stack](PNode* p, RNode* r) {
+ const auto commit = [&undo](PNode* p, RNode* r) {
// match with each other.
// TODO(ganler, masahi): Why commit on the same p-r pair happens more than
once?
if (p->ptr == r->matched) {
ICHECK_EQ(p->matched, r->ptr);
return;
}
- ICHECK(r->matched == nullptr);
p->matched = r->ptr;
r->matched = p->ptr;
- undo_stack.emplace(p, r);
+ undo.emplace_back(p, r);
};
- const auto quit = [&undo_stack] {
- while (!undo_stack.empty()) {
- auto& top = undo_stack.top();
- top.first->matched = nullptr;
- top.second->matched = nullptr;
- undo_stack.pop();
+ const auto quit = [&undo] {
+ for (auto& [p_node, r_node] : undo) {
+ p_node->matched = nullptr;
+ r_node->matched = nullptr;
+ }
+ return std::nullopt;
+ };
+
+ const auto try_match_update_undo = [&](PNode* p, RNode* r) {
+ if (auto undo_more = try_match(p, r, m, def2use, use2def)) {
+ undo.insert(undo.end(), undo_more->begin(), undo_more->end());
+ return true;
}
return false;
};
@@ -604,7 +611,7 @@ static bool try_match(PNode* p, RNode* r, DFPatternMatcher*
m,
// try all parent R nodes that are not matched yet.
// as long as ppattern can match one node.
- if (!pparent->matched && try_match(pparent, rparent, m, def2use,
use2def)) {
+ if (!pparent->matched && try_match_update_undo(pparent, rparent)) {
commit(pparent, rparent);
break;
}
@@ -639,14 +646,14 @@ static bool try_match(PNode* p, RNode* r,
DFPatternMatcher* m,
if (!all_cons_pass) continue;
any_cons_sat = true;
- if (!pchild->matched && try_match(pchild, rchild, m, def2use, use2def)) {
+ if (!pchild->matched && try_match_update_undo(pchild, rchild)) {
commit(pchild, rchild);
break;
}
}
if (!pchild->matched || !any_cons_sat) return quit();
}
- return true;
+ return undo;
}
class MatcherUseDefAnalysis : public relax::ExprVisitor {
@@ -686,13 +693,12 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
}
};
-Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock&
dfb,
- Optional<Var> start_hint, bool
must_include_hint) {
+Optional<Map<DFPattern, Var>> MatchGraph(const PatternContext& ctx, const
DataflowBlock& dfb,
+ Optional<Var> start_hint, bool
must_include_hint) {
if (ctx->src_ordered.size() == 0) {
- return {};
+ return NullOpt;
}
- Map<DFPattern, Var> ret;
// TODO(@ganler): Handle non-may external use.
ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is
supported yet.";
ICHECK(!must_include_hint || start_hint.defined())
@@ -737,12 +743,15 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
const DataflowBlock& d
}
}
+ Map<DFPattern, Var> ret;
+
if (start_hint) {
auto rnode_ptr = var2node.at(start_hint.value().get());
for (auto& p_node : pattern2node) {
if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use,
caller2callees)) {
- for (const auto& [df_pattern, pattern_node] : pattern2node)
+ for (const auto& [df_pattern, pattern_node] : pattern2node) {
ret.Set(GetRef<DFPattern>(df_pattern),
GetRef<Var>(pattern_node.matched));
+ }
return ret;
}
}
@@ -757,15 +766,15 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
const DataflowBlock& d
if (start_hint.defined() && start_hint.value().get() == var) continue;
RNode& r_node = var2node[var];
if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees))
{
- for (const auto& [df_pattern, pattern_node] : pattern2node)
+ for (const auto& [df_pattern, pattern_node] : pattern2node) {
ret.Set(GetRef<DFPattern>(df_pattern),
GetRef<Var>(pattern_node.matched));
-
+ }
return ret;
}
}
}
- return ret;
+ return NullOpt;
}
TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 76bce47f7f..f18244096e 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1042,5 +1042,38 @@ def test_attention_qkv():
assert out[V_weight_pat].name_hint == "w2"
+def test_attention_fake_qkv():
+ @tvm.script.ir_module
+ class QKV_proj:
+ @R.function
+ def main(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ x2: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv0 = R.matmul(x1, w0)
+ lv1 = R.matmul(x2, w1)
+ lv2 = R.matmul(x2, w2)
+ out = (lv0, lv1, lv2)
+ R.output(out)
+ return out
+
+ with PatternContext() as ctx:
+ inp_pat = wildcard()
+ Q_weight_pat = wildcard()
+ K_weight_pat = wildcard()
+ V_weight_pat = wildcard()
+
+ matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+ matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+ matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+ dfb = QKV_proj["main"].body.blocks[0]
+ assert ctx.match_dfb(dfb) is None
+
+
if __name__ == "__main__":
tvm.testing.main()