This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 92593cd11f Add pattern-based dataflow block rewriting
92593cd11f is described below
commit 92593cd11f74b0e8e872926907d634978148dda5
Author: Masahiro Masuda <[email protected]>
AuthorDate: Sat Apr 1 04:13:15 2023 +0900
Add pattern-based dataflow block rewriting
---
python/tvm/relax/dpl/pattern.py | 34 ++++++-
src/relax/ir/dataflow_matcher.cc | 113 ++++++++++++++++++--
tests/python/relax/test_dataflow_pattern.py | 153 +++++++++++++++++++++++++++-
3 files changed, 285 insertions(+), 15 deletions(-)
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index acabac2dcb..3026213ba2 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -1125,7 +1125,7 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
return out
-def rewrite(
+def rewrite_call(
pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]],
Expr], func: Function
) -> Function:
"""
@@ -1158,4 +1158,34 @@ def rewrite(
rewritten_func: Function
The rewritten or the input function, depending on the pattern matching
result.
"""
- return ffi.rewrite(pattern, rewriter, func)
+ return ffi.rewrite_call(pattern, rewriter, func)
+
+
+def rewrite_bindings(
+ ctx, rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr], func:
Function
+) -> Function:
+ """
+ Rewrite a function with the given pattern and the rewriter function.
+ Parameters
+ ----------
+ pattern: DFPattern
+ The pattern to match.
+ rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+ The function to be called on a successful matching for rewriting.
Given the matched
+ call node and the map of patterns and matched expressions, it should
return a new call node
+ to replace the original one or the original matched call node as is.
+ For example, to replace x + x with 2 * x, we can write the rewriter as
follows:
+ ```
+ x = wildcard()
+ pattern = is_op("relax.add")(x, x)
+ def rewriter(orig, matchings):
+ return R.multiply(matchings[x], R.const(2, "float32"))
+ ```
+ func: Function
+ The function to rewrite.
+ Returns
+ -------
+ rewritten_func: Function
+ The rewritten or the input function, depending on the pattern matching
result.
+ """
+ return ffi.rewrite_bindings(ctx, rewriter, func)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index 88381d6e26..0055929a78 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -780,18 +780,29 @@ Optional<Map<DFPattern, Var>> MatchGraph(const
PatternContext& ctx, const Datafl
TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
/*!
- * \brief Apply pattern matching to each call node and replace matching ones
with the output of
- * a user-provided rewriter function.
+ * \brief Apply pattern matching to each call node and dataflow block, and
replace matching ones
+ * with the output of a user-provided rewriter function.
*/
class PatternRewriter : ExprMutator {
public:
+ using ExprMutator::VisitBindingBlock_;
using ExprMutator::VisitExpr_;
- PatternRewriter(DFPattern pat, PackedFunc rewriter_func)
- : pattern_(pat), rewriter_func_(rewriter_func) {}
+ PatternRewriter(DFPattern pat, PackedFunc rewriter_func,
+ const std::unordered_set<const VarNode*>& params)
+ : pattern_(pat), rewriter_func_(rewriter_func), params_(params) {}
- static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) {
- PatternRewriter rewriter(pat, rewriter_func);
+ PatternRewriter(const PatternContext& ctx, PackedFunc rewriter_func,
+ const std::unordered_set<const VarNode*>& params)
+ : ctx_(ctx), rewriter_func_(rewriter_func), params_(params) {}
+
+ template <typename PatternType>
+ static Expr Run(PatternType pat, PackedFunc rewriter_func, Function f) {
+ std::unordered_set<const VarNode*> params;
+ for (const auto& p : f->params) {
+ params.insert(p.get());
+ }
+ PatternRewriter rewriter(pat, rewriter_func, params);
return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
}
@@ -807,7 +818,9 @@ class PatternRewriter : ExprMutator {
Expr VisitExpr_(const CallNode* call_node) final {
auto call = ExprMutator::VisitExpr_(call_node);
- if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) {
+ if (!pattern_) {
+ return call;
+ } else if (auto matches_opt = ExtractMatchedExpr(pattern_.value(), call,
bindings_)) {
auto rewriten_expr = rewriter_func_(call, matches_opt.value());
memo_[call_node] = rewriten_expr;
return rewriten_expr;
@@ -815,17 +828,99 @@ class PatternRewriter : ExprMutator {
return call;
}
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* block_node) final {
+ if (!ctx_) {
+ return ExprMutator::VisitBindingBlock_(block_node);
+ }
+ return RewriteDataflowBlockFixedPoint(GetRef<DataflowBlock>(block_node));
+ }
+
private:
- DFPattern pattern_;
+ void EmitUsedVars(Expr val, const Array<Binding>& pending_bindings,
+ std::unordered_set<const VarNode*>* emitted_vars) {
+ std::unordered_set<const VarNode*> unemitted_vars;
+ PostOrderVisit(val, [=, &unemitted_vars](Expr e) {
+ if (auto v = e.as<VarNode>(); v && !emitted_vars->count(v)) {
+ unemitted_vars.insert(v);
+ }
+ });
+
+ if (unemitted_vars.empty()) {
+ return;
+ }
+
+ size_t num_unemitted = unemitted_vars.size();
+ for (const auto& binding : pending_bindings) {
+ if (auto var_bind = binding.as<VarBindingNode>();
+ var_bind && unemitted_vars.count(var_bind->var.get())) {
+ EmitUsedVars(var_bind->value, pending_bindings, emitted_vars);
+ this->VisitBinding(binding);
+ emitted_vars->insert(var_bind->var.get());
+ if (--num_unemitted == 0) {
+ return;
+ }
+ }
+ }
+ }
+
+ // Repeat until all matchable subsets of bindings are rewritten.
+ BindingBlock RewriteDataflowBlockFixedPoint(BindingBlock block) {
+ if (auto matches = MatchGraph(ctx_.value(),
Downcast<DataflowBlock>(block))) {
+ builder_->BeginDataflowBlock();
+ Map<Var, Expr> replacements = rewriter_func_(matches.value());
+
+ std::unordered_set<const VarNode*> emitted_vars;
+
+ for (size_t i = 0; i < block->bindings.size(); ++i) {
+ const auto& binding = block->bindings[i];
+ if (auto var_bind = binding.as<VarBindingNode>()) {
+ if (replacements.count(var_bind->var)) {
+ auto new_val = replacements[var_bind->var];
+ Array<Binding> pending_bindings(block->bindings.begin() + i + 1,
block->bindings.end());
+ // Make sure there is no unbound variable used in the new value
before it is emitted
+ EmitUsedVars(new_val, pending_bindings, &emitted_vars);
+ this->ReEmitBinding(var_bind, builder_->Normalize(new_val));
+ } else if (!emitted_vars.count(var_bind->var.get())) {
+ this->VisitBinding(binding);
+ emitted_vars.insert(var_bind->var.get());
+ }
+ } else {
+ this->VisitBinding(binding);
+ }
+ }
+ return RewriteDataflowBlockFixedPoint(builder_->EndBlock());
+ }
+ return block;
+ }
+
+ /*! \brief The pattern for rewriting call nodes */
+ Optional<DFPattern> pattern_;
+ /*! \brief The pattern constraint contexts for rewriting dataflow blocks */
+ Optional<PatternContext> ctx_;
+ /*!
+ * \brief The user-provided rewriter function. Its signature and semantics
are:
+ * - (Call, Map<DFPattern, Expr>) -> Call for call node rewriting. Given the
matched
+ * call node and the map of patterns and matched expressions, it should
return a new call node
+ * to replace the original one or the original matched call node as is.
+ * - Map<DFPattern, Var> -> Map<Var, Expr> for dataflow block rewriting.
Given the map of patterns
+ * and corresponding variables (bound variables or parameters), it should
return a map that
+ * specifies new values for matched bound variables.
+ */
PackedFunc rewriter_func_;
+ std::unordered_set<const VarNode*> params_;
Map<Var, Expr> bindings_;
std::unordered_map<const Object*, Expr> memo_;
};
-TVM_REGISTER_GLOBAL("relax.dpl.rewrite")
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite_call")
.set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
return PatternRewriter::Run(pat, rewriter, f);
});
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite_bindings")
+ .set_body_typed([](const PatternContext& ctx, PackedFunc rewriter,
Function f) {
+ return PatternRewriter::Run(ctx, rewriter, f);
+ });
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index f18244096e..e4d7f7972c 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -918,7 +918,7 @@ def test_rewrite_simple():
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(2, "float32"))
- rewritten = rewrite(pattern, rewriter, main)
+ rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected1)
add1 = is_op("relax.add")(x, x)
@@ -927,14 +927,14 @@ def test_rewrite_simple():
def rewriter(_, matchings):
return R.multiply(matchings[x], R.const(4, "float32"))
- rewritten = rewrite(pattern, rewriter, main)
+ rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected2)
# No rewriting, return the original call node as is
def rewriter(orig, _):
return orig
- rewritten = rewrite(pattern, rewriter, main)
+ rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, main)
@@ -1002,7 +1002,7 @@ def test_rewrite_attention():
def rewriter(_, matchings):
return R.nn.attention(matchings[Q], matchings[K], matchings[V])
- rewritten = rewrite(pattern, rewriter, main)
+ rewritten = rewrite_call(pattern, rewriter, main)
tvm.ir.assert_structural_equal(rewritten, expected)
@@ -1075,5 +1075,150 @@ def test_attention_fake_qkv():
assert ctx.match_dfb(dfb) is None
+def get_qkv_proj_rewriter(
+ inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1, matmul2,
matmul3
+):
+ def qkv_proj_rewriter(matchings):
+ inp = matchings[inp_pat]
+ Q_weight = matchings[Q_weight_pat]
+ K_weight = matchings[K_weight_pat]
+ V_weight = matchings[V_weight_pat]
+ width = Q_weight.struct_info.shape[1]
+
+ concat = R.concat([Q_weight, K_weight, V_weight], axis=1)
+ matmul = R.matmul(inp, concat)
+ Q = R.strided_slice(matmul, axes=[2], begin=[0], end=[width])
+ K = R.strided_slice(matmul, axes=[2], begin=[width], end=[width * 2])
+ V = R.strided_slice(matmul, axes=[2], begin=[width * 2], end=[width *
3])
+
+ return {matchings[matmul1]: Q, matchings[matmul2]: K,
matchings[matmul3]: V}
+
+ return qkv_proj_rewriter
+
+
+def test_combine_matmul_twice():
+ @R.function
+ def qkv_x2(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ x2: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ w3: R.Tensor((640, 640), "float32"),
+ w4: R.Tensor((640, 640), "float32"),
+ w5: R.Tensor((640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv0 = R.matmul(x1, w0)
+ lv1 = R.matmul(x1, w1)
+ lv2 = R.matmul(x1, w2)
+ lv3 = R.matmul(x2, w3)
+ lv4 = R.matmul(x2, w4)
+ lv5 = R.matmul(x2, w5)
+ out = (lv0, lv1, lv2, lv3, lv4, lv5)
+ R.output(out)
+ return out
+
+ @R.function
+ def expected(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ x2: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ w3: R.Tensor((640, 640), "float32"),
+ w4: R.Tensor((640, 640), "float32"),
+ w5: R.Tensor((640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ lv = R.concat((w0, w1, w2), axis=1)
+ lv1 = R.matmul(x1, lv)
+ lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
+ lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
+ lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
+ lv2_1 = R.concat((w3, w4, w5), axis=1)
+ lv3 = R.matmul(x2, lv2_1, out_dtype="void")
+ lv3_1 = R.strided_slice(lv3, axes=[2], begin=[0], end=[640])
+ lv4 = R.strided_slice(lv3, axes=[2], begin=[640], end=[1280])
+ lv5 = R.strided_slice(lv3, axes=[2], begin=[1280], end=[1920])
+ out = lv0, lv1_1, lv2, lv3_1, lv4, lv5
+ R.output(out)
+ return out
+
+ with PatternContext() as ctx:
+ inp_pat = wildcard()
+ Q_weight_pat = wildcard()
+ K_weight_pat = wildcard()
+ V_weight_pat = wildcard()
+
+ matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+ matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+ matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+ rewriter = get_qkv_proj_rewriter(
+ inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1,
matmul2, matmul3
+ )
+ rewritten = rewrite_bindings(ctx, rewriter, qkv_x2)
+ tvm.ir.assert_structural_equal(rewritten, expected)
+
+
+def test_combine_matmul_emit_order():
+ @R.function
+ def main(
+ x1: R.Tensor((2, 1024, 640), "float32"),
+ w0: R.Tensor((640, 640), "float32"),
+ w1: R.Tensor((640, 640), "float32"),
+ w2: R.Tensor((640, 640), "float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ w0_t = R.permute_dims(w0, axes=None)
+ lv0 = R.matmul(x1, w0_t)
+ w1_t = R.permute_dims(w1, axes=None)
+ w1_t_t = R.permute_dims(w1_t, axes=None)
+ lv1 = R.matmul(x1, w1_t_t)
+ w2_t = R.permute_dims(w2, axes=None)
+ lv2 = R.matmul(x1, w2_t)
+ out = (lv0, lv1, lv2)
+ R.output(out)
+ return out
+
+ @R.function
+ def expected(
+ x1: R.Tensor((2, 1024, 640), dtype="float32"),
+ w0: R.Tensor((640, 640), dtype="float32"),
+ w1: R.Tensor((640, 640), dtype="float32"),
+ w2: R.Tensor((640, 640), dtype="float32"),
+ ) -> R.Tensor:
+ with R.dataflow():
+ w0_t = R.permute_dims(w0, axes=None)
+ w1_t = R.permute_dims(w1, axes=None)
+ w1_t_t = R.permute_dims(w1_t, axes=None)
+ w2_t = R.permute_dims(w2, axes=None)
+ lv = R.concat((w0_t, w1_t_t, w2_t), axis=1)
+ lv1 = R.matmul(x1, lv, out_dtype="void")
+ lv0 = R.strided_slice(lv1, axes=[2], begin=[0], end=[640])
+ lv1_1 = R.strided_slice(lv1, axes=[2], begin=[640], end=[1280])
+ lv2 = R.strided_slice(lv1, axes=[2], begin=[1280], end=[1920])
+ out = lv0, lv1_1, lv2
+ R.output(out)
+ return out
+
+ with PatternContext() as ctx:
+ inp_pat = wildcard()
+ Q_weight_pat = wildcard()
+ K_weight_pat = wildcard()
+ V_weight_pat = wildcard()
+
+ matmul1 = is_op("relax.matmul")(inp_pat, Q_weight_pat)
+ matmul2 = is_op("relax.matmul")(inp_pat, K_weight_pat)
+ matmul3 = is_op("relax.matmul")(inp_pat, V_weight_pat)
+
+ rewriter = get_qkv_proj_rewriter(
+ inp_pat, Q_weight_pat, K_weight_pat, V_weight_pat, matmul1,
matmul2, matmul3
+ )
+ rewritten = rewrite_bindings(ctx, rewriter, main)
+ tvm.ir.assert_structural_equal(rewritten, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()