This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7e68b4d [PatternMatcher] Support matching tuples, call nodes, and
functions with variable numbers of inputs (#7754)
7e68b4d is described below
commit 7e68b4d413dfc0ec0042634a40d94e9bab59dbde
Author: Matthew Brookhart <[email protected]>
AuthorDate: Sat Apr 3 03:12:13 2021 -0600
[PatternMatcher] Support matching tuples, call nodes, and functions with
variable numbers of inputs (#7754)
* Allow TuplePattern to have null fields and match any tuple
* support matching functions and call nodes with variable numbers of
parameters
* remove development code that was commented out
* add docs for fuzzy matching
---
docs/langref/relay_pattern.rst | 16 ++++
python/tvm/relay/dataflow_pattern/__init__.py | 5 +-
src/relay/ir/dataflow_matcher.cc | 107 ++++++++++++++++++--------
src/relay/ir/dataflow_pattern_functor.cc | 18 +++--
src/relay/ir/indexed_graph.cc | 18 +++--
tests/python/relay/test_dataflow_pattern.py | 80 +++++++++++++++++--
6 files changed, 192 insertions(+), 52 deletions(-)
diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index d77a519..efb9804 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -307,6 +307,22 @@ The final example is matching diamonds with a
post-dominator relationship. We em
assert diamond.match(out)
+Matching Fuzzy Patterns
+=======================
+
+The Dominator analysis above lets one match a subgraph of Relay AST that
doesn't correspond to a set of patterns nodes exactly 1-to-1. There are a few
other places where we support such "fuzzy" matching.
+
+Tuples, Functions, and Call nodes with any number of inputs can be matched by
passing `None` as the argument value, i.e.::
+
+ tuple_pattern = is_tuple(None)
+ func_pattern = FunctionPattern(None, wildcard() + wildcard())
+ call_pattern = func_pattern(None)
+
+These patterns allow matching more generic classes patterns by constraining
the use of the arguments rather than the number of arguments.
+
+Additionally, we support matching Functions with fuzzy bodies, i.e., a
function body that is under constrained by the pattern. The pattern
`FunctionPattern([is_var(), is_var()], wildcard() + wildcard()])` will match
`relay.Function([x, y], x + y)`, but it will also match `relay.Function([x, y],
x * x + y)`. In the second case, the pattern doesn't perfectly constrain the
body of the function, so the resulting match is fuzzy.
+
+
Pattern Language Design
=======================
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py
b/python/tvm/relay/dataflow_pattern/__init__.py
index d4a8481..b368f4e 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -47,7 +47,10 @@ class DFPattern(Node):
"""Base class of all Patterns."""
def __call__(self, *args):
- return CallPattern(self, list(args))
+ args = list(args)
+ if len(args) == 1 and args[0] is None:
+ args = None
+ return CallPattern(self, args)
def __or__(self, other):
return AltPattern(self, other)
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 43a6473..6ed24d5 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -242,6 +242,7 @@ bool DFPatternMatcher::VisitDFPattern_(const
CallPatternNode* op, const Expr& ex
}
return false;
};
+
// logic
auto watermark = matched_nodes_.size();
if (const auto* call_node = expr.as<CallNode>()) {
@@ -253,13 +254,15 @@ bool DFPatternMatcher::VisitDFPattern_(const
CallPatternNode* op, const Expr& ex
const Array<Expr> expr_args) {
bool matches = true;
size_t i = 0;
- if (pattern_args.size() == expr_args.size()) {
- while (matches && i < pattern_args.size()) {
- matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
- ++i;
+ if (pattern_args.defined()) {
+ if (pattern_args.size() == expr_args.size()) {
+ while (matches && i < pattern_args.size()) {
+ matches &= VisitDFPattern(pattern_args[i], expr_args[i]);
+ ++i;
+ }
+ } else {
+ matches = false;
}
- } else {
- matches = false;
}
if (!matches) {
ClearMap(watermark2);
@@ -381,14 +384,16 @@ bool DFPatternMatcher::VisitDFPattern_(const
FunctionPatternNode* op, const Expr
bool matches = false;
if (const auto* func = expr.as<FunctionNode>()) {
matches = true;
- size_t i = 0;
- if (op->params.size() == func->params.size()) {
- while (matches && i < op->params.size()) {
- matches &= VisitDFPattern(op->params[i], func->params[i]);
- ++i;
+ if (op->params.defined()) {
+ size_t i = 0;
+ if (op->params.size() == func->params.size()) {
+ while (matches && i < op->params.size()) {
+ matches &= VisitDFPattern(op->params[i], func->params[i]);
+ ++i;
+ }
+ } else {
+ matches = false;
}
- } else {
- matches = false;
}
if (matches) {
matches &= VisitDFPattern(op->body, func->body);
@@ -409,12 +414,16 @@ bool DFPatternMatcher::VisitDFPattern_(const
TupleGetItemPatternNode* op, const
bool DFPatternMatcher::VisitDFPattern_(const TuplePatternNode* op, const Expr&
expr) {
bool matches = false;
if (const auto* tuple_node = expr.as<TupleNode>()) {
- if (op->fields.size() == tuple_node->fields.size()) {
- matches = true;
- size_t i = 0;
- while (matches && i < op->fields.size()) {
- matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
- ++i;
+ matches = true;
+ if (op->fields.defined()) {
+ if (op->fields.size() == tuple_node->fields.size()) {
+ size_t i = 0;
+ while (matches && i < op->fields.size()) {
+ matches &= VisitDFPattern(op->fields[i], tuple_node->fields[i]);
+ ++i;
+ }
+ } else {
+ matches = false;
}
}
}
@@ -657,7 +666,6 @@ class PatternGrouper {
int var_number = 0;
auto node_map = matcher_->GetMemo();
-
// Get fuzzy patterns
std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
for (auto node : pattern_graph_.topological_order_) {
@@ -669,11 +677,13 @@ class PatternGrouper {
}
}
}
- // Don't treat Function params as input variables for partition
- if (auto op = node->ref_.as<FunctionPatternNode>()) {
- for (auto fuzzy_op : op->params) {
- for (auto match : node_map[fuzzy_op]) {
- fuzzy_matches.insert(match);
+ // Don't treat Function params or body as input variables for partition
+ if (node->ref_.as<FunctionPatternNode>()) {
+ auto matches = node_map[node->ref_];
+ for (auto match : matches) {
+ auto graph = CreateIndexedGraph(match.as<FunctionNode>()->body);
+ for (auto node : graph.topological_order_) {
+ fuzzy_matches.insert(node->ref_);
}
}
}
@@ -686,22 +696,46 @@ class PatternGrouper {
std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual> inputs;
Array<Var> params;
+
for (auto node : pattern_graph_.topological_order_) {
- if (node->inputs_.size() == 0) {
+ auto make_input = [&](const Expr& input) {
+ if (fuzzy_matches.count(input) == 0 && input.as<OpNode>() == nullptr &&
+ input.as<FunctionNode>() == nullptr && !EmbedConst(input,
node->ref_)) {
+ inputs[input] =
+ Var("FunctionVar_" + std::to_string(graph_number_) + "_" +
std::to_string(var_number),
+ NullValue<Type>());
+ group.args.push_back(input);
+ params.push_back(inputs[input]);
+ var_number++;
+ }
+ };
+ auto tuple = node->ref_.as<TuplePatternNode>();
+ auto call = node->ref_.as<CallPatternNode>();
+ if (tuple && !tuple->fields.defined()) {
if (node_map.count(node->ref_)) {
auto matches = node_map[node->ref_];
for (auto match : matches) {
- if (fuzzy_matches.count(match) == 0 && match.as<OpNode>() ==
nullptr &&
- match.as<FunctionNode>() == nullptr && !EmbedConst(match,
node->ref_)) {
- inputs[match] = Var(
- "FunctionVar_" + std::to_string(graph_number_) + "_" +
std::to_string(var_number),
- NullValue<Type>());
- group.args.push_back(match);
- params.push_back(inputs[match]);
- var_number++;
+ for (auto input : match.as<TupleNode>()->fields) {
+ make_input(input);
}
}
}
+ } else if (call && !call->args.defined()) {
+ if (node_map.count(node->ref_)) {
+ auto matches = node_map[node->ref_];
+ for (auto match : matches) {
+ for (auto input : match.as<CallNode>()->args) {
+ make_input(input);
+ }
+ }
+ }
+ } else if (node->inputs_.size() == 0) {
+ if (node_map.count(node->ref_)) {
+ auto matches = node_map[node->ref_];
+ for (auto match : matches) {
+ make_input(match);
+ }
+ }
}
}
@@ -898,6 +932,11 @@ class PatternPartitioner : protected MixedModeMutator {
public:
Expr Partition(const DFPattern& pattern, const Expr& pre, const Map<String,
ObjectRef>& attrs,
PackedFunc check) {
+ if (pattern.as<FunctionPatternNode>()) {
+ LOG(WARNING) << "Partioning a Function that isn't called doesn't make
sense, skipping"
+ << pattern;
+ return pre;
+ }
auto grouper = PatternGrouper();
groups_ = grouper.GroupMatches(pattern, pre);
gid_assignments_ = grouper.GetGIDAssignments();
diff --git a/src/relay/ir/dataflow_pattern_functor.cc
b/src/relay/ir/dataflow_pattern_functor.cc
index 828e867..290f72d 100644
--- a/src/relay/ir/dataflow_pattern_functor.cc
+++ b/src/relay/ir/dataflow_pattern_functor.cc
@@ -45,8 +45,10 @@ void DFPatternVisitor::VisitDFPattern_(const
AttrPatternNode* op) { VisitDFPatte
void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) {
VisitDFPattern(op->op);
- for (auto arg : op->args) {
- VisitDFPattern(arg);
+ if (op->args.defined()) {
+ for (auto arg : op->args) {
+ VisitDFPattern(arg);
+ }
}
}
@@ -63,8 +65,10 @@ void DFPatternVisitor::VisitDFPattern_(const
DominatorPatternNode* op) {
void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}
void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {
- for (auto param : op->params) {
- VisitDFPattern(param);
+ if (op->params.defined()) {
+ for (auto param : op->params) {
+ VisitDFPattern(param);
+ }
}
VisitDFPattern(op->body);
}
@@ -76,8 +80,10 @@ void DFPatternVisitor::VisitDFPattern_(const
TupleGetItemPatternNode* op) {
}
void DFPatternVisitor::VisitDFPattern_(const TuplePatternNode* op) {
- for (auto field : op->fields) {
- VisitDFPattern(field);
+ if (op->fields.defined()) {
+ for (auto field : op->fields) {
+ VisitDFPattern(field);
+ }
}
}
diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc
index 36789e6..e4d9585 100644
--- a/src/relay/ir/indexed_graph.cc
+++ b/src/relay/ir/indexed_graph.cc
@@ -242,8 +242,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const
DFPattern& pattern) {
void VisitDFPattern_(const CallPatternNode* op, NodePtr parent) override {
VisitDFPattern(op->op, graph_.node_map_[GetRef<DFPattern>(op)]);
- for (auto arg : op->args) {
- VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
+ if (op->args.defined()) {
+ for (auto arg : op->args) {
+ VisitDFPattern(arg, graph_.node_map_[GetRef<DFPattern>(op)]);
+ }
}
}
@@ -262,8 +264,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const
DFPattern& pattern) {
void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}
void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent)
override {
- for (auto param : op->params) {
- VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
+ if (op->params.defined()) {
+ for (auto param : op->params) {
+ VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
+ }
}
VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
}
@@ -277,8 +281,10 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const
DFPattern& pattern) {
}
void VisitDFPattern_(const TuplePatternNode* op, NodePtr parent) override {
- for (auto field : op->fields) {
- VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
+ if (op->fields.defined()) {
+ for (auto field : op->fields) {
+ VisitDFPattern(field, graph_.node_map_[GetRef<DFPattern>(op)]);
+ }
}
}
diff --git a/tests/python/relay/test_dataflow_pattern.py
b/tests/python/relay/test_dataflow_pattern.py
index a8e4b65..8e2c74a 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -196,6 +196,11 @@ def test_match_call():
add_pattern = is_op("add")(wildcard(), wildcard())
assert add_pattern.match(x + y)
+ # Match call with any number of inputs
+ call_pattern = wildcard()(None)
+ assert call_pattern.match(relay.op.nn.relu(x))
+ assert call_pattern.match(relay.op.add(x, y))
+
def test_no_match_call():
x = relay.var("x")
@@ -212,6 +217,11 @@ def test_match_func():
func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
assert func_pattern.match(relay.Function([x, y], x + y))
+ # Match Function with any number of inputs
+ func_pattern = FunctionPattern(None, wildcard())
+ assert func_pattern.match(relay.Function([x], x))
+ assert func_pattern.match(relay.Function([x, y], x + y))
+
def test_no_match_func():
x = relay.var("x")
@@ -369,6 +379,13 @@ def test_match_tuple():
assert
tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y,
z)), 1))
assert
tuple_get_item_pattern.match(relay.expr.TupleGetItem(relay.expr.Tuple((x, y,
z)), 2))
+ # Match tuple with any inputs
+ tuple_pattern = is_tuple(None)
+ concat_pattern = is_op("concatenate")(tuple_pattern)
+ assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x,)),
axis=0))
+ assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y)),
axis=0))
+ assert concat_pattern.match(relay.op.concatenate(relay.expr.Tuple((x, y,
z)), axis=0))
+
def test_no_match_tuple():
x = relay.var("x")
@@ -1375,6 +1392,63 @@ def test_partition_overused():
assert pattern.partition(out) == out
+def test_partition_fuzzy_tuple():
+ x = relay.var("x")
+ y = relay.var("y")
+ z = x + y
+ tuple_pattern = is_tuple(None)
+ concat_pattern = is_op("concatenate")(tuple_pattern)
+
+ xp = relay.var("xp")
+ yp = relay.var("yp")
+ zp = relay.var("zp")
+
+ def create_func(args, body):
+ return relay.Function(args, body).with_attr("PartitionedFromPattern",
"Tuple_concatenate_")
+
+ def concat(*args):
+ return relay.op.concatenate(relay.expr.Tuple(args), axis=0)
+
+ one = concat_pattern.partition(concat(x))
+ assert tvm.ir.structural_equal(one, create_func([xp], concat(xp))(x))
+ two = concat_pattern.partition(concat(x, y))
+ assert tvm.ir.structural_equal(two, create_func([xp, yp], concat(xp,
yp))(x, y))
+ three = concat_pattern.partition(concat(x, y, z))
+ assert tvm.ir.structural_equal(three, create_func([xp, yp, zp], concat(xp,
yp, zp))(x, y, z))
+
+
+def test_partition_fuzzy_function_args():
+
+ func_pattern = FunctionPattern(None, wildcard() + wildcard())(None) +
wildcard()
+ x = relay.var("x")
+ y = relay.var("y")
+ z = relay.var("z")
+ b = relay.var("b")
+ xp = relay.var("xp")
+ yp = relay.var("yp")
+ zp = relay.var("zp")
+
+ def create_func(call):
+ N = len(call.op.params)
+ new_params = [relay.var(str(i)) for i in range(N + 1)]
+ label = "add_FunctionCall_add_"
+ if N == 3:
+ label = "add_" + label
+ return relay.Function(
+ new_params, relay.Call(call.op, (new_params[0:-1])) +
new_params[-1]
+ ).with_attr("PartitionedFromPattern", label)(*([x, y, z][0:N] + [b]))
+
+ f1 = relay.Function([xp], xp + xp)(x)
+ one = func_pattern.partition(f1 + b)
+ assert tvm.ir.structural_equal(one, create_func(f1))
+ f2 = relay.Function([xp, yp], xp + yp)(x, y)
+ two = func_pattern.partition(f2 + b)
+ assert tvm.ir.structural_equal(two, create_func(f2))
+ f3 = relay.Function([xp, yp, zp], xp + yp + zp)(x, y, z)
+ three = func_pattern.partition(f3 + b)
+ assert tvm.ir.structural_equal(three, create_func(f3))
+
+
def test_partition_check():
pattern = is_op("nn.relu")(is_op("nn.conv2d")(is_var("input"), wildcard()))
@@ -1529,10 +1603,6 @@ def test_rewrite_function_with_fuzzy_body():
assert tvm.ir.structural_equal(x + w, x + w)
[email protected](
- """TODO(mbrookhart): The current partitioner can't properly handle
- the partitioned inputs on the fuzzy body"""
-)
def test_partition_function_with_fuzzy_body():
"""
Allow Rewriting a function with a fuzzy body via dominator analysis
@@ -1560,7 +1630,7 @@ def test_partition_function_with_fuzzy_body():
w2 = relay.var("w2")
b2 = relay.var("b2")
func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr(
- "PartitionedFromPattern", "FunctionCall_add_"
+ "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_"
)
expr2 = func2(x, w, b) + b
assert tvm.ir.structural_equal(pattern.partition(expr), expr2)