This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e889def  [PatternLang] Add a relay LetPattern (#7332)
e889def is described below

commit e889defc623f4e76589926fb89c58b5f5b5e66c8
Author: Matthew Brookhart <mbrookh...@octoml.ai>
AuthorDate: Sat Jan 23 01:37:21 2021 -0700

    [PatternLang] Add a relay LetPattern (#7332)
    
    * Add a relay LetPattern
    
    * fix If copy
    
    Co-authored-by: Cody Yu <comaniac0...@gmail.com>
    
    * fix If copy
    
    Co-authored-by: Cody Yu <comaniac0...@gmail.com>
    
    Co-authored-by: Cody Yu <comaniac0...@gmail.com>
---
 docs/langref/relay_pattern.rst                | 29 ++++++++++++++++++
 include/tvm/relay/dataflow_pattern.h          | 36 ++++++++++++++++++++++
 include/tvm/relay/dataflow_pattern_functor.h  | 11 ++++---
 python/tvm/relay/dataflow_pattern/__init__.py | 44 +++++++++++++++++++++++++++
 src/relay/ir/dataflow_matcher.cc              | 11 ++++++-
 src/relay/ir/dataflow_pattern.cc              | 22 ++++++++++++++
 src/relay/ir/dataflow_pattern_functor.cc      |  6 ++++
 src/relay/ir/indexed_graph.cc                 |  6 ++++
 tests/python/relay/test_dataflow_pattern.py   | 39 ++++++++++++++++++++++++
 9 files changed, 199 insertions(+), 5 deletions(-)

diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index 992954c..d77a519 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -246,6 +246,24 @@ are matched:
 
         assert pat.match(relay.expr.If(cond, x, y))
 
+
+A Relay ``Let`` expression can be matched if all of its variable, value, and 
body
+are matched:
+
+.. code-block:: python
+
+  def test_match_let():
+      x = is_var("x")
+      y = is_var("y")
+      let_var = is_var("let")
+      pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+      x = relay.var("x")
+      y = relay.var("y")
+      lv = relay.var("let")
+      cond = x < y
+      assert pat.match(relay.expr.Let(lv, cond, lv))
+
 Matching Diamonds and Post-Dominator Graphs
 *******************************************
 
@@ -310,6 +328,7 @@ The high level design is to introduce a language of 
patterns for now we propose
             | is_tuple()
             | is_tuple_get_item(pattern, index = None)
             | is_if(cond, tru, fls)
+            | is_let(var, value, body)
             | pattern1 `|` pattern2
             | dominates(parent_pattern, path_pattern, child_pattern)
             | FunctionPattern(params, body)
@@ -367,6 +386,16 @@ Function Pattern
 
 Match a Function with a body and parameters
 
+If Pattern
+**********
+
+Match an If with condition, true branch, and false branch
+
+Let Pattern
+***********
+
+Match a Let with a variable, value, and body
+
 Applications
 ============
 
diff --git a/include/tvm/relay/dataflow_pattern.h 
b/include/tvm/relay/dataflow_pattern.h
index 1b0c0ac..1e6cecf 100644
--- a/include/tvm/relay/dataflow_pattern.h
+++ b/include/tvm/relay/dataflow_pattern.h
@@ -222,6 +222,42 @@ class FunctionPattern : public DFPattern {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
 };
 
+/*! \brief A binding of a sub-network. */
+class LetPatternNode : public DFPatternNode {
+ public:
+  /*! \brief The variable we bind to */
+  DFPattern var;
+  /*! \brief The value we bind var to */
+  DFPattern value;
+  /*! \brief The body of the let binding */
+  DFPattern body;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("var", &var);
+    v->Visit("value", &value);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = "relay.dataflow_pattern.LetPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(LetPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief Let binding that binds a local var
+ */
+class LetPattern : public DFPattern {
+ public:
+  /*!
+   * \brief The constructor
+   * \param var The variable that is bound to.
+   * \param value The value used to bind to the variable.
+   * \param body The body of the let binding.
+   */
+  TVM_DLL LetPattern(DFPattern var, DFPattern value, DFPattern body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(LetPattern, DFPattern, LetPatternNode);
+};
+
 /*! \brief Tuple of multiple Exprs */
 class TuplePattern;
 /*! \brief Tuple container */
diff --git a/include/tvm/relay/dataflow_pattern_functor.h 
b/include/tvm/relay/dataflow_pattern_functor.h
index bff9e23..490cdc5 100644
--- a/include/tvm/relay/dataflow_pattern_functor.h
+++ b/include/tvm/relay/dataflow_pattern_functor.h
@@ -84,18 +84,19 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
   virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const LetPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
                             Args... args) DFPATTERN_FUNCTOR_DEFAULT;
-  virtual R VisitDFPattern_(const IfPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TypePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const VarPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
-  virtual R VisitDFPattern_(const ConstantPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const WildcardPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPatternDefault_(const Object* op, Args...) {
     LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
@@ -115,9 +116,10 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(LetPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
-    RELAY_DFPATTERN_FUNCTOR_DISPATCH(IfPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode);
@@ -143,10 +145,11 @@ class DFPatternVisitor : public 
DFPatternFunctor<void(const DFPattern&)> {
   void VisitDFPattern_(const DominatorPatternNode* op) override;
   void VisitDFPattern_(const ExprPatternNode* op) override;
   void VisitDFPattern_(const FunctionPatternNode* op) override;
+  void VisitDFPattern_(const IfPatternNode* op) override;
+  void VisitDFPattern_(const LetPatternNode* op) override;
   void VisitDFPattern_(const ShapePatternNode* op) override;
   void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
   void VisitDFPattern_(const TuplePatternNode* op) override;
-  void VisitDFPattern_(const IfPatternNode* op) override;
   void VisitDFPattern_(const TypePatternNode* op) override;
   void VisitDFPattern_(const VarPatternNode* op) override;
   void VisitDFPattern_(const WildcardPatternNode* op) override;
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py 
b/python/tvm/relay/dataflow_pattern/__init__.py
index 6f764e1..d4a8481 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -337,6 +337,29 @@ def is_if(cond, true_branch, false_branch):
     return IfPattern(cond, true_branch, false_branch)
 
 
+def is_let(var, value, body):
+    """
+    Syntatic sugar for creating a LetPattern.
+
+    Parameters
+    ----------
+    var: tvm.relay.dataflow_pattern.DFPattern
+        The pattern describing the variable of Let.
+
+    value: tvm.relay.dataflow_pattern.DFPattern
+        The pattern describing the value of Let.
+
+    body: tvm.relay.dataflow_pattern.DFPattern
+        The pattern describing the body where the binding is in effect.
+
+    Returns
+    -------
+    result: tvm.relay.dataflow_pattern.DFPattern
+        The resulting pattern.
+    """
+    return LetPattern(var, value, body)
+
+
 def wildcard() -> "DFPattern":
     """
     Syntatic sugar for creating a WildcardPattern.
@@ -580,6 +603,27 @@ class IfPattern(DFPattern):
 
 
 @register_df_node
+class LetPattern(DFPattern):
+    """A patern matching a Relay Let.
+
+    Parameters
+    ----------
+    var: tvm.relay.dataflow_pattern.DFPattern
+        The pattern describing the variable of Let.
+
+    value: tvm.relay.dataflow_pattern.DFPattern
+        The pattern describing the value of Let.
+
+    body: tvm.relay.dataflow_pattern.DFPattern
+        The pattern describing the body where the binding is in effect.
+
+    """
+
+    def __init__(self, var: "DFPattern", value: "DFPattern", body: 
"DFPattern"):
+        self.__init_handle_by_constructor__(ffi.LetPattern, var, value, body)
+
+
+@register_df_node
 class TuplePattern(DFPattern):
     """A patern matching a Relay Tuple.
 
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 459694b..0d94813 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -55,10 +55,11 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const 
DFPattern&, const Ex
   bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) 
override;
+  bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const LetPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
-  bool VisitDFPattern_(const IfPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TypePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const VarPatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) 
override;
@@ -423,6 +424,14 @@ bool DFPatternMatcher::VisitDFPattern_(const 
IfPatternNode* op, const Expr& expr
   return false;
 }
 
+bool DFPatternMatcher::VisitDFPattern_(const LetPatternNode* op, const Expr& 
expr) {
+  if (const auto* let_node = expr.as<LetNode>()) {
+    return VisitDFPattern(op->var, let_node->var) && VisitDFPattern(op->value, 
let_node->value) &&
+           VisitDFPattern(op->body, let_node->body);
+  }
+  return false;
+}
+
 Expr InferType(const Expr& expr) {
   auto mod = IRModule::FromExpr(expr);
   mod = transform::InferType()(mod);
diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc
index 1e268fb..4c3b82c 100644
--- a/src/relay/ir/dataflow_pattern.cc
+++ b/src/relay/ir/dataflow_pattern.cc
@@ -112,6 +112,28 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "FunctionPatternNode(" << node->params << ", " << 
node->body << ")";
     });
 
+LetPattern::LetPattern(DFPattern var, DFPattern value, DFPattern body) {
+  ObjectPtr<LetPatternNode> n = make_object<LetPatternNode>();
+  n->var = std::move(var);
+  n->value = std::move(value);
+  n->body = std::move(body);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(LetPatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.LetPattern")
+    .set_body_typed([](DFPattern var, DFPattern value, DFPattern body) {
+      return LetPattern(var, value, body);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<LetPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
+      auto* node = static_cast<const LetPatternNode*>(ref.get());
+      p->stream << "LetPatternNode(" << node->var << ", " << node->value << ", 
" << node->body
+                << ")";
+    });
+
 IfPattern::IfPattern(DFPattern cond, DFPattern true_branch, DFPattern 
false_branch) {
   ObjectPtr<IfPatternNode> n = make_object<IfPatternNode>();
   n->cond = std::move(cond);
diff --git a/src/relay/ir/dataflow_pattern_functor.cc 
b/src/relay/ir/dataflow_pattern_functor.cc
index 25b2473..828e867 100644
--- a/src/relay/ir/dataflow_pattern_functor.cc
+++ b/src/relay/ir/dataflow_pattern_functor.cc
@@ -87,6 +87,12 @@ void DFPatternVisitor::VisitDFPattern_(const IfPatternNode* 
op) {
   VisitDFPattern(op->false_branch);
 }
 
+void DFPatternVisitor::VisitDFPattern_(const LetPatternNode* op) {
+  VisitDFPattern(op->var);
+  VisitDFPattern(op->value);
+  VisitDFPattern(op->body);
+}
+
 void DFPatternVisitor::VisitDFPattern_(const TypePatternNode* op) { 
VisitDFPattern(op->pattern); }
 
 void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}
diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc
index 9ee5c9c..0f81c23 100644
--- a/src/relay/ir/indexed_graph.cc
+++ b/src/relay/ir/indexed_graph.cc
@@ -288,6 +288,12 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const 
DFPattern& pattern) {
       VisitDFPattern(op->false_branch, 
graph_.node_map_[GetRef<DFPattern>(op)]);
     }
 
+    void VisitDFPattern_(const LetPatternNode* op, NodePtr parent) override {
+      VisitDFPattern(op->var, graph_.node_map_[GetRef<DFPattern>(op)]);
+      VisitDFPattern(op->value, graph_.node_map_[GetRef<DFPattern>(op)]);
+      VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
+    }
+
     void VisitDFPattern_(const TypePatternNode* op, NodePtr parent) override {
       VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
     }
diff --git a/tests/python/relay/test_dataflow_pattern.py 
b/tests/python/relay/test_dataflow_pattern.py
index 934ebf4..e7b367b 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -138,6 +138,18 @@ def test_IfPattern():
     assert isinstance(pat.false_branch, VarPattern)
 
 
+def test_LetPattern():
+    x = is_var("x")
+    y = is_var("y")
+    let_var = is_var("let")
+    pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+    assert isinstance(pat, LetPattern)
+    assert isinstance(pat.var, VarPattern)
+    assert isinstance(pat.value, CallPattern)
+    assert isinstance(pat.body, VarPattern)
+
+
 ## MATCHER TESTS
 
 
@@ -233,6 +245,33 @@ def test_no_match_if():
     assert not pat.match(relay.expr.If(x < y, y, x))
 
 
+def test_match_let():
+    x = is_var("x")
+    y = is_var("y")
+    let_var = is_var("let")
+    pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+    x = relay.var("x")
+    y = relay.var("y")
+    lv = relay.var("let")
+    cond = x < y
+    assert pat.match(relay.expr.Let(lv, cond, lv))
+
+
+def test_no_match_let():
+    x = is_var("x")
+    y = is_var("y")
+    let_var = is_var("let")
+    pat = is_let(let_var, is_op("less")(x, y), let_var)
+
+    x = relay.var("x")
+    y = relay.var("y")
+    lv = relay.var("let")
+
+    assert not pat.match(relay.expr.Let(lv, x > y, lv))
+    assert not pat.match(relay.expr.Let(lv, x < y, lv * x))
+
+
 def test_match_option():
     x = relay.var("x")
     w = relay.var("w")

Reply via email to