This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9bcf0bc107 [Relay] add redirecting operation to dataflow pattern graph
(#15392)
9bcf0bc107 is described below
commit 9bcf0bc107b6aacf901617a74595b34c0fa7c0df
Author: 电线杆 <[email protected]>
AuthorDate: Sat Aug 5 06:27:52 2023 +0800
[Relay] add redirecting operation to dataflow pattern graph (#15392)
* Add redirecting operation to dataflow pattern graph
* Lint
---
include/tvm/relay/dataflow_pattern.h | 6 +++
python/tvm/relay/dataflow_pattern/__init__.py | 13 +++++++
src/relay/ir/dataflow_matcher.cc | 6 ++-
src/relay/ir/dataflow_pattern.cc | 10 +++++
src/relay/ir/dataflow_pattern_functor.cc | 6 ++-
src/relay/ir/indexed_graph.cc | 7 +++-
tests/python/relay/test_dataflow_pattern.py | 56 +++++++++++++++++++++++++++
7 files changed, 101 insertions(+), 3 deletions(-)
diff --git a/include/tvm/relay/dataflow_pattern.h
b/include/tvm/relay/dataflow_pattern.h
index 8c30a0df9f..040372db35 100644
--- a/include/tvm/relay/dataflow_pattern.h
+++ b/include/tvm/relay/dataflow_pattern.h
@@ -362,6 +362,10 @@ class WildcardPatternNode : public DFPatternNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}
+ /*! \brief If the wildcard is redirected, then pattern is not nullptr, and
the wildcard
+ * redirects to the pattern. */
+ Optional<DFPattern> pattern{nullptr};
+
static constexpr const char* _type_key =
"relay.dataflow_pattern.WildcardPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
};
@@ -372,6 +376,8 @@ class WildcardPatternNode : public DFPatternNode {
class WildcardPattern : public DFPattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern,
WildcardPatternNode);
+
+ void redirect_to(DFPattern pat) const;
};
class TypePattern;
diff --git a/python/tvm/relay/dataflow_pattern/__init__.py
b/python/tvm/relay/dataflow_pattern/__init__.py
index 96950a2e47..76a24c048c 100644
--- a/python/tvm/relay/dataflow_pattern/__init__.py
+++ b/python/tvm/relay/dataflow_pattern/__init__.py
@@ -722,6 +722,19 @@ class WildcardPattern(DFPattern):
def __init__(self):
self.__init_handle_by_constructor__(ffi.WildcardPattern)
+ def redirect_to(
+ self,
+ pat: "DFPattern",
+ ):
+ """Redirect the WildcardPattern to another pattern
+
+ Parameters
+ ----------
+ pat: relay.dataflow_pattern.DFPattern
+ The pattern that wildcard is redirected to.
+ """
+ ffi.WildcardPattern_redirect_to(self, pat)
+
@register_df_node
class TypePattern(DFPattern):
diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc
index 249f4ccf7a..ee585446cb 100644
--- a/src/relay/ir/dataflow_matcher.cc
+++ b/src/relay/ir/dataflow_matcher.cc
@@ -488,7 +488,11 @@ bool DFPatternMatcher::VisitDFPattern_(const
ConstantPatternNode* op, const Expr
}
bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const
Expr& expr) {
- return true;
+ if (op->pattern) {
+ return VisitDFPattern(op->pattern.value(), expr);
+ } else {
+ return true;
+ }
}
bool MatchPattern(DFPattern pattern, Expr expr) {
diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc
index c141ca51ef..637cb0665d 100644
--- a/src/relay/ir/dataflow_pattern.cc
+++ b/src/relay/ir/dataflow_pattern.cc
@@ -344,8 +344,18 @@ TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
<< ")";
});
+void WildcardPattern::redirect_to(DFPattern pat) const {
+ WildcardPatternNode* ptr = static_cast<WildcardPatternNode*>(get_mutable());
+ ptr->pattern = pat;
+}
+
TVM_REGISTER_NODE_TYPE(WildcardPatternNode);
+TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern_redirect_to")
+ .set_body_typed([](WildcardPattern wildcard, DFPattern pat) {
+ return wildcard.redirect_to(pat);
+ });
+
TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]()
{
auto w = WildcardPattern(make_object<WildcardPatternNode>());
return w;
diff --git a/src/relay/ir/dataflow_pattern_functor.cc
b/src/relay/ir/dataflow_pattern_functor.cc
index 290f72df1d..76b3fe068e 100644
--- a/src/relay/ir/dataflow_pattern_functor.cc
+++ b/src/relay/ir/dataflow_pattern_functor.cc
@@ -105,7 +105,11 @@ void DFPatternVisitor::VisitDFPattern_(const
VarPatternNode* op) {}
void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}
-void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}
+void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {
+ if (op->pattern) {
+ VisitDFPattern(op->pattern.value());
+ }
+}
} // namespace relay
} // namespace tvm
diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc
index 044884f87e..f10920769d 100644
--- a/src/relay/ir/indexed_graph.cc
+++ b/src/relay/ir/indexed_graph.cc
@@ -537,7 +537,12 @@ std::unique_ptr<IndexedGraph<DFPattern>>
CreateIndexedGraph(const DFPattern& pat
void VisitDFPattern_(const VarPatternNode* op) override {}
- void VisitDFPattern_(const WildcardPatternNode* op) override {}
+ void VisitDFPattern_(const WildcardPatternNode* op) override {
+ if (op->pattern) {
+ auto node = graph_->item_to_node(GetRef<WildcardPattern>(op));
+ AddOutput(op->pattern.value(), node);
+ }
+ }
std::unique_ptr<IndexedGraph<DFPattern>> graph_;
};
diff --git a/tests/python/relay/test_dataflow_pattern.py
b/tests/python/relay/test_dataflow_pattern.py
index c4a83735ce..3950c02c08 100644
--- a/tests/python/relay/test_dataflow_pattern.py
+++ b/tests/python/relay/test_dataflow_pattern.py
@@ -1995,5 +1995,61 @@ def test_partition_parallel_branch_with_same_input():
assert tvm.ir.structural_equal(partitioned, reference)
+def test_rewrite_with_pattern_recursion():
+ data = relay.var("data", relay.TensorType((2, 8), "float32"))
+ dense_weight = relay.const(np.zeros((4, 8)))
+ feat = relay.nn.dense(data, dense_weight)
+ feat = relay.cast(feat, "float32")
+ feat = relay.cast(feat, "float32")
+ feat = relay.cast(feat, "float32")
+ feat = relay.cast(feat, "float32")
+ feat = relay.cast(feat, "float32")
+ oup = relay.cast(feat, "float32")
+
+ expected = relay.nn.relu(oup)
+
+ class TheRewrite(DFPatternCallback):
+ def __init__(self, pattern):
+ super(TheRewrite, self).__init__(rewrite_once=True)
+ self.pattern = pattern
+
+ def callback(self, pre, post, node_map):
+ return relay.nn.relu(post)
+
+ def test_reset_call_args():
+ dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
+ wildcard_redirect = wildcard()
+ the_pattern = is_op("cast")(wildcard_redirect)
+ the_pattern2 = the_pattern | dense_pattern
+ wildcard_redirect.redirect_to(the_pattern2)
+
+ actual = rewrite(TheRewrite(the_pattern), oup)
+ tvm.ir.assert_structural_equal(actual, expected)
+
+ def test_reset_alt_left():
+ dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
+ wildcard_redirect = wildcard()
+ or_pattern = wildcard_redirect | dense_pattern
+ the_pattern = is_op("cast")(or_pattern)
+ wildcard_redirect.redirect_to(the_pattern)
+
+ actual = rewrite(TheRewrite(the_pattern), oup)
+ tvm.ir.assert_structural_equal(actual, expected)
+
+ def test_reset_alt_right():
+ dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
+ wildcard_redirect = wildcard()
+ or_pattern = dense_pattern | wildcard_redirect
+ the_pattern = is_op("cast")(or_pattern)
+ wildcard_redirect.redirect_to(the_pattern)
+
+ actual = rewrite(TheRewrite(the_pattern), oup)
+ tvm.ir.assert_structural_equal(actual, expected)
+
+ test_reset_call_args()
+ test_reset_alt_left()
+ test_reset_alt_right()
+
+
if __name__ == "__main__":
tvm.testing.main()