This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new bec2b8f48d [Unity] Preserve ShapeExpr in EliminateCommonSubexpr 
transform (#15701)
bec2b8f48d is described below

commit bec2b8f48d46f074beedde3bc7d20a0bca3aa9f9
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Sep 7 21:20:04 2023 -0700

    [Unity] Preserve ShapeExpr in EliminateCommonSubexpr transform (#15701)
    
    Prior to this commit, `R.ShapeExpr` occurring multiple times in a
    relax function would be de-duplicated.  If these shape expressions are
    de-duplicated, later use of `emit_te` may fail as it expects
    operations to have explicit `ShapeExpr` shapes.
---
 src/relax/transform/eliminate_common_subexpr.cc |  2 ++
 tests/python/relax/test_transform_cse.py        | 15 +++++++++++++++
 2 files changed, 17 insertions(+)

diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index 3452b6352b..8bbb05f327 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -82,9 +82,11 @@ class SubexprCounter : public ExprVisitor {
     // 3. PrimValue nodes (not much benefit from binding to a var)
     // 4. StringImm nodes (not much benefit from binding to a var)
     // 5. Scalar constants (not much benefit from binding to a var)
+    // 6. Shape expressions (exist to hold several PrimValue objects)
     if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
           e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
           e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
+          e->IsInstance<ShapeExprNode>() ||
           (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
       // also if e has an impure subexpression, we will not deduplicate it
       if (!impurity_detector_.Detect(e)) {
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index 5062f86cbc..cf66ae3c1c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -261,5 +261,20 @@ def test_do_not_eliminate_impure():
     verify(Before, Expected)
 
 
+def test_do_not_eliminate_shape_expr():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")):
+            x = R.reshape(x, [6])
+            y = R.reshape(y, [6])
+            z = R.add(x, y)
+            return z
+
+    Expected = Before
+
+    verify(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to