This is an automated email from the ASF dual-hosted git repository.
sslyu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0628bdba93 [Relax][Pass] Skip data type node for CSE pass (#16493)
0628bdba93 is described below
commit 0628bdba93d3520fc4f6a77afb3474392b545b0f
Author: Abhikrant Sharma <[email protected]>
AuthorDate: Thu Feb 1 03:25:41 2024 +0530
[Relax][Pass] Skip data type node for CSE pass (#16493)
* [Relax][Pass] Skip data type node for CSE pass
- The problem is seen when an arg of relax op is dtype
* Add comments to code
---
src/relax/transform/eliminate_common_subexpr.cc | 3 ++-
tests/python/relax/test_transform_cse.py | 24 ++++++++++++++++++++++++
2 files changed, 26 insertions(+), 1 deletion(-)
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index 095274b0f8..7931d73b7b 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -91,11 +91,12 @@ class SubexprCounter : public ExprVisitor {
// 4. StringImm nodes (not much benefit from binding to a var)
// 5. Scalar constants (not much benefit from binding to a var)
// 6. Shape expressions (exist to hold several PrimValue objects)
+ // 7. DataType nodes (no need to modify dtype nodes)
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
- e->IsInstance<ConstantNode>())) {
+ e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
// also if e has an impure subexpression, we will not deduplicate it
if (!impurity_detector_.Detect(e)) {
int count = 0;
diff --git a/tests/python/relax/test_transform_cse.py
b/tests/python/relax/test_transform_cse.py
index 3a57afb22c..2a247c342c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -339,5 +339,29 @@ def test_call_tir_tuple_arg():
tvm.ir.assert_structural_equal(Expected, After)
+def test_do_not_eliminate_dtype():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo() -> R.Tensor((32, 64), "int32"):
+ obj: R.Object = R.vm.alloc_storage(
+ R.shape([24576]), runtime_device_index=0, dtype="uint8"
+ )
+ a: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor(
+ obj, offset=0, shape=R.shape([32, 64]), dtype="int32"
+ )
+ ret_val: R.Tensor([32, 64], dtype="int32") =
R.builtin.alloc_tensor(
+ R.shape([32, 64]), R.dtype("int32"), R.prim_value(0)
+ )
+ _t1: R.Tuple = R.vm.kill_object(a)
+ _t3: R.Tuple = R.vm.kill_object(obj)
+ lv: R.Tensor([32, 64], dtype="int32") = ret_val
+ return lv
+
+ Expected = Before
+
+ verify(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()