This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new cb6efae413 [Unity] Support pattern-based rewriting (#14312)
cb6efae413 is described below
commit cb6efae413a1068efbb1b3509c3198c96cade138
Author: masahi <[email protected]>
AuthorDate: Fri Mar 17 04:20:51 2023 +0900
[Unity] Support pattern-based rewriting (#14312)
* stub
* wip
* works
* restore binding
* attention test work
* use RemoveAllUnused
* simplified callback api
* pass original call node to callback
* clean test
* add doc
* add test for the case where the original call is returned
* callback -> rewriter and other doc improvement
---
python/tvm/relax/dpl/pattern.py | 40 +++++++++-
src/relax/ir/dataflow_matcher.cc | 48 +++++++++++
tests/python/relax/test_dataflow_pattern.py | 118 ++++++++++++++++++++++++++++
3 files changed, 204 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/dpl/pattern.py b/python/tvm/relax/dpl/pattern.py
index 1ca41b378d..248e957726 100644
--- a/python/tvm/relax/dpl/pattern.py
+++ b/python/tvm/relax/dpl/pattern.py
@@ -20,7 +20,7 @@
# pylint: disable=pointless-statement
import typing
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Dict, List, Optional, Tuple, Union, Callable
import tvm
import tvm._ffi as tvm_ffi
@@ -31,7 +31,7 @@ from tvm.relay.op import get
from ...ir import make_node
from ...ir.base import Node
from ...runtime import Object
-from ..expr import Expr, Var
+from ..expr import Expr, Var, Function
from . import _ffi as ffi
@@ -1115,3 +1115,39 @@ def make_fused_bias_activation_pattern(op_name,
with_bias=False, activation=None
return is_op(activation)(out)
return out
+
+
+def rewrite(
+ pattern: DFPattern, rewriter: Callable[[Expr, Dict[DFPattern, Expr]],
Expr], func: Function
+) -> Function:
+ """
+ Rewrite a function with the given pattern and the rewriter function.
+
+ Parameters
+ ----------
+ pattern: DFPattern
+ The pattern to match.
+
+ rewriter: Callable[[Expr, Dict[DFPattern, Expr]], Expr]
+ The function to be called on a successful matching for rewriting.
Given the matched
+ call node and the map of patterns and matched expressions, it should
return a new call node
+ to replace the original one or the original matched call node as is.
+
+ For example, to replace x + x with 2 * x, we can write the rewriter as
follows:
+ ```
+ x = wildcard()
+ pattern = is_op("relax.add")(x, x)
+
+ def rewriter(orig, matchings):
+ return R.multiply(matchings[x], R.const(2, "float32"))
+ ```
+
+ func: Function
+ The function to rewrite.
+
+ Returns
+ -------
+ rewritten_func: Function
+ The rewritten or the input function, depending on the pattern matching
result.
+ """
+ return ffi.rewrite(pattern, rewriter, func)
diff --git a/src/relax/ir/dataflow_matcher.cc b/src/relax/ir/dataflow_matcher.cc
index da8c6ce2da..c6d705b5b4 100644
--- a/src/relax/ir/dataflow_matcher.cc
+++ b/src/relax/ir/dataflow_matcher.cc
@@ -766,5 +766,53 @@ Map<DFPattern, Var> MatchGraph(const PatternContext& ctx,
const DataflowBlock& d
TVM_REGISTER_GLOBAL("relax.dpl.match_dfb").set_body_typed(MatchGraph);
+/*!
+ * \brief Apply pattern matching to each call node and replace matching ones
with the output of
+ * a user-provided rewriter function.
+ */
+class PatternRewriter : ExprMutator {
+ public:
+ using ExprMutator::VisitExpr_;
+
+ PatternRewriter(DFPattern pat, PackedFunc rewriter_func)
+ : pattern_(pat), rewriter_func_(rewriter_func) {}
+
+ static Expr Run(DFPattern pat, PackedFunc rewriter_func, Function f) {
+ PatternRewriter rewriter(pat, rewriter_func);
+ return RemoveAllUnused(Downcast<Function>(rewriter.VisitExpr(f)));
+ }
+
+ void VisitBinding_(const VarBindingNode* binding) final {
+ bindings_.Set(binding->var, binding->value);
+ ExprMutator::VisitBinding_(binding);
+ if (auto it = memo_.find(binding->value.get()); it != memo_.end()) {
+ // We need to update the binding to pass to ExtractMatchedExpr, so that
the rewritten
+ // expression can be subject to further pattern matchings.
+ bindings_.Set(binding->var, it->second);
+ }
+ }
+
+ Expr VisitExpr_(const CallNode* call_node) final {
+ auto call = ExprMutator::VisitExpr_(call_node);
+ if (auto matches_opt = ExtractMatchedExpr(pattern_, call, bindings_)) {
+ auto rewriten_expr = rewriter_func_(call, matches_opt.value());
+ memo_[call_node] = rewriten_expr;
+ return rewriten_expr;
+ }
+ return call;
+ }
+
+ private:
+ DFPattern pattern_;
+ PackedFunc rewriter_func_;
+ Map<Var, Expr> bindings_;
+ std::unordered_map<const Object*, Expr> memo_;
+};
+
+TVM_REGISTER_GLOBAL("relax.dpl.rewrite")
+ .set_body_typed([](DFPattern pat, PackedFunc rewriter, Function f) {
+ return PatternRewriter::Run(pat, rewriter, f);
+ });
+
} // namespace relax
} // namespace tvm
diff --git a/tests/python/relax/test_dataflow_pattern.py
b/tests/python/relax/test_dataflow_pattern.py
index b57dca19f2..a40faf3bcb 100644
--- a/tests/python/relax/test_dataflow_pattern.py
+++ b/tests/python/relax/test_dataflow_pattern.py
@@ -888,5 +888,123 @@ def test_incremental_solving_counter():
assert not ctx1.match_dfb(simple_chain.body.blocks[0])
+def test_rewrite_simple():
+ @R.function
+ def main(x: R.Tensor((16, 16), "float32")) -> R.Tensor((16, 16),
"float32"):
+ with R.dataflow():
+ x2 = R.add(x, x)
+ x4 = R.add(x2, x2)
+ R.output(x4)
+ return x4
+
+ @R.function
+ def expected1(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ with R.dataflow():
+ lv: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(2,
"float32"))
+ x4: R.Tensor((16, 16), dtype="float32") = R.multiply(lv,
R.const(2, "float32"))
+ R.output(x4)
+ return x4
+
+ @R.function
+ def expected2(x: R.Tensor((16, 16), dtype="float32")) -> R.Tensor((16,
16), dtype="float32"):
+ with R.dataflow():
+ x4: R.Tensor((16, 16), dtype="float32") = R.multiply(x, R.const(4,
"float32"))
+ R.output(x4)
+ return x4
+
+ x = wildcard()
+ pattern = is_op("relax.add")(x, x)
+
+ def rewriter(_, matchings):
+ return R.multiply(matchings[x], R.const(2, "float32"))
+
+ rewritten = rewrite(pattern, rewriter, main)
+ tvm.ir.assert_structural_equal(rewritten, expected1)
+
+ add1 = is_op("relax.add")(x, x)
+ pattern = is_op("relax.add")(add1, add1)
+
+ def rewriter(_, matchings):
+ return R.multiply(matchings[x], R.const(4, "float32"))
+
+ rewritten = rewrite(pattern, rewriter, main)
+ tvm.ir.assert_structural_equal(rewritten, expected2)
+
+ # No rewriting, return the original call node as is
+ def rewriter(orig, _):
+ return orig
+
+ rewritten = rewrite(pattern, rewriter, main)
+ tvm.ir.assert_structural_equal(rewritten, main)
+
+
+def test_rewrite_attention():
+ @R.function
+ def main(
+ Q: R.Tensor((2, 4096, 8, 40), "float32"),
+ K: R.Tensor((2, 4096, 8, 40), "float32"),
+ V: R.Tensor((2, 4096, 8, 40), "float32"),
+ ) -> R.Tensor((2, 4096, 8, 40), "float32"):
+ with R.dataflow():
+ lv58 = R.permute_dims(Q, axes=[0, 2, 1, 3])
+ lv59 = R.reshape(lv58, R.shape([16, 4096, 40]))
+
+ lv61 = R.permute_dims(K, axes=[0, 2, 1, 3])
+ lv62 = R.reshape(lv61, R.shape([16, 4096, 40]))
+
+ lv64 = R.permute_dims(V, axes=[0, 2, 1, 3])
+ lv65 = R.reshape(lv64, R.shape([16, 4096, 40]))
+
+ lv62_transposed = R.permute_dims(lv62, axes=[0, 2, 1])
+ lv3_1 = R.matmul(lv59, lv62_transposed)
+ lv68 = R.multiply(lv3_1, R.const(0.15811388194561005, "float32"))
+ lv69 = R.nn.softmax(lv68, axis=-1)
+ lv_3 = R.matmul(lv69, lv65)
+
+ lv71 = R.reshape(lv_3, R.shape([2, 8, 4096, 40]))
+ lv72 = R.permute_dims(lv71, axes=[0, 2, 1, 3])
+ R.output(lv72)
+
+ return lv72
+
+ @R.function
+ def expected(
+ Q: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ K: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ V: R.Tensor((2, 4096, 8, 40), dtype="float32"),
+ ) -> R.Tensor((2, 4096, 8, 40), dtype="float32"):
+ with R.dataflow():
+ lv72: R.Tensor((2, 4096, 8, 40), dtype="float32") =
R.nn.attention(Q, V, K)
+ R.output(lv72)
+ return lv72
+
+ def BSNH_to_BSH(tensor):
+ return is_op("relax.reshape")(is_op("relax.permute_dims")(tensor),
wildcard())
+
+ def BSH_to_BSNH(tensor):
+ return is_op("relax.permute_dims")(is_op("relax.reshape")(tensor,
wildcard()))
+
+ Q = wildcard()
+ K = wildcard()
+ V = wildcard()
+
+ Q_3D = BSNH_to_BSH(Q)
+ V_3D = BSNH_to_BSH(V)
+ K_3D = BSNH_to_BSH(K)
+
+ matmul1 = is_op("relax.matmul")(Q_3D, is_op("relax.permute_dims")(V_3D))
+ multiply = is_op("relax.multiply")(matmul1, is_const())
+ softmax = is_op("relax.nn.softmax")(multiply)
+ matmul2 = is_op("relax.matmul")(softmax, K_3D)
+
+ pattern = BSH_to_BSNH(matmul2)
+
+ def rewriter(_, matchings):
+ return R.nn.attention(matchings[Q], matchings[K], matchings[V])
+
+ rewritten = rewrite(pattern, rewriter, main)
+ tvm.ir.assert_structural_equal(rewritten, expected)
+
+
if __name__ == "__main__":
tvm.testing.main()