This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a156181ee3 [Relax] Fix EliminiateCommonSubexpr removing alloc tensor
(#16852)
a156181ee3 is described below
commit a156181ee3242407aa3c0e1565c18896b9d2f06b
Author: Wuwei Lin <[email protected]>
AuthorDate: Sat Apr 6 05:45:26 2024 -0700
[Relax] Fix EliminiateCommonSubexpr removing alloc tensor (#16852)
---
src/relax/op/op.cc | 15 ++++++++----
src/relax/transform/eliminate_common_subexpr.cc | 15 ++++++++++++
tests/python/relax/test_transform_cse.py | 32 +++++++++++++++++++++++++
3 files changed, 57 insertions(+), 5 deletions(-)
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index 7eb499f102..77cf4a2c6f 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -851,7 +851,8 @@ RELAY_REGISTER_OP("relax.builtin.alloc_tensor")
"The storage scope of the storage to allocate. Default is
global.")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoAllocateTensor)
// memory allocation isn't considered a "visible effect" as far as purity
is concerned
- .set_attr<Bool>("FPurity", Bool(true));
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<Bool>("TAllocator", Bool(true));
Expr MakeAllocTensor(Expr shape, DataTypeImm dtype, PrimValue
runtime_device_index,
StringImm storage_scope) {
@@ -875,7 +876,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_storage")
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to
allocate.")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
// memory allocation isn't considered a "visible effect" as far as purity
is concerned
- .set_attr<Bool>("FPurity", Bool(true));
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<Bool>("TAllocator", Bool(true));
Expr MakeAllocStorage(Expr size, PrimValue virtual_device_index, StringImm
storage_scope,
DataTypeImm dtype) {
@@ -906,7 +908,8 @@ RELAY_REGISTER_OP("relax.memory.alloc_tensor")
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to
allocate.")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoMemAllocTensor)
// memory allocation isn't considered a "visible effect" as far as purity
is concerned
- .set_attr<Bool>("FPurity", Bool(true));
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<Bool>("TAllocator", Bool(true));
Expr MakeMemAllocTensor(Expr storage, PrimValue offset, Expr shape,
DataTypeImm dtype) {
static const Op& op = Op::Get("relax.memory.alloc_tensor");
@@ -960,7 +963,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_storage")
"The storage scope of the storage to allocate. Default is
global.")
.set_attr<FInferStructInfo>("FInferStructInfo", ReturnObjectStructInfo)
// memory allocation isn't considered a "visible effect" as far as purity
is concerned
- .set_attr<Bool>("FPurity", Bool(true));
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<Bool>("TAllocator", Bool(true));
Expr MakeVMAllocStorage(Expr size, PrimValue runtime_device_index, DataTypeImm
dtype,
StringImm storage_scope) {
@@ -998,7 +1002,8 @@ RELAY_REGISTER_OP("relax.vm.alloc_tensor")
.add_argument("dtype", "DataTypeImm", "The dtype of the tensor to
allocate.")
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoVMAllocTensor)
// memory allocation isn't considered a "visible effect" as far as purity
is concerned
- .set_attr<Bool>("FPurity", Bool(true));
+ .set_attr<Bool>("FPurity", Bool(true))
+ .set_attr<Bool>("TAllocator", Bool(true));
Expr MakeVMAllocTensor(Expr storage, PrimValue offset, Expr shape, DataTypeImm
dtype) {
static const Op& op = Op::Get("relax.vm.alloc_tensor");
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index 5804b1c5bb..2b61174bcb 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -126,6 +126,8 @@ class CommonSubexprEliminator : public ExprMutator {
} else if (ContainsImpureCall(bound_value)) {
VLOG(1) << "Since the expression is impure, cannot de-duplicate " <<
bound_value;
+ } else if (IsAllocatorCall(bound_value)) {
+ VLOG(1) << "Skip allocator calls";
} else if (auto it = expr_replacements_.find(lookup_key);
it != expr_replacements_.end() && it->second.size()) {
VLOG(1) << "Value " << bound_value << " has previously been bound as "
<< it->second[0]
@@ -186,6 +188,19 @@ class CommonSubexprEliminator : public ExprMutator {
return clean_mutator.VisitExpr(expr);
}
+ bool IsAllocatorCall(const Expr& expr) {
+ static const auto& allocator_attr_map = Op::GetAttrMap<Bool>("TAllocator");
+ if (const auto* call = expr.as<CallNode>()) {
+ if (const auto* op = call->op.as<OpNode>()) {
+ bool is_allocator = allocator_attr_map.get(GetRef<Op>(op),
Bool(false))->value;
+ if (is_allocator) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
bool call_only_{false};
std::unordered_map<ReplacementKey, std::vector<Var>> expr_replacements_;
};
diff --git a/tests/python/relax/test_transform_cse.py
b/tests/python/relax/test_transform_cse.py
index 0998fb67c0..bb10704acb 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -627,5 +627,37 @@ def test_keep_duplicate_after_branch():
verify(Before, Expected)
+def test_keep_alloc_tensor():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32")):
+ tmp_buf1 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"),
R.prim_value(0))
+ tmp_buf2 = R.builtin.alloc_tensor(R.shape([64]), R.dtype("int32"),
R.prim_value(0))
+ out = R.add(tmp_buf1, tmp_buf2)
+ return out
+
+ Expected = Before
+
+ verify(Before, Expected)
+
+
+def test_keep_alloc_storage():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32")):
+ tmp_storage1 = R.vm.alloc_storage(R.shape([64]),
runtime_device_index=0, dtype="uint8")
+ tmp_buf1 = R.vm.alloc_tensor(tmp_storage1, offset=0,
shape=R.shape([64]), dtype="int32")
+ tmp_storage2 = R.vm.alloc_storage(R.shape([64]),
runtime_device_index=0, dtype="uint8")
+ tmp_buf2 = R.vm.alloc_tensor(tmp_storage2, offset=0,
shape=R.shape([64]), dtype="int32")
+ out = R.add(tmp_buf1, tmp_buf2)
+ return out
+
+ Expected = Before
+
+ verify(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()