This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c9de001490 [Unity] [Transform] Skip constants in CSE pass (#16125)
c9de001490 is described below

commit c9de0014905c69b67d7e36488af12466d7d9a940
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Wed Nov 15 01:38:03 2023 +0530

    [Unity] [Transform] Skip constants in CSE pass (#16125)
    
    This patch modifies the CSE pass to skip all constants as
    [discussed 
here](https://discuss.tvm.apache.org/t/common-subexpr-elimination-pass-replaces-constant-args-with-vars/15971)
---
 src/relax/transform/eliminate_common_subexpr.cc | 2 +-
 tests/python/relax/test_transform_cse.py        | 8 +++++---
 2 files changed, 6 insertions(+), 4 deletions(-)

diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index 2addb60697..095274b0f8 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -95,7 +95,7 @@ class SubexprCounter : public ExprVisitor {
           e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
           e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
           e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
-          (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
+          e->IsInstance<ConstantNode>())) {
       // also if e has an impure subexpression, we will not deduplicate it
       if (!impurity_detector_.Detect(e)) {
         int count = 0;
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index d69ec61b5c..92cf4349d4 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -78,9 +78,11 @@ def test_constants():
         def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2), 
dtype="int32")):
             with R.dataflow():
                 lv0 = R.add(R.const(1, dtype="int32"), R.const(1, 
dtype="int32"))
-                lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
-                lv2 = R.add(lv1, lv1)
-                gv = (lv0, lv2)
+                lv1 = R.add(
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                    R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+                )
+                gv = (lv0, lv1)
                 R.output(gv)
             return gv
 

Reply via email to