This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new f78d9ed0ed [UNITY][Pass] Remove redundant reshape (#15806)
f78d9ed0ed is described below
commit f78d9ed0ed1f57b5828f5429744b375de5609ce3
Author: sjain58 <[email protected]>
AuthorDate: Thu Sep 28 01:20:24 2023 +0530
[UNITY][Pass] Remove redundant reshape (#15806)
* Remove Redundant Reshape ops
* Added support for reshape(reshape(arg, shape1), shape2)
* Fix spaces by running black
* Fixed review commnets & Added a testcase for no_op_reshape pattern
---------
Co-authored-by: shaljain <[email protected]>
---
python/tvm/relax/transform/__init__.py | 1 +
.../relax/transform/remove_redundant_reshape.py | 80 ++++++++++++++
.../python/relax/test_remove_redundant_reshape.py | 118 +++++++++++++++++++++
3 files changed, 199 insertions(+)
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
index 68128db62d..7bbe6d52c2 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -20,6 +20,7 @@
from .transform import *
from .lazy_transform_params import LazyTransformParams
from .optimize_layout_transform import OptimizeLayoutTransform
+from .remove_redundant_reshape import RemoveRedundantReshape
# Import to register the legalization functions.
from . import legalize_ops
diff --git a/python/tvm/relax/transform/remove_redundant_reshape.py
b/python/tvm/relax/transform/remove_redundant_reshape.py
new file mode 100644
index 0000000000..2274f8e5da
--- /dev/null
+++ b/python/tvm/relax/transform/remove_redundant_reshape.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-argument, missing-function-docstring,
abstract-method
+"""Relax Remove Redundant Reshape ops"""
+from tvm import IRModule, relax
+from tvm.ir.transform import PassContext
+from tvm.ir import structural_equal
+from tvm.relax import Expr, Function
+from tvm.relax.dpl import is_op, rewrite_call, wildcard
+from . import function_pass
+
+
+@function_pass(opt_level=0)
+class RemoveRedundantReshape:
+ """
+ Transformation pass to remove redundant reshape operator
+ """
+
+ def __init__(self):
+ self.input1 = wildcard()
+ shape1 = wildcard()
+ pattern_redundant_reshape = is_op("relax.reshape")(self.input1, shape1)
+ self.no_op_reshape = pattern_redundant_reshape
+ shape2 = wildcard()
+ self.repeated_reshape =
is_op("relax.reshape")(pattern_redundant_reshape, shape2)
+ self.pattern = self.repeated_reshape | self.no_op_reshape
+
+ def transform_function(self, func: Expr, mod: IRModule, ctx: PassContext)
-> IRModule:
+ """
+ Tarnsformation function to remove redundant reshape
+ where tensors before and after reshape are of same dimentions.
+
+ Parameters
+ --------------
+ func: Expr
+ The relax function to be optimized
+
+ mod: IRModule
+ The IR module
+
+ ctx: PassContext
+ Relax pass context
+ """
+
+ updated_func = func
+ for _, funct in mod.functions.items():
+ # Skip non-relax functions
+ if not isinstance(funct, Function):
+ continue
+ # Skip primitive functions
+ if "Primitive" in funct.attrs.keys() and funct.attrs["Primitive"]
!= 0:
+ continue
+
+ def rewriter(expr, matches):
+ args = matches[self.pattern]
+ if self.repeated_reshape in matches:
+ return relax.op.reshape(matches[self.input1], args.args[1])
+ elif self.no_op_reshape in matches:
+ if args.args[0].struct_info.shape:
+ if structural_equal(args.args[0].struct_info.shape,
args.args[1]):
+ return args.args[0]
+ return expr
+
+ updated_func = rewrite_call(self.pattern, rewriter, funct)
+
+ return updated_func
diff --git a/tests/python/relax/test_remove_redundant_reshape.py
b/tests/python/relax/test_remove_redundant_reshape.py
new file mode 100644
index 0000000000..806b563cb8
--- /dev/null
+++ b/tests/python/relax/test_remove_redundant_reshape.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Test relax transform - Eliminate redundant reshape operations
+"""
+import tvm.testing
+from tvm import relax
+from tvm.relax.transform import DeadCodeElimination
+from tvm.relax.transform import RemoveRedundantReshape
+from tvm.script import ir as I, relax as R
+
+
+def _run_pass_compare_output(Before, Expected):
+ fused_mod = RemoveRedundantReshape()(Before)
+ fused_mod = DeadCodeElimination()(fused_mod)
+ tvm.ir.assert_structural_equal(Expected, fused_mod)
+
+
+def test_remove_redundant_reshape_pass_one_arg():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+ ) -> R.Tensor((1, 1001), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((1, 1001), dtype="float16") = R.reshape(x,
R.shape([1, 1001]))
+ lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv,
R.shape([1, 1001]))
+ lv2: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv1,
R.shape([1, 1001]))
+ R.output(lv2)
+ return lv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+ ) -> R.Tensor((1, 1001), dtype="float16"):
+ with R.dataflow():
+ lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(x,
R.shape([1, 1001]))
+ R.output(lv1)
+ return lv1
+
+ _run_pass_compare_output(Before, Expected)
+
+
+def test_remove_redundant_reshape_pass_two_arg():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+ ) -> R.Tensor((1, 1001), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((1, 1001, 1), dtype="float16") = R.reshape(x,
R.shape([1, 1001, 1]))
+ lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(lv,
R.shape([1, 1001]))
+ R.output(lv1)
+ return lv1
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+ ) -> R.Tensor((1, 1001), dtype="float16"):
+ with R.dataflow():
+ lv1: R.Tensor((1, 1001), dtype="float16") = R.reshape(x,
R.shape([1, 1001]))
+ R.output(lv1)
+ return lv1
+
+ _run_pass_compare_output(Before, Expected)
+
+
+def test_remove_redundant_reshape_pass_three_arg():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+ ) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((1, 1001, 1, 1), dtype="float16") = R.reshape(
+ x, R.shape([1, 1001, 1, 1])
+ )
+ R.output(lv)
+ return lv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor((1, 1001, 1, 1), dtype="float16")
+ ) -> R.Tensor((1, 1001, 1, 1), dtype="float16"):
+ with R.dataflow():
+ lv: R.Tensor((1, 1001, 1, 1), dtype="float16") = x
+ R.output(lv)
+ return lv
+
+ _run_pass_compare_output(Before, Expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()