This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new f6bbe946e1 [relay][simplify_expr]: Add pass to remove trivial
transpose ops (#14858)
f6bbe946e1 is described below
commit f6bbe946e123f44d6078cdbc70d5ab5affffbf67
Author: Krishna Bindumadhavan <[email protected]>
AuthorDate: Tue May 16 13:29:07 2023 +0530
[relay][simplify_expr]: Add pass to remove trivial transpose ops (#14858)
[relay][simplify_expr]: Add pattern to remove trivial transpose ops
---
src/relay/transforms/simplify_expr.cc | 111 +++++++++++++++++---------
tests/python/relay/test_pass_simplify_expr.py | 22 +++++
2 files changed, 97 insertions(+), 36 deletions(-)
diff --git a/src/relay/transforms/simplify_expr.cc
b/src/relay/transforms/simplify_expr.cc
index 4f255647df..a557f2496b 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -260,6 +260,38 @@ class SimplifyCastClip : public DFPatternRewrite {
DFPattern clip_, cast_;
};
+/*!
+ * \brief Return the axis order for layout transform and transpose
+ * ops.
+ */
+static std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) {
+ std::vector<int> attr_axes;
+ if (auto attr = call->attrs.as<TransposeAttrs>()) {
+ if (attr->axes.defined()) {
+ for (int i = 0; i < ndim; ++i) {
+ int64_t axis = attr->axes[i].IntValue();
+ axis += (axis < 0) ? ndim : 0;
+ attr_axes.push_back(axis);
+ }
+ } else {
+ // Empty axes means reverse
+ for (int i = ndim - 1; i >= 0; --i) {
+ attr_axes.push_back(i);
+ }
+ }
+ } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+ Layout src_layout(attr->src_layout);
+ Layout dst_layout(attr->dst_layout);
+ for (int i = 0; i < ndim; ++i) {
+ attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+ }
+ } else {
+ CHECK(false) << "Expected transpose or layout_transform, but got "
+ << Downcast<Op>(call->op)->name;
+ }
+ return std::move(attr_axes);
+}
+
/*!
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
* and merges or cancels them.
@@ -316,19 +348,7 @@ class SimplifyTranspose : public DFPatternRewrite {
it++;
}
- // Check if the transpose is still required
- bool need_transpose = false;
- for (int i = 0; i < ndim; ++i) {
- if (axes[i] != i) {
- need_transpose = true;
- break;
- }
- }
-
- if (need_transpose) {
- return MakeTranspose(x, axes);
- }
- return x;
+ return MakeTranspose(x, axes);
}
String PermuteLayout(const String& layout, std::vector<int> axes_order)
const {
@@ -431,32 +451,50 @@ class SimplifyTranspose : public DFPatternRewrite {
return Downcast<Call>(output_layout_trans);
}
- std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
- std::vector<int> attr_axes;
- if (auto attr = call->attrs.as<TransposeAttrs>()) {
- if (attr->axes.defined()) {
- for (int i = 0; i < ndim; ++i) {
- int64_t axis = attr->axes[i].IntValue();
- axis += (axis < 0) ? ndim : 0;
- attr_axes.push_back(axis);
- }
- } else {
- // Empty axes means reverse
- for (int i = ndim - 1; i >= 0; --i) {
- attr_axes.push_back(i);
- }
+ private:
+ /*! \brief Pattern input */
+ DFPattern x_;
+};
+
+/*!
+ * \brief SimplifyNoOpTranspose matches the pattern of transpose or
+ * layout transform ops which do not change the layout or rank and
+ * removes the op.
+ */
+class SimplifyNoOpTranspose : public DFPatternRewrite {
+ public:
+ SimplifyNoOpTranspose() {
+ x_ = IsWildcard();
+ auto trans1 = IsOp("transpose") || IsOp("layout_transform");
+ pattern_ = trans1({x_});
+ }
+
+ Expr Callback(const Expr& pre, const Expr& post,
+ const Map<DFPattern, Array<Expr>>& node_map) const override {
+ auto x = node_map[x_][0];
+ Call trans_call = Downcast<Call>(post);
+
+ // Do not remove ops which change rank
+ if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
+ if (attr->src_layout != attr->dst_layout) {
+ return post;
}
- } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
- Layout src_layout(attr->src_layout);
- Layout dst_layout(attr->dst_layout);
- for (int i = 0; i < ndim; ++i) {
- attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+ }
+
+ int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
+ auto axes = GetTransposeAxisOrder(trans_call, ndim);
+
+ bool need_transpose = false;
+ for (int i = 0; i < ndim; ++i) {
+ if (axes[i] != i) {
+ need_transpose = true;
+ break;
}
- } else {
- CHECK(false) << "Expected transpose or layout_transform, but got "
- << Downcast<Op>(call->op)->name;
}
- return std::move(attr_axes);
+
+ if (!need_transpose) return x;
+
+ return post;
}
private:
@@ -1037,6 +1075,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
composer.AddRewrite<EliminateIdentityRewrite>();
composer.AddRewrite<SimplifyReshape>();
composer.AddRewrite<SimplifyTranspose>();
+ composer.AddRewrite<SimplifyNoOpTranspose>();
composer.AddRewrite<SimplifySameCast>();
composer.AddRewrite<SimplifyConsecutiveCast>();
composer.AddRewrite<FullElementwise>();
diff --git a/tests/python/relay/test_pass_simplify_expr.py
b/tests/python/relay/test_pass_simplify_expr.py
index d11242dbd8..4edb85f2d7 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -266,6 +266,27 @@ def test_simplify_transpose():
y = relay.nn.relu(y)
return relay.Function([x], y)
+ def before11():
+ """
+ Remove trivial no op transpose ops
+
+ Input:
+ op1 -> relay.transpose(x, axes=[0, 1, 2, 3]) -> op2
+
+ Simplified:
+ op1 -> op2
+ """
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.transpose(x, axes=[0, 1, 2, 3])
+ y = relay.nn.relu(y)
+ y = relay.layout_transform(y, "NCHW", "NCHW")
+ return relay.Function([x], y)
+
+ def expected11():
+ x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+ y = relay.nn.relu(x)
+ return relay.Function([x], y)
+
for before, expected in [
[before1(), expected1()],
[before2(), expected2()],
@@ -277,6 +298,7 @@ def test_simplify_transpose():
[before8(), expected8()],
[before9(), expected9()],
[before10(), expected10()],
+ [before11(), expected11()],
]:
after = run_opt_pass(before, transform.SimplifyExpr())
expected = run_opt_pass(expected, transform.InferType())