This is an automated email from the ASF dual-hosted git repository.

sslyu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0628bdba93 [Relax][Pass] Skip data type node for CSE pass (#16493)
0628bdba93 is described below

commit 0628bdba93d3520fc4f6a77afb3474392b545b0f
Author: Abhikrant Sharma <[email protected]>
AuthorDate: Thu Feb 1 03:25:41 2024 +0530

    [Relax][Pass] Skip data type node for CSE pass (#16493)
    
    * [Relax][Pass] Skip data type node for CSE pass
    - The problem is seen when an arg of relax op is dtype
    
    * Add comments to code
---
 src/relax/transform/eliminate_common_subexpr.cc |  3 ++-
 tests/python/relax/test_transform_cse.py        | 24 ++++++++++++++++++++++++
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index 095274b0f8..7931d73b7b 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -91,11 +91,12 @@ class SubexprCounter : public ExprVisitor {
     // 4. StringImm nodes (not much benefit from binding to a var)
     // 5. Scalar constants (not much benefit from binding to a var)
     // 6. Shape expressions (exist to hold several PrimValue objects)
+    // 7. DataType nodes (no need to modify dtype nodes)
     if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
           e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
           e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
           e->IsInstance<ShapeExprNode>() || e->IsInstance<ExternFuncNode>() ||
-          e->IsInstance<ConstantNode>())) {
+          e->IsInstance<ConstantNode>() || e->IsInstance<DataTypeImmNode>())) {
       // also if e has an impure subexpression, we will not deduplicate it
       if (!impurity_detector_.Detect(e)) {
         int count = 0;
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index 3a57afb22c..2a247c342c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -339,5 +339,29 @@ def test_call_tir_tuple_arg():
     tvm.ir.assert_structural_equal(Expected, After)
 
 
+def test_do_not_eliminate_dtype():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo() -> R.Tensor((32, 64), "int32"):
+            obj: R.Object = R.vm.alloc_storage(
+                R.shape([24576]), runtime_device_index=0, dtype="uint8"
+            )
+            a: R.Tensor([32, 64], dtype="int32") = R.vm.alloc_tensor(
+                obj, offset=0, shape=R.shape([32, 64]), dtype="int32"
+            )
+            ret_val: R.Tensor([32, 64], dtype="int32") = 
R.builtin.alloc_tensor(
+                R.shape([32, 64]), R.dtype("int32"), R.prim_value(0)
+            )
+            _t1: R.Tuple = R.vm.kill_object(a)
+            _t3: R.Tuple = R.vm.kill_object(obj)
+            lv: R.Tensor([32, 64], dtype="int32") = ret_val
+            return lv
+
+    Expected = Before
+
+    verify(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to