This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9bcf0bc107 [Relay] add redirecting operation to dataflow pattern graph 
(#15392)
9bcf0bc107 is described below

commit 9bcf0bc107b6aacf901617a74595b34c0fa7c0df
Author: 电线杆 <[email protected]>
AuthorDate: Sat Aug 5 06:27:52 2023 +0800

    [Relay] add redirecting operation to dataflow pattern graph (#15392)
    
    * Add redirecting operation to dataflow pattern graph
    
    * Lint
---
 include/tvm/relay/dataflow_pattern.h          |  6 +++
 python/tvm/relay/dataflow_pattern/__init__.py | 13 +++++++
 src/relay/ir/dataflow_matcher.cc              |  6 ++-
 src/relay/ir/dataflow_pattern.cc              | 10 +++++
 src/relay/ir/dataflow_pattern_functor.cc      |  6 ++-
 src/relay/ir/indexed_graph.cc                 |  7 +++-
 tests/python/relay/test_dataflow_pattern.py   | 56 +++++++++++++++++++++++++++
 7 files changed, 101 insertions(+), 3 deletions(-)

diff --git a/include/tvm/relay/dataflow_pattern.h 
b/include/tvm/relay/dataflow_pattern.h
index 8c30a0df9f..040372db35 100644
--- a/include/tvm/relay/dataflow_pattern.h
+++ b/include/tvm/relay/dataflow_pattern.h
@@ -362,6 +362,10 @@ class WildcardPatternNode : public DFPatternNode {
  public:
   void VisitAttrs(tvm::AttrVisitor* v) {}
 
+  /*! \brief If the wildcard is redirected, then pattern is not nullptr, and 
the wildcard
+   * redirects to the pattern. */
+  Optional<DFPattern> pattern{nullptr};
+
   static constexpr const char* _type_key = 
"relay.dataflow_pattern.WildcardPattern";
   TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
 };
@@ -372,6 +376,8 @@ class WildcardPatternNode : public DFPatternNode {
 class WildcardPattern : public DFPattern {
  public:
   TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, 
WildcardPatternNode);
+
+  void redirect_to(DFPattern pat) const;
 };
 
 class TypePattern;
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py 
b/python/tvm/relay/dataflow_pattern/__init__.py
index 96950a2e47..76a24c048c 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -722,6 +722,19 @@ class WildcardPattern(DFPattern):
     def __init__(self):
         self.__init_handle_by_constructor__(ffi.WildcardPattern)
 
+    def redirect_to(
+        self,
+        pat: "DFPattern",
+    ):
+        """Redirect the WildcardPattern to another pattern
+
+        Parameters
+        ----------
+        pat: relay.dataflow_pattern.DFPattern
+            The pattern that wildcard is redirected to.
+        """
+        ffi.WildcardPattern_redirect_to(self, pat)
+
 
 @register_df_node
 class TypePattern(DFPattern):
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 249f4ccf7a..ee585446cb 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -488,7 +488,11 @@ bool DFPatternMatcher::VisitDFPattern_(const 
ConstantPatternNode* op, const Expr
 }
 
 bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const 
Expr& expr) {
-  return true;
+  if (op->pattern) {
+    return VisitDFPattern(op->pattern.value(), expr);
+  } else {
+    return true;
+  }
 }
 
 bool MatchPattern(DFPattern pattern, Expr expr) {
diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc
index c141ca51ef..637cb0665d 100644
--- a/src/relay/ir/dataflow_pattern.cc
+++ b/src/relay/ir/dataflow_pattern.cc
@@ -344,8 +344,18 @@ TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
                        << ")";
     });
 
