This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new f78d9ed0ed [UNITY][Pass] Remove redundant reshape (#15806)
f78d9ed0ed is described below

commit f78d9ed0ed1f57b5828f5429744b375de5609ce3
Author: sjain58 <[email protected]>
AuthorDate: Thu Sep 28 01:20:24 2023 +0530

    [UNITY][Pass] Remove redundant reshape (#15806)
    
    * Remove Redundant Reshape ops
    
    * Added support for reshape(reshape(arg, shape1), shape2)
    
    * Fix spaces by running black
    
    * Fixed review commnets & Added a testcase for no_op_reshape pattern
    
    ---------
    
    Co-authored-by: shaljain <[email protected]>
---
 python/tvm/relax/transform/__init__.py             |   1 +
 .../relax/transform/remove_redundant_reshape.py    |  80 ++++++++++++++
 .../python/relax/test_remove_redundant_reshape.py  | 118 +++++++++++++++++++++
 3 files changed, 199 insertions(+)

diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 68128db62d..7bbe6d52c2 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -20,6 +20,7 @@
 from .transform import *
 from .lazy_transform_params import LazyTransformParams
 from .optimize_layout_transform import OptimizeLayoutTransform
+from .remove_redundant_reshape import RemoveRedundantReshape
 
 # Import to register the legalization functions.
 from . import legalize_ops
diff --git a/python/tvm/relax/transform/remove_redundant_reshape.py 
b/python/tvm/relax/transform/remove_redundant_reshape.py
new file mode 100644
index 0000000000..2274f8e5da
--- /dev/null
+++ b/python/tvm/relax/transform/remove_redundant_reshape.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, missing-function-docstring, 
abstract-method
+"""Relax Remove Redundant Reshape ops"""
+from tvm import IRModule, relax
+from tvm.ir.transform import PassContext
+from tvm.ir import structural_equal
+from tvm.relax import Expr, Function
+from tvm.relax.dpl import is_op, rewrite_call, wildcard
+from . import function_pass
+
+
+@function_pass(opt_level=0)
+class RemoveRedundantReshape:
+    """
+    Transformation pass to remove redundant reshape operator
+    """
+
+    def __init__(self):
+        self.input1 = wildcard()
+        shape1 = wildcard()
+        pattern_redundant_reshape = is_op("relax.reshape")(self.input1, shape1)
+        self.no_op_reshape = pattern_redundant_reshape
+        shape2 = wildcard()
+        self.repeated_reshape = 
is_op("relax.reshape")(pattern_redundant_reshape, shape2)
+        self.pattern = self.repeated_reshape | self.no_op_reshape
+
+    def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext) 
-> IRModule:
+        """
+        Tarnsformation function to remove redundant reshape
+        where tensors before and after reshape are of same dimentions.
+
+        Parameters
+        --------------
+        func: Expr
+            The relax function to be optimized
+
+        mod: IRModule
+            The IR module
+
+        ctx: PassContext
+            Relax pass context
+        """
+
+        updated_func = func
+        for _, funct in mod.functions.items():
+            # Skip non-relax functions
+            if not isinstance(funct, Function):
+                continue
+            # Skip primitive functions
+            if "Primitive" in funct.attrs.keys() and funct.attrs["Primitive"] 
!= 0:
+                continue
+
+            def rewriter(expr, matches):
+                args = matches[self.pattern]
+                if self.repeated_reshape in matches:
+                    return relax.op.reshape(matches[self.input1], args.args[1])
+                elif self.no_op_reshape in matches:
+                    if args.args[0].struct_info.shape:
+                        if structural_equal(args.args[0].struct_info.shape, 
args.args[1]):
+                            return args.args[0]
+                return expr
+
+            updated_func = rewrite_call(self.pattern, rewriter, funct)
+
+        return updated_func
diff --git a/tests/python/relax/test_remove_redundant_reshape.py 
b/tests/python/relax/test_remove_redundant_reshape.py
new file mode 100644
index 0000000000..806b563cb8
--- /dev/null
+++ b/tests/python/relax/test_remove_redundant_reshape.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Test relax transform - Eliminate redundant reshape operations
+"""
+import tvm.testing
+from tvm import relax
+from tvm.relax.transform import DeadCodeElimination
+from tvm.relax.transform import RemoveRedundantReshape
+from tvm.script import ir as I, relax as R
+
+
+def _run_pass_compare_output(Before, Expected):
+    fused_mod = RemoveRedundantReshape()(Before)
+    fused_mod = DeadCodeElimination()(fused_mod)
+    tvm.ir.assert_structural_equal(Expected, fused_mod)
+
+
+def test_remove_redundant_reshape_pass_one_arg():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+        ) -> R.Tensor((1, 1001), dtype="float16"):
+            with R.dataflow():
+                lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
+                lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv, 
R.shape([1, 1001]))
+                lv2: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv1, 
R.shape([1, 1001]))
+                R.output(lv2)
+            return lv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+        ) -> R.Tensor((1, 1001), dtype="float16"):
+            with R.dataflow():
+                lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
+                R.output(lv1)
+            return lv1
+
+    _run_pass_compare_output(Before, Expected)
+
+
+def test_remove_redundant_reshape_pass_two_arg():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+        ) -> R.Tensor((1, 1001), dtype="float16"):
+            with R.dataflow():
+                lv: R.Tensor((1, 1001, 1), dtype="float16") = R.reshape(x, 
R.shape([1, 1001, 1]))
+                lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv, 
R.shape([1, 1001]))
+                R.output(lv1)
+            return lv1
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+        ) -> R.Tensor((1, 1001), dtype="float16"):
+            with R.dataflow():
+                lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(x, 
R.shape([1, 1001]))
+                R.output(lv1)
+            return lv1
+
+    _run_pass_compare_output(Before, Expected)
+
+
+def test_remove_redundant_reshape_pass_three_arg():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+        ) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
+            with R.dataflow():
+                lv: R.Tensor((1, 1001, 1, 1), dtype="float16") = R.reshape(
+                    x, R.shape([1, 1001, 1, 1])
+                )
+                R.output(lv)
+            return lv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+        ) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
+            with R.dataflow():
+                lv: R.Tensor((1, 1001, 1, 1), dtype="float16") = x
+                R.output(lv)
+            return lv
+
+    _run_pass_compare_output(Before, Expected)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to