This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch refactor-s2
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 9753b471f9b1f3c40fb9601c8d49c0303f90504b
Author: tqchen <[email protected]>
AuthorDate: Thu Apr 24 08:58:07 2025 -0400

    [FFI] bool compact
---
 ffi/include/tvm/ffi/container/ndarray.h | 10 ++++++++--
 ffi/tests/cpp/test_ndarray.cc           |  2 +-
 2 files changed, 9 insertions(+), 3 deletions(-)

diff --git a/ffi/include/tvm/ffi/container/ndarray.h 
b/ffi/include/tvm/ffi/container/ndarray.h
index 642825655f..66bdbb6d70 100644
--- a/ffi/include/tvm/ffi/container/ndarray.h
+++ b/ffi/include/tvm/ffi/container/ndarray.h
@@ -83,7 +83,13 @@ inline bool IsAligned(const DLTensor& arr, size_t alignment) 
{
  * \param dtype the data type of the array
  * \return the total number bytes needs to store packed data
  */
-inline size_t GetPackedDataSize(int64_t numel, DLDataType dtype) {
+inline size_t GetDataSize(int64_t numel, DLDataType dtype) {
+  // compatible handling sub-byte uint1(bool), which usually stored as uint8_t
+  // TODO: revisit and switch to kDLBool
+  if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) {
+    return numel;
+  }
+  // for other sub-byte types, packing is preferred
   return (numel * dtype.bits * dtype.lanes + 7) / 8;
 }
 
@@ -98,7 +104,7 @@ inline size_t GetDataSize(const DLTensor& arr) {
   for (int i = 0; i < arr.ndim; ++i) {
     size *= static_cast<size_t>(arr.shape[i]);
   }
-  return GetPackedDataSize(size, arr.dtype);
+  return GetDataSize(size, arr.dtype);
 }
 
 /*! \brief An object representing an NDArray. */
diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc
index 0284ceb818..811227f073 100644
--- a/ffi/tests/cpp/test_ndarray.cc
+++ b/ffi/tests/cpp/test_ndarray.cc
@@ -83,7 +83,7 @@ TEST(NDArray, DLPack) {
 
 TEST(NDArray, DLPackVersioned) {
   DLDataType dtype = DLDataType({kDLFloat4_e2m1fn, 4, 1});
-  EXPECT_EQ(GetPackedDataSize(2, dtype), 2 * 4 / 8);
+  EXPECT_EQ(GetDataSize(2, dtype), 2 * 4 / 8);
   NDArray nd = Empty({2}, dtype, DLDevice({kDLCPU, 0}));
   DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned();
   EXPECT_EQ(dlpack->version.major, DLPACK_MAJOR_VERSION);

Reply via email to