This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new ae346ec [DTYPE] Align bool parsing to align with DLPack (#262)
ae346ec is described below
commit ae346ec92a3c386f1376064ae086aae72947c329
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 14 18:40:40 2025 -0500
[DTYPE] Align bool parsing to align with DLPack (#262)
This PR aligns the bool parsing to align with DLPack.
---
python/tvm_ffi/_dtype.py | 1 +
src/ffi/dtype.cc | 13 ++++++-------
tests/python/test_dtype.py | 13 +++++++++++++
3 files changed, 20 insertions(+), 7 deletions(-)
diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 8e36bb4..216af55 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -33,6 +33,7 @@ class DataTypeCode(IntEnum):
FLOAT = 2
HANDLE = 3
BFLOAT = 4
+ BOOL = 6
Float8E3M4 = 7
Float8E4M3 = 8
Float8E4M3B11FNUZ = 9
diff --git a/src/ffi/dtype.cc b/src/ffi/dtype.cc
index 0b875e7..74c9eeb 100644
--- a/src/ffi/dtype.cc
+++ b/src/ffi/dtype.cc
@@ -160,7 +160,7 @@ inline void PrintDLDataTypeCodeAsStr(std::ostream& os,
DLDataTypeCode type_code)
* \return The output stream.
*/
inline std::string DLDataTypeToString_(DLDataType dtype) { // NOLINT(*)
- if (dtype.bits == 1 && dtype.lanes == 1 && dtype.code == kDLUInt) {
+ if (dtype.bits == 8 && dtype.lanes == 1 && dtype.code == kDLBool) {
return "bool";
}
// specially handle void
@@ -177,7 +177,7 @@ inline std::string DLDataTypeToString_(DLDataType dtype) {
// NOLINT(*)
}
if (dtype.code == kDLOpaqueHandle) return os.str();
int16_t lanes = static_cast<int16_t>(dtype.lanes);
- if (dtype.code < kDLFloat8_e3m4) {
+ if (dtype.code < kDLFloat8_e3m4 && (dtype.code != kDLBool || dtype.bits !=
8)) {
os << static_cast<int>(dtype.bits);
}
if (lanes > 1) {
@@ -228,6 +228,10 @@ inline DLDataType StringViewToDLDataType_(std::string_view
str) {
} else if (str.compare(0, 4, "uint") == 0) {
dtype.code = kDLUInt;
scan = str.data() + 4;
+ } else if (str.compare(0, 4, "bool") == 0) {
+ dtype.code = kDLBool;
+ dtype.bits = 8;
+ scan = str.data() + 4;
} else if (str.compare(0, 5, "float") == 0) {
if (str.compare(5, 2, "8_") == 0) {
if (str.compare(7, 4, "e3m4") == 0) {
@@ -279,11 +283,6 @@ inline DLDataType StringViewToDLDataType_(std::string_view
str) {
dtype.code = kDLOpaqueHandle;
dtype.bits = 64; // handle uses 64 bit by default.
scan = str.data() + 6;
- } else if (str == "bool") {
- dtype.code = kDLUInt;
- dtype.bits = 1;
- dtype.lanes = 1;
- return dtype;
} else if (str.compare(0, 6, "bfloat") == 0) {
dtype.code = kDLBfloat;
dtype.bits = 16;
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index cae5f84..2897faa 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -169,3 +169,16 @@ def test_dtype_from_dlpack_data_type() -> None:
assert dtype.type_code == 0
assert dtype.bits == 8
assert dtype.lanes == 1
+
+
+def test_dtype_bool() -> None:
+ dtype = tvm_ffi.dtype("bool")
+ assert dtype.type_code == 6
+ assert dtype.bits == 8
+ assert dtype.lanes == 1
+
+ dtype_with_lanes = dtype.with_lanes(4)
+ assert dtype_with_lanes.type_code == 6
+ assert dtype_with_lanes.bits == 8
+ assert dtype_with_lanes.lanes == 4
+ assert dtype_with_lanes == "boolx4"