This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f6bbe946e1 [relay][simplify_expr]: Add pass to remove trivial 
transpose ops (#14858)
f6bbe946e1 is described below

commit f6bbe946e123f44d6078cdbc70d5ab5affffbf67
Author: Krishna Bindumadhavan <[email protected]>
AuthorDate: Tue May 16 13:29:07 2023 +0530

    [relay][simplify_expr]: Add pass to remove trivial transpose ops (#14858)
    
    [relay][simplify_expr]: Add pattern to remove trivial transpose ops
---
 src/relay/transforms/simplify_expr.cc         | 111 +++++++++++++++++---------
 tests/python/relay/test_pass_simplify_expr.py |  22 +++++
 2 files changed, 97 insertions(+), 36 deletions(-)

diff --git a/src/relay/transforms/simplify_expr.cc 
b/src/relay/transforms/simplify_expr.cc
index 4f255647df..a557f2496b 100644
--- a/src/relay/transforms/simplify_expr.cc
+++ b/src/relay/transforms/simplify_expr.cc
@@ -260,6 +260,38 @@ class SimplifyCastClip : public DFPatternRewrite {
   DFPattern clip_, cast_;
 };
 
+/*!
+ * \brief Return the axis order for layout transform and transpose
+ * ops.
+ */
+static std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) {
+  std::vector<int> attr_axes;
+  if (auto attr = call->attrs.as<TransposeAttrs>()) {
+    if (attr->axes.defined()) {
+      for (int i = 0; i < ndim; ++i) {
+        int64_t axis = attr->axes[i].IntValue();
+        axis += (axis < 0) ? ndim : 0;
+        attr_axes.push_back(axis);
+      }
+    } else {
+      // Empty axes means reverse
+      for (int i = ndim - 1; i >= 0; --i) {
+        attr_axes.push_back(i);
+      }
+    }
+  } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
+    Layout src_layout(attr->src_layout);
+    Layout dst_layout(attr->dst_layout);
+    for (int i = 0; i < ndim; ++i) {
+      attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+    }
+  } else {
+    CHECK(false) << "Expected transpose or layout_transform, but got "
+                 << Downcast<Op>(call->op)->name;
+  }
+  return std::move(attr_axes);
+}
+
 /*!
  * \brief SimplifyTranspose matches the pattern of consecutive transpose op,
  *   and merges or cancels them.
@@ -316,19 +348,7 @@ class SimplifyTranspose : public DFPatternRewrite {
       it++;
     }
 
-    // Check if the transpose is still required
-    bool need_transpose = false;
-    for (int i = 0; i < ndim; ++i) {
-      if (axes[i] != i) {
-        need_transpose = true;
-        break;
-      }
-    }
-
-    if (need_transpose) {
-      return MakeTranspose(x, axes);
-    }
-    return x;
+    return MakeTranspose(x, axes);
   }
 
   String PermuteLayout(const String& layout, std::vector<int> axes_order) 
const {
@@ -431,32 +451,50 @@ class SimplifyTranspose : public DFPatternRewrite {
     return Downcast<Call>(output_layout_trans);
   }
 
-  std::vector<int> GetTransposeAxisOrder(const Call& call, int ndim) const {
-    std::vector<int> attr_axes;
-    if (auto attr = call->attrs.as<TransposeAttrs>()) {
-      if (attr->axes.defined()) {
-        for (int i = 0; i < ndim; ++i) {
-          int64_t axis = attr->axes[i].IntValue();
-          axis += (axis < 0) ? ndim : 0;
-          attr_axes.push_back(axis);
-        }
-      } else {
-        // Empty axes means reverse
-        for (int i = ndim - 1; i >= 0; --i) {
-          attr_axes.push_back(i);
-        }
+ private:
+  /*! \brief Pattern input */
+  DFPattern x_;
+};
+
+/*!
+ * \brief SimplifyNoOpTranspose matches the pattern of transpose or
+ *  layout transform ops which do not change the layout or rank and
+ *  removes the op.
+ */
+class SimplifyNoOpTranspose : public DFPatternRewrite {
+ public:
+  SimplifyNoOpTranspose() {
+    x_ = IsWildcard();
+    auto trans1 = IsOp("transpose") || IsOp("layout_transform");
+    pattern_ = trans1({x_});
+  }
+
+  Expr Callback(const Expr& pre, const Expr& post,
+                const Map<DFPattern, Array<Expr>>& node_map) const override {
+    auto x = node_map[x_][0];
+    Call trans_call = Downcast<Call>(post);
+
+    // Do not remove ops which change rank
+    if (auto attr = trans_call->attrs.as<LayoutTransformAttrs>()) {
+      if (attr->src_layout != attr->dst_layout) {
+        return post;
       }
-    } else if (auto attr = call->attrs.as<LayoutTransformAttrs>()) {
-      Layout src_layout(attr->src_layout);
-      Layout dst_layout(attr->dst_layout);
-      for (int i = 0; i < ndim; ++i) {
-        attr_axes.push_back(src_layout.IndexOf(dst_layout[i]));
+    }
+
+    int ndim = Downcast<TensorType>(pre->checked_type())->shape.size();
+    auto axes = GetTransposeAxisOrder(trans_call, ndim);
+
+    bool need_transpose = false;
+    for (int i = 0; i < ndim; ++i) {
+      if (axes[i] != i) {
+        need_transpose = true;
+        break;
       }
-    } else {
-      CHECK(false) << "Expected transpose or layout_transform, but got "
-                   << Downcast<Op>(call->op)->name;
     }
-    return std::move(attr_axes);
+
+    if (!need_transpose) return x;
+
+    return post;
   }
 
  private:
@@ -1037,6 +1075,7 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
   composer.AddRewrite<EliminateIdentityRewrite>();
   composer.AddRewrite<SimplifyReshape>();
   composer.AddRewrite<SimplifyTranspose>();
+  composer.AddRewrite<SimplifyNoOpTranspose>();
   composer.AddRewrite<SimplifySameCast>();
   composer.AddRewrite<SimplifyConsecutiveCast>();
   composer.AddRewrite<FullElementwise>();
diff --git a/tests/python/relay/test_pass_simplify_expr.py 
b/tests/python/relay/test_pass_simplify_expr.py
index d11242dbd8..4edb85f2d7 100644
--- a/tests/python/relay/test_pass_simplify_expr.py
+++ b/tests/python/relay/test_pass_simplify_expr.py
@@ -266,6 +266,27 @@ def test_simplify_transpose():
         y = relay.nn.relu(y)
         return relay.Function([x], y)
 
+    def before11():
+        """
+        Remove trivial no op transpose ops
+
+        Input:
+        op1 -> relay.transpose(x, axes=[0, 1, 2, 3]) -> op2
+
+        Simplified:
+        op1 -> op2
+        """
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.transpose(x, axes=[0, 1, 2, 3])
+        y = relay.nn.relu(y)
+        y = relay.layout_transform(y, "NCHW", "NCHW")
+        return relay.Function([x], y)
+
+    def expected11():
+        x = relay.var("x", shape=(1, 128, 56, 56), dtype="float32")
+        y = relay.nn.relu(x)
+        return relay.Function([x], y)
+
     for before, expected in [
         [before1(), expected1()],
         [before2(), expected2()],
@@ -277,6 +298,7 @@ def test_simplify_transpose():
         [before8(), expected8()],
         [before9(), expected9()],
         [before10(), expected10()],
+        [before11(), expected11()],
     ]:
         after = run_opt_pass(before, transform.SimplifyExpr())
         expected = run_opt_pass(expected, transform.InferType())

Reply via email to