This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 9753b471f9b1f3c40fb9601c8d49c0303f90504b Author: tqchen <[email protected]> AuthorDate: Thu Apr 24 08:58:07 2025 -0400 [FFI] bool compact --- ffi/include/tvm/ffi/container/ndarray.h | 10 ++++++++-- ffi/tests/cpp/test_ndarray.cc | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/ffi/include/tvm/ffi/container/ndarray.h b/ffi/include/tvm/ffi/container/ndarray.h index 642825655f..66bdbb6d70 100644 --- a/ffi/include/tvm/ffi/container/ndarray.h +++ b/ffi/include/tvm/ffi/container/ndarray.h @@ -83,7 +83,13 @@ inline bool IsAligned(const DLTensor& arr, size_t alignment) { * \param dtype the data type of the array * \return the total number bytes needs to store packed data */ -inline size_t GetPackedDataSize(int64_t numel, DLDataType dtype) { +inline size_t GetDataSize(int64_t numel, DLDataType dtype) { + // compatible handling sub-byte uint1(bool), which usually stored as uint8_t + // TODO: revisit and switch to kDLBool + if (dtype.code == kDLUInt && dtype.bits == 1 && dtype.lanes == 1) { + return numel; + } + // for other sub-byte types, packing is preferred return (numel * dtype.bits * dtype.lanes + 7) / 8; } @@ -98,7 +104,7 @@ inline size_t GetDataSize(const DLTensor& arr) { for (int i = 0; i < arr.ndim; ++i) { size *= static_cast<size_t>(arr.shape[i]); } - return GetPackedDataSize(size, arr.dtype); + return GetDataSize(size, arr.dtype); } /*! \brief An object representing an NDArray. */ diff --git a/ffi/tests/cpp/test_ndarray.cc b/ffi/tests/cpp/test_ndarray.cc index 0284ceb818..811227f073 100644 --- a/ffi/tests/cpp/test_ndarray.cc +++ b/ffi/tests/cpp/test_ndarray.cc @@ -83,7 +83,7 @@ TEST(NDArray, DLPack) { TEST(NDArray, DLPackVersioned) { DLDataType dtype = DLDataType({kDLFloat4_e2m1fn, 4, 1}); - EXPECT_EQ(GetPackedDataSize(2, dtype), 2 * 4 / 8); + EXPECT_EQ(GetDataSize(2, dtype), 2 * 4 / 8); NDArray nd = Empty({2}, dtype, DLDevice({kDLCPU, 0})); DLManagedTensorVersioned* dlpack = nd.ToDLPackVersioned(); EXPECT_EQ(dlpack->version.major, DLPACK_MAJOR_VERSION);
