This is an automated email from the ASF dual-hosted git repository.
yaxingcai pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new adf256df79 [Unity][Graph matching] Automatically add `used-by`
constraints for `is_op` pattern (#14439)
adf256df79 is described below
commit adf256df791c263185ca376973d0a26337fdb061
Author: masahi <[email protected]>
AuthorDate: Fri Mar 31 07:22:04 2023 +0900
[Unity][Graph matching] Automatically add `used-by` constraints for `is_op`
pattern (#14439)
* Automatically add used-by constraints for is_op pattern
* [DOC fix] pass context -> constraint context
---
include/tvm/relax/dataflow_matcher.h | 9 ---------
include/tvm/relax/dataflow_pattern.h | 8 ++++----
python/tvm/relax/dpl/pattern.py | 18 +++++++++++++-----
src/relax/ir/dataflow_pattern.cc | 8 +++++---
tests/python/relax/test_dataflow_pattern.py | 9 ---------
5 files changed, 22 insertions(+), 30 deletions(-)
diff --git a/include/tvm/relax/dataflow_matcher.h
b/include/tvm/relax/dataflow_matcher.h
index fa58308fac..498f77a3f7 100644
--- a/include/tvm/relax/dataflow_matcher.h
+++ b/include/tvm/relax/dataflow_matcher.h
@@ -65,15 +65,6 @@ TVM_DLL tvm::runtime::Map<DFPattern, Var> MatchGraph(const
PatternContext& ctx,
Optional<Var> start_hint
= NullOpt,
bool must_include_hint =
false);
-/**
- * \brief Match a graph-wise pattern with the current context
(PatternContext::Current()).
- */
-inline tvm::runtime::Map<DFPattern, Var> MatchGraphDefault(const
DataflowBlock& dfb,
- Optional<Var>
start_hint = NullOpt,
- bool
must_include_hint = false) {
- return MatchGraph(PatternContext::Current(), dfb, start_hint,
must_include_hint);
-}
-
} // namespace relax
} // namespace tvm
diff --git a/include/tvm/relax/dataflow_pattern.h
b/include/tvm/relax/dataflow_pattern.h
index 144a7f45bf..e4c27f3558 100644
--- a/include/tvm/relax/dataflow_pattern.h
+++ b/include/tvm/relax/dataflow_pattern.h
@@ -245,15 +245,15 @@ class PatternContext : public ObjectRef {
}
}
- /*! \brief Get the pass context object on the top of the stack */
- TVM_DLL static PatternContext Current();
+ /*! \brief Get the constraint context object on the top of the stack */
+ TVM_DLL static Optional<PatternContext> Current();
class Internal;
private:
- /*! \brief The RAII-like entry of a pass context scope */
+ /*! \brief The RAII-like entry of a constraint context scope */
TVM_DLL void EnterWithScope();
- /*! \brief The RAII-like exit of a pass context scope */
+ /*! \brief The RAII-like exit of a constraint context scope */
TVM_DLL void ExitWithScope();
friend class Internal;
friend class With<PatternContext>;
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 248e957726..acabac2dcb 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -52,7 +52,7 @@ def register_df_node(type_key=None):
class DFPattern(Node):
"""Base class of all Patterns."""
- def __call__(self, *args, varg_default_wildcard=False) -> "CallPattern":
+ def __call__(self, *args, varg_default_wildcard=False,
add_constraint=True) -> "CallPattern":
"""
Syntax sugar for creating a CallPattern with argument patterns
@@ -61,7 +61,7 @@ class DFPattern(Node):
result: CallPattern
The resulting CallPattern
"""
- return CallPattern(self, args, varg_default_wildcard)
+ return CallPattern(self, args, varg_default_wildcard, add_constraint)
def __or__(self, other: "DFPattern") -> "OrPattern":
"""
@@ -387,6 +387,9 @@ class CallPattern(DFPattern):
varg_default_wildcard: bool
If True, args can be fewer than actual provided arguments.
+ add_constraint: bool
+ If True, automatically add "used-by" constraints between caller and
callee expressions.
+
Note
----
By setting varg_default_wildcard to True, we can only focus on the argument
@@ -400,11 +403,16 @@ class CallPattern(DFPattern):
op: "DFPattern",
args: Union[List["DFPattern"], typing.Tuple["DFPattern", ...]],
varg_default_wildcard: bool = False,
+ add_constraint=True,
):
self.__init_handle_by_constructor__(
ffi.CallPattern, op, args, varg_default_wildcard # type: ignore
)
+ if add_constraint:
+ for i, arg in enumerate(args):
+ arg.used_by(self, i)
+
@register_df_node
class FunctionPattern(DFPattern):
@@ -835,7 +843,7 @@ def _is_call_tir(
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)
- return is_op("relax.call_tir")(func_pattern, args)
+ return is_op("relax.call_tir")(func_pattern, args, add_constraint=False)
# Todo(relax-team): Dataflow pattern for StructInfo, and match out_sinfo
@@ -871,7 +879,7 @@ def _is_call_dps_packed(
elif isinstance(args, (list, tuple)):
args = TuplePattern(args)
- return is_op("relax.call_dps_packed")(func_pattern, args)
+ return is_op("relax.call_dps_packed")(func_pattern, args,
add_constraint=False)
def is_call_dps_packed(
@@ -915,7 +923,7 @@ def is_call_packed(
The resulting CallPattern
"""
if args is None:
- return ExternFuncPattern(func_name)(varg_default_wildcard=True)
+ return ExternFuncPattern(func_name)(varg_default_wildcard=True,
add_constraint=False)
return ExternFuncPattern(func_name)(*args)
diff --git a/src/relax/ir/dataflow_pattern.cc b/src/relax/ir/dataflow_pattern.cc
index 5eb1bf3ea6..5580f6a1ab 100644
--- a/src/relax/ir/dataflow_pattern.cc
+++ b/src/relax/ir/dataflow_pattern.cc
@@ -394,8 +394,8 @@ std::stack<PatternContext>& pattern_ctx_stack() {
return graph_pattern_managers;
}
-PatternContext PatternContext::Current() {
- ICHECK(!pattern_ctx_stack().empty()) << "No active PatternContext found.";
+Optional<PatternContext> PatternContext::Current() {
+ if (pattern_ctx_stack().empty()) return NullOpt;
return pattern_ctx_stack().top();
}
@@ -419,7 +419,9 @@ void PatternContext::ExitWithScope() {
}
static void sync_graph_constraints(const DFPattern& lhs, const DFPattern& rhs,
PairCons pcon) {
- PatternContext::Current().add_constraint(lhs, rhs, pcon);
+ if (auto ctx = PatternContext::Current()) {
+ ctx.value().add_constraint(lhs, rhs, pcon);
+ }
}
TVM_REGISTER_NODE_TYPE(PatternSeqNode);
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index 9679e14fff..76bce47f7f 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -1034,15 +1034,6 @@ def test_attention_qkv():
matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
- # TODO(masahi): Automate addition of used_by constraints during is_op
- inp_pat.used_by(matmul1, 0)
- inp_pat.used_by(matmul2, 0)
- inp_pat.used_by(matmul3, 0)
-
- Q_weight_pat.only_used_by(matmul1, 1)
- K_weight_pat.only_used_by(matmul2, 1)
- V_weight_pat.only_used_by(matmul3, 1)
-
dfb = QKV_proj["main"].body.blocks[0]
out = ctx.match_dfb(dfb)