This is an automated email from the ASF dual-hosted git repository.

yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new adf256df79 [Unity][Graph matching] Automatically add `used-by` 
constraints for `is_op` pattern (#14439)
adf256df79 is described below

commit adf256df791c263185ca376973d0a26337fdb061
Author: masahi <[email protected]>
AuthorDate: Fri Mar 31 07:22:04 2023 +0900

    [Unity][Graph matching] Automatically add `used-by` constraints for `is_op` 
pattern (#14439)
    
    * Automatically add used-by constraints for is_op pattern
    
    * [DOC fix] pass context -> constraint context
---
 include/tvm/relax/dataflow_matcher.h        |  9 ---------
 include/tvm/relax/dataflow_pattern.h        |  8 ++++----
 python/tvm/relax/dpl/pattern.py             | 18 +++++++++++++-----
 src/relax/ir/dataflow_pattern.cc            |  8 +++++---
 tests/python/relax/test_dataflow_pattern.py |  9 ---------
 5 files changed, 22 insertions(+), 30 deletions(-)

diff --git a/include/tvm/relax/dataflow_matcher.h 
b/include/tvm/relax/dataflow_matcher.h
index fa58308fac..498f77a3f7 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -65,15 +65,6 @@ TVM_DLL tvm::runtime::Map<DFPattern, Var> MatchGraph(const 
PatternContext& ctx,
                                                      Optional<Var> start_hint 
= NullOpt,
                                                      bool must_include_hint = 
false);
 
-/**
- * \brief Match a graph-wise pattern with the current context 
(PatternContext::Current()).
- */
-inline tvm::runtime::Map<DFPattern, Var> MatchGraphDefault(const 
DataflowBlock& dfb,
-                                                           Optional<Var> 
start_hint = NullOpt,
-                                                           bool 
must_include_hint = false) {
-  return MatchGraph(PatternContext::Current(), dfb, start_hint, 
must_include_hint);
-}
-
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/include/tvm/relax/dataflow_pattern.h 
b/include/tvm/relax/dataflow_pattern.h
index 144a7f45bf..e4c27f3558 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -245,15 +245,15 @@ class PatternContext : public ObjectRef {
     }
   }
 
-  /*! \brief Get the pass context object on the top of the stack */
-  TVM_DLL static PatternContext Current();
+  /*! \brief Get the constraint context object on the top of the stack */
+  TVM_DLL static Optional<PatternContext> Current();
 
   class Internal;
 
  private:
-  /*! \brief The RAII-like entry of a pass context scope */
+  /*! \brief The RAII-like entry of a constraint context scope */
   TVM_DLL void EnterWithScope();
-  /*! \brief The RAII-like exit of a pass context scope */
+  /*! \brief The RAII-like exit of a constraint context scope */
   TVM_DLL void ExitWithScope();
   friend class Internal;
   friend class With<PatternContext>;
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 248e957726..acabac2dcb 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -52,7 +52,7 @@ def register_df_node(type_key=None):
 class DFPattern(Node):
     """Base class of all Patterns."""
 
-    def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern":
+    def __call__(self, *args, varg_default_wildcard=False, 
add_constraint=True) -> "CallPattern":
         """
         Syntax sugar for creating a CallPattern with argument patterns
 
@@ -61,7 +61,7 @@ class DFPattern(Node):
         result: CallPattern
             The resulting CallPattern
         """
-        return CallPattern(self, args, varg_default_wildcard)
+        return CallPattern(self, args, varg_default_wildcard, add_constraint)
 
     def __or__(self, other: "DFPattern") -> "OrPattern":
         """
@@ -387,6 +387,9 @@ class CallPattern(DFPattern):
     varg_default_wildcard: bool
         If True, args can be fewer than actual provided arguments.
 
+    add_constraint: bool
+        If True, automatically add "used-by" constraints between caller and 
callee expressions.
+
     Note
     ----
     By setting varg_default_wildcard to True, we can only focus on the argument
@@ -400,11 +403,16 @@ class CallPattern(DFPattern):
         op: "DFPattern",
         args: Union[List["DFPattern"], typing.Tuple["DFPattern", ...]],
         varg_default_wildcard: bool = False,
+        add_constraint=True,
     ):
         self.__init_handle_by_constructor__(
             ffi.CallPattern, op, args, varg_default_wildcard  # type: ignore
         )
 
+        if add_constraint:
+            for i, arg in enumerate(args):
+                arg.used_by(self, i)
+
 
 @register_df_node
 class FunctionPattern(DFPattern):
@@ -835,7 +843,7 @@ def _is_call_tir(
     elif isinstance(args, (list, tuple)):
         args = TuplePattern(args)
 
-    return is_op("relax.call_tir")(func_pattern, args)
+    return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
 
 
 # Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
@@ -871,7 +879,7 @@ def _is_call_dps_packed(
     elif isinstance(args, (list, tuple)):
         args = TuplePattern(args)
 
-    return is_op("relax.call_dps_packed")(func_pattern, args)
+    return is_op("relax.call_dps_packed")(func_pattern, args, 
add_constraint=False)
 
 
 def is_call_dps_packed(
@@ -915,7 +923,7 @@ def is_call_packed(
         The resulting CallPattern
     """
     if args is None:
-        return ExternFuncPattern(func_name)(varg_default_wildcard=True)
+        return ExternFuncPattern(func_name)(varg_default_wildcard=True, 
add_constraint=False)
     return ExternFuncPattern(func_name)(*args)
 
 
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 5eb1bf3ea6..5580f6a1ab 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -394,8 +394,8 @@ std::stack<PatternContext>& pattern_ctx_stack() {
   return graph_pattern_managers;
 }
 
-PatternContext PatternContext::Current() {
-  ICHECK(!pattern_ctx_stack().empty()) << "No active PatternContext found.";
+Optional<PatternContext> PatternContext::Current() {
+  if (pattern_ctx_stack().empty()) return NullOpt;
   return pattern_ctx_stack().top();
 }
 
@@ -419,7 +419,9 @@ void PatternContext::ExitWithScope() {
 }
 
 static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs, 
PairCons pcon) {
-  PatternContext::Current().add_constraint(lhs, rhs, pcon);
+  if (auto ctx = PatternContext::Current()) {
+    ctx.value().add_constraint(lhs, rhs, pcon);
+  }
 }
 
 TVM_REGISTER_NODE_TYPE(PatternSeqNode);
diff --git a/tests/python/relax/test_dataflow_pattern.py 
b/tests/python/relax/test_dataflow_pattern.py
index 9679e14fff..76bce47f7f 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1034,15 +1034,6 @@ def test_attention_qkv():
         matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
         matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
 
-        # TODO(masahi): Automate addition of used_by constraints during is_op
-        inp_pat.used_by(matmul1, 0)
-        inp_pat.used_by(matmul2, 0)
-        inp_pat.used_by(matmul3, 0)
-
-        Q_weight_pat.only_used_by(matmul1, 1)
-        K_weight_pat.only_used_by(matmul2, 1)
-        V_weight_pat.only_used_by(matmul3, 1)
-
         dfb = QKV_proj["main"].body.blocks[0]
         out = ctx.match_dfb(dfb)
 

Reply via email to