This is an automated email from the ASF dual-hosted git repository.

comaniac pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a7503d  Add a FunctionPattern, remove unused attributes in 
CallPattern (#7151)
4a7503d is described below

commit 4a7503d9036ffcd3323959709c92a5e13816fd73
Author: Matthew Brookhart <mbrookh...@octoml.ai>
AuthorDate: Tue Dec 22 22:56:51 2020 -0700

    Add a FunctionPattern, remove unused attributes in CallPattern (#7151)
    
    * Add a FunctionPattern, remove unused attributes in CallPattern
    
    * update docs
---
 docs/langref/relay_pattern.rst                | 19 ++++++++
 include/tvm/relay/dataflow_pattern.h          | 69 +++++++++++++++++----------
 include/tvm/relay/dataflow_pattern_functor.h  |  3 ++
 python/tvm/relay/dataflow_pattern/__init__.py | 34 ++++++++-----
 src/relay/ir/dataflow_matcher.cc              | 42 ++++++++++++----
 src/relay/ir/dataflow_pattern.cc              | 30 ++++++++----
 src/relay/ir/dataflow_pattern_functor.cc      |  7 +++
 src/relay/ir/indexed_graph.cc                 |  7 +++
 src/relay/transforms/simplify_expr.cc         |  2 +-
 tests/python/relay/test_dataflow_pattern.py   | 61 +++++++++++++++++++++++
 10 files changed, 220 insertions(+), 54 deletions(-)

diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst
index 8b34b76..ff02e50 100644
--- a/docs/langref/relay_pattern.rst
+++ b/docs/langref/relay_pattern.rst
@@ -167,6 +167,19 @@ The next example is matching a pattern of batch_norm -> 
get(0) -> relu. Note tha
         out = relay.nn.relu(tuple_get_item_node)
         pat.match(out)
 
+If we have a pattern that crosses a function boundary, we might want to match 
the Function itself
+
+
+.. code-block:: python
+
+  def test_match_func():
+      x = relay.var("x")
+      y = relay.var("y")
+      wc1 = wildcard()
+      wc2 = wildcard()
+      func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
+      assert func_pattern.match(relay.Function([x, y], x + y))
+
 The next example is matching a constant node regarding its values. This is 
useful to check
 if a specific parameter in a subgraph has been bound or not.
 
@@ -283,6 +296,7 @@ The high level design is to introduce a language of 
patterns for now we propose
             | is_tuple_get_item(pattern, index = None)
             | pattern1 `|` pattern2
             | dominates(parent_pattern, path_pattern, child_pattern)
+            | FunctionPattern(params, body)
 
 The above language then provides a matching interface with both can select 
sub-graphs as well as verify that the graph does match the pattern.
 
@@ -332,6 +346,11 @@ Domination
 
 Match child pattern, find a match for the parent pattern, insuring that the 
child ultimately dominates the parrent (i.e., no nodes outside the pattern use 
outputs of the parent), and that ever node betwen the child and the pattern 
matches the path pattern.
 
+Function Pattern
+****************
+
+Match a Function with a body and parameters
+
 Applications
 ============
 
diff --git a/include/tvm/relay/dataflow_pattern.h 
b/include/tvm/relay/dataflow_pattern.h
index 11ac7e3..909a4fe 100644
--- a/include/tvm/relay/dataflow_pattern.h
+++ b/include/tvm/relay/dataflow_pattern.h
@@ -148,34 +148,9 @@ class CallPatternNode : public DFPatternNode {
   /*! \brief The arguments(inputs) of the call */
   tvm::Array<relay::DFPattern> args;
 
-  /*! \brief The additional attributes */
-  Attrs attrs;
-
-  /*!
-   * \brief The type arguments passed to polymorphic(template) function.
-   *
-   * This is the advance feature that is only used when the function is
-   * polymorphic. It is safe to be ignored in most cases. For example, in the
-   * following code, the type_args of addone call is [int].
-   *
-   * \code
-   *
-   * template<typename T>
-   * T addone(T a) { return a + 1; }
-   *
-   * void main() {
-   *   int x = addone<int>(10);
-   * }
-   *
-   * \endcode
-   */
-  tvm::Array<Type> type_args;
-
   void VisitAttrs(tvm::AttrVisitor* v) {
     v->Visit("op", &op);
     v->Visit("args", &args);
-    v->Visit("attrs", &attrs);
-    v->Visit("type_args", &type_args);
   }
 
   static constexpr const char* _type_key = 
"relay.dataflow_pattern.CallPattern";
@@ -184,10 +159,52 @@ class CallPatternNode : public DFPatternNode {
 
 class CallPattern : public DFPattern {
  public:
-  TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, 
Array<Type> type_args);
+  TVM_DLL CallPattern(DFPattern op, Array<DFPattern> args);
   TVM_DEFINE_OBJECT_REF_METHODS(CallPattern, DFPattern, CallPatternNode);
 };
 
+/*!
+ * \brief Relay Function container
+ * \sa Function
+ */
+class FunctionPatternNode : public DFPatternNode {
+ public:
+  /*! \brief Function parameters */
+  tvm::Array<DFPattern> params;
+  /*!
+   * \brief
+   * The expression which represents the computation of the function,
+   * the expression may reference the parameters, and the type of it
+   * or sub-expressions may reference the type variables.
+   */
+  DFPattern body;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("params", &params);
+    v->Visit("body", &body);
+  }
+
+  static constexpr const char* _type_key = 
"relay.dataflow_pattern.FunctionPattern";
+  TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPatternNode, DFPatternNode);
+};
+
+/*!
+ * \brief Managed reference to FunctionNode.
+ * \sa FunctionNode
+ */
+class FunctionPattern : public DFPattern {
+ public:
+  /*!
+   * \brief Constructor
+   * \param params The parameters of the function.
+   * \param body The body of the function.
+   */
+  TVM_DLL FunctionPattern(tvm::Array<DFPattern> params, DFPattern body);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(FunctionPattern, DFPattern, 
FunctionPatternNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(FunctionPatternNode);
+};
+
 /*! \brief Tuple of multiple Exprs */
 class TuplePattern;
 /*! \brief Tuple container */
diff --git a/include/tvm/relay/dataflow_pattern_functor.h 
b/include/tvm/relay/dataflow_pattern_functor.h
index 364daac..f04977b 100644
--- a/include/tvm/relay/dataflow_pattern_functor.h
+++ b/include/tvm/relay/dataflow_pattern_functor.h
@@ -87,6 +87,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
   virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
+  virtual R VisitDFPattern_(const FunctionPatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) 
DFPATTERN_FUNCTOR_DEFAULT;
   virtual R VisitDFPattern_(const TupleGetItemPatternNode* op,
                             Args... args) DFPATTERN_FUNCTOR_DEFAULT;
@@ -112,6 +113,7 @@ class DFPatternFunctor<R(const DFPattern& n, Args...)> {
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode);
+    RELAY_DFPATTERN_FUNCTOR_DISPATCH(FunctionPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode);
     RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode);
@@ -138,6 +140,7 @@ class DFPatternVisitor : public DFPatternFunctor<void(const 
DFPattern&)> {
   void VisitDFPattern_(const DataTypePatternNode* op) override;
   void VisitDFPattern_(const DominatorPatternNode* op) override;
   void VisitDFPattern_(const ExprPatternNode* op) override;
+  void VisitDFPattern_(const FunctionPatternNode* op) override;
   void VisitDFPattern_(const ShapePatternNode* op) override;
   void VisitDFPattern_(const TupleGetItemPatternNode* op) override;
   void VisitDFPattern_(const TuplePatternNode* op) override;
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py 
b/python/tvm/relay/dataflow_pattern/__init__.py
index 7178bff..233c696 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -504,24 +504,36 @@ class CallPattern(DFPattern):
     args: List[realy.dataflow_pattern.DFPattern]
         The arguments to the call.
 
-    attrs: Optional[tvm.ir.attrs.Attrs]
-        Attributes to the call, can be None
-
-    type_args: Optional[List[tvm.ir.type.Type]]
-        The additional type arguments, this is only
-        used in advanced usecase of template functions.
     """
 
     def __init__(
         self,
         op: "DFPattern",
         args: List["DFPattern"],
-        attrs: Optional[tvm.ir.attrs.Attrs] = None,
-        type_args: Optional[List[tvm.ir.type.Type]] = None,
     ):
-        if not type_args:
-            type_args = []
-        self.__init_handle_by_constructor__(ffi.CallPattern, op, args, attrs, 
type_args)
+        self.__init_handle_by_constructor__(ffi.CallPattern, op, args)
+
+
+@register_df_node
+class FunctionPattern(DFPattern):
+    """A pattern matching a function node in Relay.
+
+    Parameters
+    ----------
+    params: List[realy.dataflow_pattern.DFPattern]
+        The parameters to the Function.
+
+    body: realy.dataflow_pattern.DFPattern
+        The body fo the Function
+
+    """
+
+    def __init__(
+        self,
+        params: List["DFPattern"],
+        body: "DFPattern",
+    ):
+        self.__init_handle_by_constructor__(ffi.FunctionPattern, params, body)
 
 
 @register_df_node
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 44b8763..c5cc3dd 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -54,6 +54,7 @@ class DFPatternMatcher : public DFPatternFunctor<bool(const 
DFPattern&, const Ex
   bool VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const DominatorPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const ExprPatternNode* op, const Expr& expr) override;
+  bool VisitDFPattern_(const FunctionPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) override;
   bool VisitDFPattern_(const TupleGetItemPatternNode* op, const Expr& expr) 
override;
   bool VisitDFPattern_(const TuplePatternNode* op, const Expr& expr) override;
@@ -264,10 +265,8 @@ bool DFPatternMatcher::VisitDFPattern_(const 
CallPatternNode* op, const Expr& ex
                is_expr_op(call_node->args[1], "divide"))) {
             bool out = false;
             for (size_t arg_id = 0; arg_id < 2; ++arg_id) {
-              auto div = CallPattern(op->op, {arg_node->args[arg_id], 
op->args[1]}, op->attrs,
-                                     op->type_args);
-              auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 
1) % 2], div},
-                                     arg_node->attrs, arg_node->type_args);
+              auto div = CallPattern(op->op, {arg_node->args[arg_id], 
op->args[1]});
+              auto mul = CallPattern(arg_node->op, {arg_node->args[(arg_id + 
1) % 2], div});
               out = VisitDFPattern(mul, expr);
               if (out) {
                 return true;
@@ -286,10 +285,8 @@ bool DFPatternMatcher::VisitDFPattern_(const 
CallPatternNode* op, const Expr& ex
             if (is_pattern_op(arg_node, "divide") && is_expr_op(expr, 
"divide") &&
                 (is_expr_op(call_node->args[0], "multiply") ||
                  is_expr_op(call_node->args[1], "multiply"))) {
-              auto mul = CallPattern(op->op, {arg_node->args[0], 
op->args[(arg_id + 1) % 2]},
-                                     op->attrs, op->type_args);
-              auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]}, 
arg_node->attrs,
-                                     arg_node->type_args);
+              auto mul = CallPattern(op->op, {arg_node->args[0], 
op->args[(arg_id + 1) % 2]});
+              auto div = CallPattern(arg_node->op, {mul, arg_node->args[1]});
               return VisitDFPattern(div, expr);
             }
           }
@@ -356,6 +353,26 @@ bool DFPatternMatcher::VisitDFPattern_(const 
ExprPatternNode* op, const Expr& ex
   return StructuralEqual()(op->expr, expr);
 }
 
+bool DFPatternMatcher::VisitDFPattern_(const FunctionPatternNode* op, const 
Expr& expr) {
+  bool matches = false;
+  if (const auto* func = expr.as<FunctionNode>()) {
+    matches = true;
+    size_t i = 0;
+    if (op->params.size() == func->params.size()) {
+      while (matches && i < op->params.size()) {
+        matches &= VisitDFPattern(op->params[i], func->params[i]);
+        ++i;
+      }
+    } else {
+      matches = false;
+    }
+    if (matches) {
+      matches &= VisitDFPattern(op->body, func->body);
+    }
+  }
+  return matches;
+}
+
 bool DFPatternMatcher::VisitDFPattern_(const TupleGetItemPatternNode* op, 
const Expr& expr) {
   bool matches = false;
   if (const auto* tuple_get_item_node = expr.as<TupleGetItemNode>()) {
@@ -601,6 +618,7 @@ class PatternGrouper {
     // Get fuzzy patterns
     std::unordered_set<Expr, ObjectPtrHash, ObjectPtrEqual> fuzzy_matches;
     for (auto node : pattern_graph_.topological_order_) {
+      // Don't treat fuzzy Dominator patterns input variables for partition
       if (auto op = node->ref_.as<DominatorPatternNode>()) {
         for (auto fuzzy_op : {op->parent, op->path}) {
           for (auto match : node_map[fuzzy_op]) {
@@ -608,6 +626,14 @@ class PatternGrouper {
           }
         }
       }
+      // Don't treat Function params as input variables for partition
+      if (auto op = node->ref_.as<FunctionPatternNode>()) {
+        for (auto fuzzy_op : op->params) {
+          for (auto match : node_map[fuzzy_op]) {
+            fuzzy_matches.insert(match);
+          }
+        }
+      }
     }
 
     // Create input variables
diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc
index 4664e5f..46c53c8 100644
--- a/src/relay/ir/dataflow_pattern.cc
+++ b/src/relay/ir/dataflow_pattern.cc
@@ -81,27 +81,41 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
       p->stream << "ConstantPattern()";
     });
 
-CallPattern::CallPattern(DFPattern op, Array<DFPattern> args, Attrs attrs, 
Array<Type> type_args) {
+CallPattern::CallPattern(DFPattern op, Array<DFPattern> args) {
   ObjectPtr<CallPatternNode> n = make_object<CallPatternNode>();
   n->op = std::move(op);
   n->args = std::move(args);
-  n->attrs = std::move(attrs);
-  n->type_args = std::move(type_args);
   data_ = std::move(n);
 }
 
 TVM_REGISTER_NODE_TYPE(CallPatternNode);
 
 TVM_REGISTER_GLOBAL("relay.dataflow_pattern.CallPattern")
-    .set_body_typed([](DFPattern op, Array<DFPattern> args, Attrs attrs, 
Array<Type> type_args) {
-      return CallPattern(op, args, attrs, type_args);
-    });
+    .set_body_typed([](DFPattern op, Array<DFPattern> args) { return 
CallPattern(op, args); });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<CallPatternNode>([](const ObjectRef& ref, ReprPrinter* p) {
       auto* node = static_cast<const CallPatternNode*>(ref.get());
-      p->stream << "CallPatternNode(" << node->op << ", " << node->args << ", 
" << node->attrs
-                << ", " << node->type_args << ")";
+      p->stream << "CallPatternNode(" << node->op << ", " << node->args << ")";
+    });
+
+FunctionPattern::FunctionPattern(Array<DFPattern> params, DFPattern body) {
+  ObjectPtr<FunctionPatternNode> n = make_object<FunctionPatternNode>();
+  n->params = std::move(params);
+  n->body = std::move(body);
+  data_ = std::move(n);
+}
+TVM_REGISTER_NODE_TYPE(FunctionPatternNode);
+
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.FunctionPattern")
+    .set_body_typed([](Array<DFPattern> params, DFPattern body) {
+      return FunctionPattern(params, body);
+    });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+    .set_dispatch<FunctionPatternNode>([](const ObjectRef& ref, ReprPrinter* 
p) {
+      auto* node = static_cast<const FunctionPatternNode*>(ref.get());
+      p->stream << "FunctionPatternNode(" << node->params << ", " << 
node->body << ")";
     });
 
 TuplePattern::TuplePattern(tvm::Array<DFPattern> fields) {
diff --git a/src/relay/ir/dataflow_pattern_functor.cc 
b/src/relay/ir/dataflow_pattern_functor.cc
index 7e9f828..aaa4f84 100644
--- a/src/relay/ir/dataflow_pattern_functor.cc
+++ b/src/relay/ir/dataflow_pattern_functor.cc
@@ -62,6 +62,13 @@ void DFPatternVisitor::VisitDFPattern_(const 
DominatorPatternNode* op) {
 
 void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {}
 
+void DFPatternVisitor::VisitDFPattern_(const FunctionPatternNode* op) {
+  for (auto param : op->params) {
+    VisitDFPattern(param);
+  }
+  VisitDFPattern(op->body);
+}
+
 void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { 
VisitDFPattern(op->pattern); }
 
 void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) {
diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc
index 456bf02..4ba053c 100644
--- a/src/relay/ir/indexed_graph.cc
+++ b/src/relay/ir/indexed_graph.cc
@@ -261,6 +261,13 @@ IndexedGraph<DFPattern> CreateIndexedGraph(const 
DFPattern& pattern) {
 
     void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {}
 
+    void VisitDFPattern_(const FunctionPatternNode* op, NodePtr parent) 
override {
+      for (auto param : op->params) {
+        VisitDFPattern(param, graph_.node_map_[GetRef<DFPattern>(op)]);
+      }
+      VisitDFPattern(op->body, graph_.node_map_[GetRef<DFPattern>(op)]);
+    }
+
     void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override {
       VisitDFPattern(op->pattern, graph_.node_map_[GetRef<DFPattern>(op)]);
     }
diff --git a/src/relay/transforms/simplify_expr.cc 
b/src/relay/transforms/simplify_expr.cc
index 079b867..cb42ab0 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -46,7 +46,7 @@ class SimplifyReshape {
     x_ = WildcardPattern(make_object<WildcardPatternNode>());
     auto reshape1 = AltPattern(ExprPattern(reshape_op), 
ExprPattern(reverse_reshape_op));
     auto reshape2 = AltPattern(ExprPattern(reshape_op), 
ExprPattern(reverse_reshape_op));
-    pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_}, Attrs{}, 
{})}, Attrs{}, {});
+    pattern_ = CallPattern(reshape1, {CallPattern(reshape2, {x_})});
   }
 
   Expr callback(const Expr& pre, const Expr& post, const Map<DFPattern, 
Array<Expr>>& node_map) {
diff --git a/tests/python/relay/test_dataflow_pattern.py 
b/tests/python/relay/test_dataflow_pattern.py
index d4c169b..d99e55b 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -62,6 +62,19 @@ def test_CallPattern():
     assert isinstance(c.args[1], WildcardPattern)
 
 
+def test_FunctionPattern():
+    wc1 = wildcard()
+    wc2 = wildcard()
+    c = is_op("add")(wc1, wc2)
+    f = FunctionPattern([wc1, wc2], c)
+    assert isinstance(f, FunctionPattern)
+    assert isinstance(f.params[0], WildcardPattern)
+    assert isinstance(f.params[1], WildcardPattern)
+    assert isinstance(f.body, CallPattern)
+    assert isinstance(f.body.args[0], WildcardPattern)
+    assert isinstance(f.body.args[1], WildcardPattern)
+
+
 def test_TuplePattern():
     wc1 = wildcard()
     wc2 = wildcard()
@@ -167,6 +180,24 @@ def test_no_match_call():
     assert not add_pattern.match(x - y)
 
 
+def test_match_func():
+    x = relay.var("x")
+    y = relay.var("y")
+    wc1 = wildcard()
+    wc2 = wildcard()
+    func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
+    assert func_pattern.match(relay.Function([x, y], x + y))
+
+
+def test_no_match_func():
+    x = relay.var("x")
+    y = relay.var("y")
+    wc1 = wildcard()
+    wc2 = wildcard()
+    func_pattern = FunctionPattern([wc1, wc2], wc1 + wc2)
+    assert not func_pattern.match(relay.Function([x, y], x - y))
+
+
 def test_match_option():
     x = relay.var("x")
     w = relay.var("w")
@@ -1300,6 +1331,36 @@ def test_partition_option():
     assert tvm.ir.structural_equal(func(x, w, b), pattern2.partition(relu))
 
 
+def test_partition_function():
+    x = relay.var("x")
+    w = relay.var("w")
+    b = relay.var("b")
+
+    x1 = relay.var("x1")
+    w1 = relay.var("w1")
+
+    wc_x = wildcard()
+    wc_w = wildcard()
+    wc_b = wildcard()
+    wc_x1 = wildcard()
+    wc_w1 = wildcard()
+
+    func_pattern = FunctionPattern([wc_x1, wc_w1], is_op("nn.conv2d")(wc_x1, 
wc_w1))
+    pattern = func_pattern(wc_x, wc_w) + wc_b
+
+    func = relay.Function([x1, w1], relay.nn.conv2d(x1, w1))
+    expr = func(x, w) + b + b
+
+    x2 = relay.var("x2")
+    w2 = relay.var("w2")
+    b2 = relay.var("b2")
+    func2 = relay.Function([x2, w2, b2], func(x2, w2) + b2).with_attr(
+        "PartitionedFromPattern", "nn.conv2d_FunctionCall_add_"
+    )
+    expr2 = func2(x, w, b) + b
+    assert tvm.ir.structural_equal(pattern.partition(expr), expr2)
+
+
 def test_match_match():
     add_pattern = is_op("add")(wildcard(), wildcard())
 

Reply via email to