This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9eeb5bcbcb [Unity] Support Regular expression matching in globalvar
dataflow pattern (#16085)
9eeb5bcbcb is described below
commit 9eeb5bcbcb94da8bbe87428274d1a0d6f7a655a7
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Nov 7 18:55:16 2023 -0800
[Unity] Support Regular expression matching in globalvar dataflow pattern
(#16085)
* global var pattern support regex
* format
---
src/relax/ir/dataflow_matcher.cc | 7 +++++--
tests/python/relax/test_dataflow_pattern.py | 1 +
2 files changed, 6 insertions(+), 2 deletions(-)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index d1edb945ba..9524c90b57 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -36,6 +36,7 @@
#include <cstddef>
#include <limits>
#include <optional>
+#include <regex>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>
@@ -555,8 +556,10 @@ bool DFPatternMatcher::VisitDFPattern_(const
DataflowVarPatternNode* op, const E
bool DFPatternMatcher::VisitDFPattern_(const GlobalVarPatternNode* op, const
Expr& expr) {
// GlobalVarPattern is not inherited from Var, so we need to handle it
separately.
- if (const auto* var_node = expr.as<GlobalVarNode>())
- return "" == op->name_hint() || op->name_hint() == var_node->name_hint;
+ if (const auto* var_node = expr.as<GlobalVarNode>()) {
+ std::regex pat{std::string(op->name_hint())};
+ return "" == op->name_hint() ||
std::regex_search(std::string(var_node->name_hint), pat);
+ }
return false;
}
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 520fb87322..685a382ad7 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -97,6 +97,7 @@ def test_dataflow_var_pattern():
def test_global_var_pattern():
assert is_gv("x").match(rx.GlobalVar("x"))
+ assert is_gv("x.*").match(rx.GlobalVar("x_2"))
assert is_gv().match(rx.GlobalVar("x"))
assert not is_gv("x").match(rx.GlobalVar("y"))
assert not is_gv("x").match(rx.Var("x"))