This is an automated email from the ASF dual-hosted git repository.

ruihangl pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 922e18c9ee [Unity][Fix] Fix block memory plan to handle bool (#14357)
922e18c9ee is described below

commit 922e18c9eea5ab3cd67df910b379d31c37c23134
Author: Prakalp Srivastava <[email protected]>
AuthorDate: Tue Mar 21 13:39:22 2023 -0400

    [Unity][Fix] Fix block memory plan to handle bool (#14357)
    
    This PR fixes the size calculation of storage tokens in Static Plan Block 
Memory pass. Prior to this PR the storage for a tensor  `R.Tensor((2, 3), 
dtype="bool")` would be calculated as `R.memory.alloc_storage(R.shape([1]), 
dtype="bool")` instead of `R.memory.alloc_storage(R.shape([6]), dtype="bool")`
---
 src/relax/transform/static_plan_block_memory.cc    |  2 +-
 .../test_transform_static_plan_block_memory.py     | 50 ++++++++++++++++++++++
 2 files changed, 51 insertions(+), 1 deletion(-)

diff --git a/src/relax/transform/static_plan_block_memory.cc 
b/src/relax/transform/static_plan_block_memory.cc
index ba5177fec0..3f0dfcf149 100644
--- a/src/relax/transform/static_plan_block_memory.cc
+++ b/src/relax/transform/static_plan_block_memory.cc
@@ -107,7 +107,7 @@ class StorageToken : public ObjectRef {
     }
 
     ObjectPtr<StorageTokenNode> n = make_object<StorageTokenNode>();
-    n->bytes = (size * dtype.bits() * dtype.lanes() + 7) / 8;
+    n->bytes = size * dtype.bytes() * dtype.lanes();
     n->dtype = dtype;
     data_ = std::move(n);
   }
diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py 
b/tests/python/relax/test_transform_static_plan_block_memory.py
index 21d18742dd..5198d9e075 100644
--- a/tests/python/relax/test_transform_static_plan_block_memory.py
+++ b/tests/python/relax/test_transform_static_plan_block_memory.py
@@ -216,6 +216,56 @@ def test_different_dtype():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
+def test_dtype_bool():
+    @tvm.script.ir_module
+    class Module:
+        @T.prim_func
+        def add1(
+            A: T.Buffer((T.int64(2), T.int64(3)), "bool"),
+            B: T.Buffer((T.int64(2), T.int64(3)), "bool"),
+            C: T.Buffer((T.int64(2), T.int64(3)), "bool"),
+        ):
+            T.evaluate(0)
+
+        @R.function
+        def main(y: R.Tensor((2, 3), dtype="bool")) -> R.Tensor((2, 3), 
dtype="bool"):
+            cls = Module
+            alloc: R.Tensor((2, 3), dtype="bool") = R.builtin.alloc_tensor(
+                R.shape([2, 3]), dtype="bool", runtime_device_index=0
+            )
+            _1: R.Tuple() = cls.add1(y, y, alloc)
+            gv1: R.Tensor((2, 3), dtype="bool") = alloc
+            return y
+
+    @tvm.script.ir_module
+    class Expected:
+        @T.prim_func
+        def add1(
+            A: T.Buffer((T.int64(2), T.int64(3)), "bool"),
+            B: T.Buffer((T.int64(2), T.int64(3)), "bool"),
+            C: T.Buffer((T.int64(2), T.int64(3)), "bool"),
+        ):
+            T.evaluate(0)
+
+        @R.function
+        def main(y: R.Tensor((2, 3), dtype="bool")) -> R.Tensor((2, 3), 
dtype="bool"):
+            cls = Expected
+            storage: R.Object = R.memory.alloc_storage(
+                R.shape([6]), virtual_device_index=0, storage_scope="global", 
dtype="bool"
+            )
+            alloc: R.Tensor((2, 3), dtype="bool") = R.memory.alloc_tensor(
+                storage, 0, R.shape([2, 3]), dtype="bool"
+            )
+            _2: R.Tuple() = cls.add1(y, y, alloc)
+            _3: R.Tuple() = R.memory.kill_tensor(alloc)
+            gv12: R.Tensor((2, 3), dtype="bool") = alloc
+            _4: R.Tuple() = R.memory.kill_storage(storage)
+            return y
+
+    mod = relax.transform.StaticPlanBlockMemory()(Module)
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 def test_same_dtype():
     @tvm.script.ir_module
     class Module:

Reply via email to