This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new bec2b8f48d [Unity] Preserve ShapeExpr in EliminateCommonSubexpr
transform (#15701)
bec2b8f48d is described below
commit bec2b8f48d46f074beedde3bc7d20a0bca3aa9f9
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Sep 7 21:20:04 2023 -0700
[Unity] Preserve ShapeExpr in EliminateCommonSubexpr transform (#15701)
Prior to this commit, `R.ShapeExpr` occurring multiple times in a
relax function would be de-duplicated. If these shape expressions are
de-duplicated, later use of `emit_te` may fail as it expects
operations to have explicit `ShapeExpr` shapes.
---
src/relax/transform/eliminate_common_subexpr.cc | 2 ++
tests/python/relax/test_transform_cse.py | 15 +++++++++++++++
2 files changed, 17 insertions(+)
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index 3452b6352b..8bbb05f327 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -82,9 +82,11 @@ class SubexprCounter : public ExprVisitor {
// 3. PrimValue nodes (not much benefit from binding to a var)
// 4. StringImm nodes (not much benefit from binding to a var)
// 5. Scalar constants (not much benefit from binding to a var)
+ // 6. Shape expressions (exist to hold several PrimValue objects)
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
+ e->IsInstance<ShapeExprNode>() ||
(e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
// also if e has an impure subexpression, we will not deduplicate it
if (!impurity_detector_.Detect(e)) {
diff --git a/tests/python/relax/test_transform_cse.py
b/tests/python/relax/test_transform_cse.py
index 5062f86cbc..cf66ae3c1c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -261,5 +261,20 @@ def test_do_not_eliminate_impure():
verify(Before, Expected)
+def test_do_not_eliminate_shape_expr():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
+ x = R.reshape(x, [6])
+ y = R.reshape(y, [6])
+ z = R.add(x, y)
+ return z
+
+ Expected = Before
+
+ verify(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()