This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e889def [PatternLang] Add a relay LetPattern (#7332)
e889def is described below
commit e889defc623f4e76589926fb89c58b5f5b5e66c8
Author: Matthew Brookhart <[email protected]>
AuthorDate: Sat Jan 23 01:37:21 2021 -0700
[PatternLang] Add a relay LetPattern (#7332)
* Add a relay LetPattern
* fix If copy
Co-authored-by: Cody Yu <[email protected]>
* fix If copy
Co-authored-by: Cody Yu <[email protected]>
Co-authored-by: Cody Yu <[email protected]>
---
docs/langref/relay_pattern.rst | 29 ++++++++++++++++++
include/tvm/relay/dataflow_pattern.h | 36 ++++++++++++++++++++++
include/tvm/relay/dataflow_pattern_functor.h | 11 ++++---
python/tvm/relay/dataflow_pattern/__init__.py | 44 +++++++++++++++++++++++++++
src/relay/ir/dataflow_matcher.cc | 11 ++++++-
src/relay/ir/dataflow_pattern.cc | 22 ++++++++++++++
src/relay/ir/dataflow_pattern_functor.cc | 6 ++++
src/relay/ir/indexed_graph.cc | 6 ++++
tests/python/relay/test_dataflow_pattern.py | 39 ++++++++++++++++++++++++
9 files changed, 199 insertions(+), 5 deletions(-)
diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index 992954c..d77a519 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -246,6 +246,24 @@ are matched:
assert pat.match(relay.expr.If(cond, x, y))
+
+A Relay ``Let`` expression can be matched if all of its variable, value, and
body
+are matched:
+
+.. code-block:: python
+
+ def test_match_let():
+ x = is_var("x")
+ y = is_var("y")
+ let_var = is_var("let")
+ pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+ x = relay.var("x")
+ y = relay.var("y")
+ lv = relay.var("let")
+ cond = x < y
+ assert pat.match(relay.expr.Let(lv, cond, lv))
+
Matching Diamonds and Post-Dominator Graphs
*******************************************
@@ -310,6 +328,7 @@ The high level design is to introduce a language of
patterns for now we propose
| is_tuple()
| is_tuple_get_item(pattern, index = None)
| is_if(cond, tru, fls)
+ | is_let(var, value, body)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
| FunctionPattern(params, body)
@@ -367,6 +386,16 @@ Function Pattern
Match a Function with a body and parameters
+If Pattern
+**********
+
+Match an If with condition, true branch, and false branch
+
+Let Pattern
+***********
+
+Match a Let with a variable, value, and body
+
Applications
============
diff --git a/include/tvm/relay/dataflow_pattern.h
b/include/tvm/relay/dataflow_pattern.h
index 1b0c0ac..1e6cecf 100644
--- a/include/tvm/relay/dataflow_pattern.h
+++ b/include/tvm/relay/dataflow_pattern.h
@@ -222,6 +222,42 @@ class FunctionPattern : public DFPattern {
TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
};
+/*! \brief A binding of a sub-network. */
+class LetPatternNode : public DFPatternNode {
+ public:
+ /*! \brief The variable we bind to */
+ DFPattern var;
+ /*! \brief The value we bind var to */
+ DFPattern value;
+ /*! \brief The body of the let binding */
+ DFPattern body;
+
+ void VisitAttrs(tvm::AttrVisitor* v) {
+ v->Visit("var", &var);
+ v->Visit("value", &value);
+ v->Visit("body", &body);
+ }
+
+ static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern";
+ TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief Let binding that binds a local var
+ */
+class LetPattern : public DFPattern {
+ public:
+ /*!
+ * \brief The constructor
+ * \param var The variable that is bound to.
+ * \param value The value used to bind to the variable.
+ * \param body The body of the let binding.
+ */
+ TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode);
+};
+
/*! \brief Tuple of multiple Exprs */
class TuplePattern;
/*! \brief Tuple container */
diff --git a/include/tvm/relay/dataflow_pattern_functor.h
b/include/tvm/relay/dataflow_pattern_functor.h
index bff9e23..490cdc5 100644
--- a/include/tvm/relay/dataflow_pattern_functor.h
+++ b/include/tvm/relay/dataflow_pattern_functor.h
@@ -84,18 +84,19 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
virtual R VisitDFPattern_(const AltPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const CallPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitDFPattern_(const IfPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
+ virtual R VisitDFPattern_(const LetPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
Args... args) DFPATTERN_FUNCTOR_DEFAULT;
- virtual R VisitDFPattern_(const IfPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const TypePatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const VarPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
- virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args)
DFPATTERN_FUNCTOR_DEFAULT;
virtual R VisitDFPatternDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
@@ -115,9 +116,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
+ RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
+ RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
- RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
@@ -143,10 +145,11 @@ class DFPatternVisitor : public
DFPatternFunctor<void(const DFPattern&)> {
void VisitDFPattern_(const DominatorPatternNode* op) override;
void VisitDFPattern_(const ExprPatternNode* op) override;
void VisitDFPattern_(const FunctionPatternNode* op) override;
+ void VisitDFPattern_(const IfPatternNode* op) override;
+ void VisitDFPattern_(const LetPatternNode* op) override;
void VisitDFPattern_(const ShapePatternNode* op) override;
void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
void VisitDFPattern_(const TuplePatternNode* op) override;
- void VisitDFPattern_(const IfPatternNode* op) override;
void VisitDFPattern_(const TypePatternNode* op) override;
void VisitDFPattern_(const VarPatternNode* op) override;
void VisitDFPattern_(const WildcardPatternNode* op) override;
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py
b/python/tvm/relay/dataflow_pattern/__init__.py
index 6f764e1..d4a8481 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -337,6 +337,29 @@ def is_if(cond, true_branch, false_branch):
return IfPattern(cond, true_branch, false_branch)
+def is_let(var, value, body):
+ """
+ Syntatic sugar for creating a LetPattern.
+
+ Parameters
+ ----------
+ var: tvm.relay.dataflow_pattern.DFPattern
+ The pattern describing the variable of Let.
+
+ value: tvm.relay.dataflow_pattern.DFPattern
+ The pattern describing the value of Let.
+
+ body: tvm.relay.dataflow_pattern.DFPattern
+ The pattern describing the body where the binding is in effect.
+
+ Returns
+ -------
+ result: tvm.relay.dataflow_pattern.DFPattern
+ The resulting pattern.
+ """
+ return LetPattern(var, value, body)
+
+
def wildcard() -> "DFPattern":
"""
Syntatic sugar for creating a WildcardPattern.
@@ -580,6 +603,27 @@ class IfPattern(DFPattern):
@register_df_node
+class LetPattern(DFPattern):
+ """A patern matching a Relay Let.
+
+ Parameters
+ ----------
+ var: tvm.relay.dataflow_pattern.DFPattern
+ The pattern describing the variable of Let.
+
+ value: tvm.relay.dataflow_pattern.DFPattern
+ The pattern describing the value of Let.
+
+ body: tvm.relay.dataflow_pattern.DFPattern
+ The pattern describing the body where the binding is in effect.
+
+ """
+
+ def __init__(self, var: "DFPattern", value: "DFPattern", body:
"DFPattern"):
+ self.__init_handle_by_constructor__(ffi.LetPattern, var, value, body)
+
+
+@register_df_node
class TuplePattern(DFPattern):
"""A patern matching a Relay Tuple.
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 459694b..0d94813 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -55,10 +55,11 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const
DFPattern&, const Ex
bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr)
override;
bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr)
override;
+ bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
+ bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr)
override;
bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
- bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr)
override;
@@ -423,6 +424,14 @@ bool DFPatternMatcher::VisitDFPattern_(const
IfPatternNode* op, const Expr& expr
return false;
}
+bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr&
expr) {
+ if (const auto* let_node = expr.as<LetNode>()) {
+ return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value,
let_node->value) &&
+ VisitDFPattern(op->body, let_node->body);
+ }
+ return false;
+}
+
Expr InferType(const Expr& expr) {
auto mod = IRModule::FromExpr(expr);
mod = transform::InferType()(mod);
diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc
index 1e268fb..4c3b82c 100644
--- a/src/relay/ir/dataflow_pattern.cc
+++ b/src/relay/ir/dataflow_pattern.cc
@@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
p->stream << "FunctionPatternNode(" << node->params << ", " <<
node->body << ")";
});
+LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) {
+ ObjectPtr<LetPatternNode> n = make_object<LetPatternNode>();
+ n->var = std::move(var);
+ n->value = std::move(value);
+ n->body = std::move(body);
+ data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(LetPatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern")
+ .set_body_typed([](DFPattern var, DFPattern value, DFPattern body) {
+ return LetPattern(var, value, body);
+ });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const LetPatternNode*>(ref.get());
+ p->stream << "LetPatternNode(" << node->var << ", " << node->value << ",
" << node->body
+ << ")";
+ });
+
IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern
false_branch) {
ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
n->cond = std::move(cond);
diff --git a/src/relay/ir/dataflow_pattern_functor.cc
b/src/relay/ir/dataflow_pattern_functor.cc
index 25b2473..828e867 100644
--- a/src/relay/ir/dataflow_pattern_functor.cc
+++ b/src/relay/ir/dataflow_pattern_functor.cc
@@ -87,6 +87,12 @@ void DFPatternVisitor::VisitDFPattern_(const IfPatternNode*
op) {
VisitDFPattern(op->false_branch);
}
+void DFPatternVisitor::VisitDFPattern_(const LetPatternNode* op) {
+ VisitDFPattern(op->var);
+ VisitDFPattern(op->value);
+ VisitDFPattern(op->body);
+}
+
void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) {
VisitDFPattern(op->pattern); }
void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc
index 9ee5c9c..0f81c23 100644
--- a/src/relay/ir/indexed_graph.cc
+++ b/src/relay/ir/indexed_graph.cc
@@ -288,6 +288,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const
DFPattern& pattern) {
VisitDFPattern(op->false_branch,
graph_.node_map_[GetRef<DFPattern>(op)]);
}
+ void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override {
+ VisitDFPattern(op->var, graph_.node_map_[GetRef<DFPattern>(op)]);
+ VisitDFPattern(op->value, graph_.node_map_[GetRef<DFPattern>(op)]);
+ VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
+ }
+
void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override {
VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
}
diff --git a/tests/python/relay/test_dataflow_pattern.py
b/tests/python/relay/test_dataflow_pattern.py
index 934ebf4..e7b367b 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -138,6 +138,18 @@ def test_IfPattern():
assert isinstance(pat.false_branch, VarPattern)
+def test_LetPattern():
+ x = is_var("x")
+ y = is_var("y")
+ let_var = is_var("let")
+ pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+ assert isinstance(pat, LetPattern)
+ assert isinstance(pat.var, VarPattern)
+ assert isinstance(pat.value, CallPattern)
+ assert isinstance(pat.body, VarPattern)
+
+
## MATCHER TESTS
@@ -233,6 +245,33 @@ def test_no_match_if():
assert not pat.match(relay.expr.If(x < y, y, x))
+def test_match_let():
+ x = is_var("x")
+ y = is_var("y")
+ let_var = is_var("let")
+ pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+ x = relay.var("x")
+ y = relay.var("y")
+ lv = relay.var("let")
+ cond = x < y
+ assert pat.match(relay.expr.Let(lv, cond, lv))
+
+
+def test_no_match_let():
+ x = is_var("x")
+ y = is_var("y")
+ let_var = is_var("let")
+ pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+ x = relay.var("x")
+ y = relay.var("y")
+ lv = relay.var("let")
+
+ assert not pat.match(relay.expr.Let(lv, x > y, lv))
+ assert not pat.match(relay.expr.Let(lv, x < y, lv * x))
+
+
def test_match_option():
x = relay.var("x")
w = relay.var("w")