This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a156181ee3 [Relax] Fix EliminiateCommonSubexpr removing alloc tensor 
(#16852)
a156181ee3 is described below

commit a156181ee3242407aa3c0e1565c18896b9d2f06b
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Apr 6 05:45:26 2024 -0700

    [Relax] Fix EliminiateCommonSubexpr removing alloc tensor (#16852)
---
 src/relax/op/op.cc                              | 15 ++++++++----
 src/relax/transform/eliminate_common_subexpr.cc | 15 ++++++++++++
 tests/python/relax/test_transform_cse.py        | 32 +++++++++++++++++++++++++
 3 files changed, 57 insertions(+), 5 deletions(-)

diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 7eb499f102..77cf4a2c6f 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -851,7 +851,8 @@ RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
                   "The storage scope of the storage to allocate. Default is 
global.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoAllocateTensor)
     // memory allocation isn't considered a "visible effect" as far as purity 
is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    .set_attr<Bool>("FPurity", Bool(true))
+    .set_attr<Bool>("TAllocator", Bool(true));
 
 Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue 
runtime_device_index,
                      StringImm storage_scope) {
@@ -875,7 +876,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_storage")
     .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
     // memory allocation isn't considered a "visible effect" as far as purity 
is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    .set_attr<Bool>("FPurity", Bool(true))
+    .set_attr<Bool>("TAllocator", Bool(true));
 
 Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm 
storage_scope,
                       DataTypeImm dtype) {
@@ -906,7 +908,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_tensor")
     .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoMemAllocTensor)
     // memory allocation isn't considered a "visible effect" as far as purity 
is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    .set_attr<Bool>("FPurity", Bool(true))
+    .set_attr<Bool>("TAllocator", Bool(true));
 
 Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape, 
DataTypeImm dtype) {
   static const Op& op = Op::Get("relax.memory.alloc_tensor");
@@ -960,7 +963,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_storage")
                   "The storage scope of the storage to allocate. Default is 
global.")
     .set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
     // memory allocation isn't considered a "visible effect" as far as purity 
is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    .set_attr<Bool>("FPurity", Bool(true))
+    .set_attr<Bool>("TAllocator", Bool(true));
 
 Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm 
dtype,
                         StringImm storage_scope) {
@@ -998,7 +1002,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_tensor")
     .add_argument("dtype", "DataTypeImm", "The dtype of the tensor to 
allocate.")
     .set_attr<FInferStructInfo>("FInferStructInfo", 
InferStructInfoVMAllocTensor)
     // memory allocation isn't considered a "visible effect" as far as purity 
is concerned
-    .set_attr<Bool>("FPurity", Bool(true));
+    .set_attr<Bool>("FPurity", Bool(true))
+    .set_attr<Bool>("TAllocator", Bool(true));
 
 Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm 
dtype) {
   static const Op& op = Op::Get("relax.vm.alloc_tensor");
diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index 5804b1c5bb..2b61174bcb 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -126,6 +126,8 @@ class CommonSubexprEliminator : public ExprMutator {
     } else if (ContainsImpureCall(bound_value)) {
       VLOG(1) << "Since the expression is impure, cannot de-duplicate " << 
bound_value;
 
+    } else if (IsAllocatorCall(bound_value)) {
+      VLOG(1) << "Skip allocator calls";
     } else if (auto it = expr_replacements_.find(lookup_key);
                it != expr_replacements_.end() && it->second.size()) {
       VLOG(1) << "Value " << bound_value << " has previously been bound as " 
<< it->second[0]
@@ -186,6 +188,19 @@ class CommonSubexprEliminator : public ExprMutator {
     return clean_mutator.VisitExpr(expr);
   }
 
+  bool IsAllocatorCall(const Expr& expr) {
+    static const auto& allocator_attr_map = Op::GetAttrMap<Bool>("TAllocator");
+    if (const auto* call = expr.as<CallNode>()) {
+      if (const auto* op = call->op.as<OpNode>()) {
+        bool is_allocator = allocator_attr_map.get(GetRef<Op>(op), 
Bool(false))->value;
+        if (is_allocator) {
+          return true;
+        }
+      }
+    }
+    return false;
+  }
+
   bool call_only_{false};
   std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_;
 };
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index 0998fb67c0..bb10704acb 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -627,5 +627,37 @@ def test_keep_duplicate_after_branch():
     verify(Before, Expected)
 
 
+def test_keep_alloc_tensor():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32")):
+            tmp_buf1 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"), 
R.prim_value(0))
+            tmp_buf2 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"), 
R.prim_value(0))
+            out = R.add(tmp_buf1, tmp_buf2)
+            return out
+
+    Expected = Before
+
+    verify(Before, Expected)
+
+
+def test_keep_alloc_storage():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32")):
+            tmp_storage1 = R.vm.alloc_storage(R.shape([64]), 
runtime_device_index=0, dtype="uint8")
+            tmp_buf1 = R.vm.alloc_tensor(tmp_storage1, offset=0, 
shape=R.shape([64]), dtype="int32")
+            tmp_storage2 = R.vm.alloc_storage(R.shape([64]), 
runtime_device_index=0, dtype="uint8")
+            tmp_buf2 = R.vm.alloc_tensor(tmp_storage2, offset=0, 
shape=R.shape([64]), dtype="int32")
+            out = R.add(tmp_buf1, tmp_buf2)
+            return out
+
+    Expected = Before
+
+    verify(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to