This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9eeb5bcbcb [Unity] Support Regular expression matching in globalvar 
dataflow pattern (#16085)
9eeb5bcbcb is described below

commit 9eeb5bcbcb94da8bbe87428274d1a0d6f7a655a7
Author: Hongyi Jin <[email protected]>
AuthorDate: Tue Nov 7 18:55:16 2023 -0800

    [Unity] Support Regular expression matching in globalvar dataflow pattern 
(#16085)
    
    * global var pattern support regex
    
    * format
---
 src/relax/ir/dataflow_matcher.cc            | 7 +++++--
 tests/python/relax/test_dataflow_pattern.py | 1 +
 2 files changed, 6 insertions(+), 2 deletions(-)

diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index d1edb945ba..9524c90b57 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -36,6 +36,7 @@
 #include <cstddef>
 #include <limits>
 #include <optional>
+#include <regex>
 #include <type_traits>
 #include <unordered_map>
 #include <unordered_set>
@@ -555,8 +556,10 @@ bool DFPatternMatcher::VisitDFPattern_(const 
DataflowVarPatternNode* op, const E
 
 bool DFPatternMatcher::VisitDFPattern_(const GlobalVarPatternNode* op, const 
Expr& expr) {
   // GlobalVarPattern is not inherited from Var, so we need to handle it 
separately.
-  if (const auto* var_node = expr.as<GlobalVarNode>())
-    return "" == op->name_hint() || op->name_hint() == var_node->name_hint;
+  if (const auto* var_node = expr.as<GlobalVarNode>()) {
+    std::regex pat{std::string(op->name_hint())};
+    return "" == op->name_hint() || 
std::regex_search(std::string(var_node->name_hint), pat);
+  }
   return false;
 }
 
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 520fb87322..685a382ad7 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -97,6 +97,7 @@ def test_dataflow_var_pattern():
 
 def test_global_var_pattern():
     assert is_gv("x").match(rx.GlobalVar("x"))
+    assert is_gv("x.*").match(rx.GlobalVar("x_2"))
     assert is_gv().match(rx.GlobalVar("x"))
     assert not is_gv("x").match(rx.GlobalVar("y"))
     assert not is_gv("x").match(rx.Var("x"))

Reply via email to