This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c9de001490 [Unity] [Transform] Skip constants in CSE pass (#16125)
c9de001490 is described below
commit c9de0014905c69b67d7e36488af12466d7d9a940
Author: Anirudh Sundar Subramaniam <[email protected]>
AuthorDate: Wed Nov 15 01:38:03 2023 +0530
[Unity] [Transform] Skip constants in CSE pass (#16125)
This patch modifies the CSE pass to skip all constants as
[discussed
here](https://discuss.tvm.apache.org/t/common-subexpr-elimination-pass-replaces-constant-args-with-vars/15971)
---
src/relax/transform/eliminate_common_subexpr.cc | 2 +-
tests/python/relax/test_transform_cse.py | 8 +++++---
2 files changed, 6 insertions(+), 4 deletions(-)
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index 2addb60697..095274b0f8 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -95,7 +95,7 @@ class SubexprCounter : public ExprVisitor {
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
- (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
+ e->IsInstance<ConstantNode>())) {
// also if e has an impure subexpression, we will not deduplicate it
if (!impurity_detector_.Detect(e)) {
int count = 0;
diff --git a/tests/python/relax/test_transform_cse.py
b/tests/python/relax/test_transform_cse.py
index d69ec61b5c..92cf4349d4 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -78,9 +78,11 @@ def test_constants():
def foo() -> R.Tuple(R.Tensor((), dtype="int32"), R.Tensor((2, 2),
dtype="int32")):
with R.dataflow():
lv0 = R.add(R.const(1, dtype="int32"), R.const(1,
dtype="int32"))
- lv1 = R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32")))
- lv2 = R.add(lv1, lv1)
- gv = (lv0, lv2)
+ lv1 = R.add(
+ R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+ R.const(tvm.nd.array(np.zeros((2, 2), dtype="int32"))),
+ )
+ gv = (lv0, lv1)
R.output(gv)
return gv