+void WildcardPattern::redirect_to(DFPattern pat) const {
+  WildcardPatternNode* ptr = static_cast<WildcardPatternNode*>(get_mutable());
+  ptr->pattern = pat;
+}
+
 TVM_REGISTER_NODE_TYPE(WildcardPatternNode);
 
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern_redirect_to")
+    .set_body_typed([](WildcardPattern wildcard, DFPattern pat) {
+      return wildcard.redirect_to(pat);
+    });
+
 
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]()
 {
   auto w = WildcardPattern(make_object<WildcardPatternNode>());
   return w;
diff --git a/src/relay/ir/dataflow_pattern_functor.cc 
b/src/relay/ir/dataflow_pattern_functor.cc
index 290f72df1d..76b3fe068e 100644
--- a/src/relay/ir/dataflow_pattern_functor.cc
+++ b/src/relay/ir/dataflow_pattern_functor.cc
@@ -105,7 +105,11 @@ void DFPatternVisitor::VisitDFPattern_(const 
VarPatternNode* op) {}
 
 void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}
 
-void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}
+void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {
+  if (op->pattern) {
+    VisitDFPattern(op->pattern.value());
+  }
+}
 
 }  // namespace relay
 }  // namespace tvm
diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc
index 044884f87e..f10920769d 100644
--- a/src/relay/ir/indexed_graph.cc
+++ b/src/relay/ir/indexed_graph.cc
@@ -537,7 +537,12 @@ std::unique_ptr<IndexedGraph<DFPattern>> 
CreateIndexedGraph(const DFPattern& pat
 
     void VisitDFPattern_(const VarPatternNode* op) override {}
 
-    void VisitDFPattern_(const WildcardPatternNode* op) override {}
+    void VisitDFPattern_(const WildcardPatternNode* op) override {
+      if (op->pattern) {
+        auto node = graph_->item_to_node(GetRef<WildcardPattern>(op));
+        AddOutput(op->pattern.value(), node);
+      }
+    }
 
     std::unique_ptr<IndexedGraph<DFPattern>> graph_;
   };
diff --git a/tests/python/relay/test_dataflow_pattern.py 
b/tests/python/relay/test_dataflow_pattern.py
index c4a83735ce..3950c02c08 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -1995,5 +1995,61 @@ def test_partition_parallel_branch_with_same_input():
     assert tvm.ir.structural_equal(partitioned, reference)
 
 
+def test_rewrite_with_pattern_recursion():
+    data = relay.var("data", relay.TensorType((2, 8), "float32"))
+    dense_weight = relay.const(np.zeros((4, 8)))
+    feat = relay.nn.dense(data, dense_weight)
+    feat = relay.cast(feat, "float32")
+    feat = relay.cast(feat, "float32")
+    feat = relay.cast(feat, "float32")
+    feat = relay.cast(feat, "float32")
+    feat = relay.cast(feat, "float32")
+    oup = relay.cast(feat, "float32")
+
+    expected = relay.nn.relu(oup)
+
+    class TheRewrite(DFPatternCallback):
+        def __init__(self, pattern):
+            super(TheRewrite, self).__init__(rewrite_once=True)
+            self.pattern = pattern
+
+        def callback(self, pre, post, node_map):
+            return relay.nn.relu(post)
+
+    def test_reset_call_args():
+        dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
+        wildcard_redirect = wildcard()
+        the_pattern = is_op("cast")(wildcard_redirect)
+        the_pattern2 = the_pattern | dense_pattern
+        wildcard_redirect.redirect_to(the_pattern2)
+
+        actual = rewrite(TheRewrite(the_pattern), oup)
+        tvm.ir.assert_structural_equal(actual, expected)
+
+    def test_reset_alt_left():
+        dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
+        wildcard_redirect = wildcard()
+        or_pattern = wildcard_redirect | dense_pattern
+        the_pattern = is_op("cast")(or_pattern)
+        wildcard_redirect.redirect_to(the_pattern)
+
+        actual = rewrite(TheRewrite(the_pattern), oup)
+        tvm.ir.assert_structural_equal(actual, expected)
+
+    def test_reset_alt_right():
+        dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
+        wildcard_redirect = wildcard()
+        or_pattern = dense_pattern | wildcard_redirect
+        the_pattern = is_op("cast")(or_pattern)
+        wildcard_redirect.redirect_to(the_pattern)
+
+        actual = rewrite(TheRewrite(the_pattern), oup)
+        tvm.ir.assert_structural_equal(actual, expected)
+
+    test_reset_call_args()
+    test_reset_alt_left()
+    test_reset_alt_right()
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to