This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 0695cfe4e2 [Unity] Remove non-deterministic behavior from graph
pattern matching (#14417)
0695cfe4e2 is described below
commit 0695cfe4e2036195f6df418bab861f05f88eb93c
Author: masahi <[email protected]>
AuthorDate: Fri Mar 31 02:57:30 2023 +0900
[Unity] Remove non-deterministic behavior from graph pattern matching
(#14417)
* Remove all non-determinsm from graph matching
* add test
* typo
* try fixing compile error for gcc
* more style update
* suppress compile warning
---
include/tvm/relax/dataflow_pattern.h | 24 +++++--
src/relax/ir/dataflow_matcher.cc | 104 +++++++++++++++-------------
tests/python/relax/test_dataflow_pattern.py | 45 ++++++++++++
3 files changed, 119 insertions(+), 54 deletions(-)
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index 37640750a8..144a7f45bf 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -191,7 +191,10 @@ class PatternContextNode : public Object {
kMustNot, /*!< All nodes except outputs only have internal depedencies in
the matched graph. */
} allow_extern_use = kMay;
// src node -> <dst node, constraint type> constraints.
- std::map<DFPattern, std::map<DFPattern, std::vector<PairCons>>> constraints;
+ // Dst nodes are kept in a vector to keep them ordered.
+ std::map<DFPattern, std::vector<std::pair<DFPattern,
std::vector<PairCons>>>> constraints;
+ // Keep a separate vector of patterns to process constraints in a fixed
order.
+ std::vector<DFPattern> src_ordered;
static constexpr const char* _type_key = "relax.dpl.PatternContext";
TVM_DECLARE_FINAL_OBJECT_INFO(PatternContextNode, Object);
@@ -224,9 +227,22 @@ class PatternContext : public ObjectRef {
* \param cons The constraint type. \sa PairCons
*/
void add_constraint(DFPattern producer, DFPattern consumer, PairCons cons) {
- auto& vec = (*this)->constraints[producer][consumer];
- ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend()) <<
"Constraint already exists";
- vec.push_back(cons);
+ auto& pairs = (*this)->constraints[producer];
+ auto it = std::find_if(pairs.begin(), pairs.end(),
+ [consumer](auto p) { return p.first == consumer; });
+ if (it == pairs.end()) {
+ pairs.emplace_back(consumer, std::vector{cons});
+ } else {
+ auto& vec = it->second;
+ ICHECK(std::find(vec.cbegin(), vec.cend(), cons) == vec.cend())
+ << "Constraint already exists";
+ vec.push_back(cons);
+ }
+
+ auto& patterns = (*this)->src_ordered;
+ if (std::find(patterns.begin(), patterns.end(), producer) ==
patterns.end()) {
+ patterns.push_back(producer);
+ }
}
/*! \brief Get the pass context object on the top of the stack */
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index c6d705b5b4..d1b7a1d62c 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -541,15 +541,21 @@ struct RNode {
* \brief This method try to match a real node and a pattern node along with
its neighbors.
*/
static bool try_match(PNode* p, RNode* r, DFPatternMatcher* m,
- const std::map<const VarNode*, std::set<const
VarNode*>>& def2use,
+ const std::map<const VarNode*, std::vector<const
VarNode*>>& def2use,
const std::map<const VarNode*, std::vector<const
VarNode*>>& use2def) {
- if (nullptr != p->matched && p->matched == r->ptr) return true; // matched
before.
+ if (p->matched != nullptr && p->matched == r->ptr) return true; // matched
before.
if (!m->Match(GetRef<DFPattern>(p->ptr), GetRef<Var>(r->ptr))) return false;
std::stack<std::pair<PNode*, RNode*>> undo_stack{};
const auto commit = [&undo_stack](PNode* p, RNode* r) {
// match with each other.
+ // TODO(ganler, masahi): Why commit on the same p-r pair happens more than
once?
+ if (p->ptr == r->matched) {
+ ICHECK_EQ(p->matched, r->ptr);
+ return;
+ }
+ ICHECK(r->matched == nullptr);
p->matched = r->ptr;
r->matched = p->ptr;
undo_stack.emplace(p, r);
@@ -568,31 +574,26 @@ static bool try_match(PNode* p, RNode* r,
DFPatternMatcher* m,
commit(p, r);
// match parent patterns.
- for (auto& pparent_pairs : p->parents) {
- PNode* pparent = pparent_pairs.first;
- const std::vector<PairCons>& constraints = pparent_pairs.second;
-
+ for (auto& [pparent, constraints] : p->parents) {
bool any_cons_sat = false;
for (auto& rparent : r->parents) {
// skip if mismatch.
if (rparent->matched && rparent->matched != pparent->ptr) continue;
const auto& uses = def2use.at(rparent->ptr);
- // skip if `rparent` is not used by `r`.
- if (uses.cend() == uses.find(r->ptr)) continue;
// check edge constraints.
bool cons_sat = true;
for (const auto& cons : constraints) {
- if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) {
+ if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
cons_sat = false;
break;
}
- if (-1 != cons.index) {
+ if (cons.index != -1) {
const auto& callees = use2def.at(r->ptr);
- if (static_cast<size_t>(cons.index) >= callees.size() ||
- rparent->ptr != callees[cons.index]) {
+ if (callees.size() <= static_cast<size_t>(cons.index) ||
+ callees[cons.index] != rparent->ptr) {
cons_sat = false;
break;
}
@@ -612,27 +613,24 @@ static bool try_match(PNode* p, RNode* r,
DFPatternMatcher* m,
}
// forward matching;
- for (auto& pchild_pairs : p->children) {
- PNode* pchild = pchild_pairs.first;
- const std::vector<PairCons>& constraints = pchild_pairs.second;
+ for (auto& [pchild, constraints] : p->children) {
bool any_cons_sat = false;
for (auto& rchild : r->children) {
if (rchild->matched && rchild->matched != pchild->ptr) continue;
const auto& uses = def2use.at(r->ptr);
- if (uses.cend() == uses.find(rchild->ptr)) continue;
// check edge constraints.
bool all_cons_pass = true;
for (const auto& cons : constraints) {
- if (PairCons::kOnlyUsedBy == cons.type && uses.size() != 1) {
+ if (cons.type == PairCons::kOnlyUsedBy && uses.size() != 1) {
all_cons_pass = false;
break;
}
- if (-1 != cons.index) {
+ if (cons.index != -1) {
const auto& callees = use2def.at(rchild->ptr);
- if (static_cast<size_t>(cons.index) >= callees.size() || r->ptr !=
callees[cons.index]) {
+ if (callees.size() <= static_cast<size_t>(cons.index) ||
callees[cons.index] != r->ptr) {
all_cons_pass = false;
break;
}
@@ -648,13 +646,13 @@ static bool try_match(PNode* p, RNode* r,
DFPatternMatcher* m,
}
if (!pchild->matched || !any_cons_sat) return quit();
}
-
return true;
}
class MatcherUseDefAnalysis : public relax::ExprVisitor {
public:
- std::map<const VarNode*, std::set<const VarNode*>> def2use;
+ std::vector<const VarNode*> vars;
+ std::map<const VarNode*, std::vector<const VarNode*>> def2use;
// caller -> callee table.
std::map<const VarNode*, std::vector<const VarNode*>> caller2callees;
@@ -671,7 +669,15 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
void VisitExpr_(const VarNode* op) override {
if (nullptr == cur_user_) return;
- def2use[op].insert(cur_user_);
+ auto check_and_push = [](std::vector<const VarNode*>& vec, const VarNode*
var) {
+ if (std::find(vec.begin(), vec.end(), var) == vec.end()) {
+ vec.push_back(var);
+ }
+ };
+
+ check_and_push(def2use[op], cur_user_);
+ check_and_push(vars, op);
+
caller2callees[cur_user_].push_back(op);
}
@@ -682,6 +688,10 @@ class MatcherUseDefAnalysis : public relax::ExprVisitor {
Map<DFPattern, Var> MatchGraph(const PatternContext& ctx, const DataflowBlock&
dfb,
Optional<Var> start_hint, bool
must_include_hint) {
+ if (ctx->src_ordered.size() == 0) {
+ return {};
+ }
+
Map<DFPattern, Var> ret;
// TODO(@ganler): Handle non-may external use.
ICHECK(ctx->allow_extern_use == PatternContextNode::kMay) << "Only kMay is
supported yet.";
@@ -691,7 +701,6 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
const DataflowBlock& d
const auto var2val = AnalyzeVar2Value(dfb);
DFPatternMatcher matcher(var2val);
- // std::map<const VarNode*, std::set<const VarNode*>>
MatcherUseDefAnalysis ud_analysis;
ud_analysis.VisitBindingBlock_(dfb.get());
const auto& def2use = ud_analysis.def2use;
@@ -701,9 +710,8 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
const DataflowBlock& d
std::unordered_map<const VarNode*, RNode> var2node;
var2node.reserve(dfb->bindings.size());
- for (const auto& du : def2use) {
- const VarNode* cur_var = du.first;
- const std::set<const VarNode*>& uses = du.second;
+ for (const VarNode* cur_var : ud_analysis.vars) {
+ const auto& uses = def2use.at(cur_var);
RNode& cur_node = var2node[cur_var];
cur_node.ptr = cur_var;
for (const VarNode* use : uses) {
@@ -717,29 +725,24 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
const DataflowBlock& d
std::unordered_map<const DFPatternNode*, PNode> pattern2node;
pattern2node.reserve(ctx->constraints.size());
- for (const auto& def2use_pattern : ctx->constraints) {
- const DFPatternNode* def_pattern = def2use_pattern.first.get();
- const std::map<DFPattern, std::vector<PairCons>>& uses =
def2use_pattern.second;
- PNode& def_node = pattern2node[def_pattern];
- def_node.ptr = def_pattern;
+ for (const auto& [def_pattern, uses] : ctx->constraints) {
+ PNode& def_node = pattern2node[def_pattern.get()];
+ def_node.ptr = def_pattern.get();
def_node.children.reserve(uses.size());
- for (const auto& use : uses) {
- const auto& cons = use.second;
- const DFPatternNode* use_pattern = use.first.get();
- PNode& use_node = pattern2node[use_pattern];
- use_node.ptr = use_pattern;
+ for (const auto& [use_pattern, cons] : uses) {
+ PNode& use_node = pattern2node[use_pattern.get()];
+ use_node.ptr = use_pattern.get();
use_node.parents.emplace_back(&def_node, std::ref(cons));
def_node.children.emplace_back(&use_node, std::ref(cons));
}
}
- if (start_hint.defined()) {
- Var v = start_hint.value();
- auto rnode_ptr = var2node.find(v.get());
- for (auto& ppair : pattern2node) {
- if (try_match(&ppair.second, &rnode_ptr->second, &matcher, def2use,
caller2callees)) {
- for (auto ppair : pattern2node)
- ret.Set(GetRef<DFPattern>(ppair.first),
GetRef<Var>(ppair.second.matched));
+ if (start_hint) {
+ auto rnode_ptr = var2node.at(start_hint.value().get());
+ for (auto& p_node : pattern2node) {
+ if (try_match(&p_node.second, &rnode_ptr, &matcher, def2use,
caller2callees)) {
+ for (const auto& [df_pattern, pattern_node] : pattern2node)
+ ret.Set(GetRef<DFPattern>(df_pattern),
GetRef<Var>(pattern_node.matched));
return ret;
}
}
@@ -747,14 +750,15 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
const DataflowBlock& d
if (must_include_hint) return ret;
}
- PNode* pnode_start = &pattern2node.begin()->second;
+ PNode& pnode_start = pattern2node[ctx->src_ordered[0].get()];
- if (!pnode_start->matched) {
- for (auto& rpair : var2node) {
- if (start_hint.defined() && start_hint.value().get() == rpair.first)
continue;
- if (try_match(pnode_start, &rpair.second, &matcher, def2use,
caller2callees)) {
- for (auto ppair : pattern2node)
- ret.Set(GetRef<DFPattern>(ppair.first),
GetRef<Var>(ppair.second.matched));
+ if (!pnode_start.matched) {
+ for (const auto& var : ud_analysis.vars) {
+ if (start_hint.defined() && start_hint.value().get() == var) continue;
+ RNode& r_node = var2node[var];
+ if (try_match(&pnode_start, &r_node, &matcher, def2use, caller2callees))
{
+ for (const auto& [df_pattern, pattern_node] : pattern2node)
+ ret.Set(GetRef<DFPattern>(df_pattern),
GetRef<Var>(pattern_node.matched));
return ret;
}
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index a40faf3bcb..9679e14fff 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1006,5 +1006,50 @@ def test_rewrite_attention():
tvm.ir.assert_structural_equal(rewritten, expected)
+def test_attention_qkv():
+ @tvm.script.ir_module
+ class QKV_proj:
+ @R.function
+ def main(
+ x: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv0 = R.matmul(x, w0)
+ lv1 = R.matmul(x, w1)
+ lv2 = R.matmul(x, w2)
+ out = (lv0, lv1, lv2)
+ R.output(out)
+ return out
+
+ with PatternContext() as ctx:
+ inp_pat = wildcard()
+ Q_weight_pat = wildcard()
+ K_weight_pat = wildcard()
+ V_weight_pat = wildcard()
+
+ matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+ matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+ matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+ # TODO(masahi): Automate addition of used_by constraints during is_op
+ inp_pat.used_by(matmul1, 0)
+ inp_pat.used_by(matmul2, 0)
+ inp_pat.used_by(matmul3, 0)
+
+ Q_weight_pat.only_used_by(matmul1, 1)
+ K_weight_pat.only_used_by(matmul2, 1)
+ V_weight_pat.only_used_by(matmul3, 1)
+
+ dfb = QKV_proj["main"].body.blocks[0]
+ out = ctx.match_dfb(dfb)
+
+ assert out[Q_weight_pat].name_hint == "w0"
+ assert out[K_weight_pat].name_hint == "w1"
+ assert out[V_weight_pat].name_hint == "w2"
+
+
if __name__ == "__main__":
tvm.testing.main()