This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new ae346ec  [DTYPE] Align bool parsing to align with DLPack (#262)
ae346ec is described below

commit ae346ec92a3c386f1376064ae086aae72947c329
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 14 18:40:40 2025 -0500

    [DTYPE] Align bool parsing to align with DLPack (#262)
    
    This PR aligns the bool parsing to align with DLPack.
---
 python/tvm_ffi/_dtype.py   |  1 +
 src/ffi/dtype.cc           | 13 ++++++-------
 tests/python/test_dtype.py | 13 +++++++++++++
 3 files changed, 20 insertions(+), 7 deletions(-)

diff --git a/python/tvm_ffi/_dtype.py b/python/tvm_ffi/_dtype.py
index 8e36bb4..216af55 100644
--- a/python/tvm_ffi/_dtype.py
+++ b/python/tvm_ffi/_dtype.py
@@ -33,6 +33,7 @@ class DataTypeCode(IntEnum):
     FLOAT = 2
     HANDLE = 3
     BFLOAT = 4
+    BOOL = 6
     Float8E3M4 = 7
     Float8E4M3 = 8
     Float8E4M3B11FNUZ = 9
diff --git a/src/ffi/dtype.cc b/src/ffi/dtype.cc
index 0b875e7..74c9eeb 100644
--- a/src/ffi/dtype.cc
+++ b/src/ffi/dtype.cc
@@ -160,7 +160,7 @@ inline void PrintDLDataTypeCodeAsStr(std::ostream& os, 
DLDataTypeCode type_code)
  *  \return The output stream.
  */
 inline std::string DLDataTypeToString_(DLDataType dtype) {  // NOLINT(*)
-  if (dtype.bits == 1 && dtype.lanes == 1 && dtype.code == kDLUInt) {
+  if (dtype.bits == 8 && dtype.lanes == 1 && dtype.code == kDLBool) {
     return "bool";
   }
   // specially handle void
@@ -177,7 +177,7 @@ inline std::string DLDataTypeToString_(DLDataType dtype) {  
// NOLINT(*)
   }
   if (dtype.code == kDLOpaqueHandle) return os.str();
   int16_t lanes = static_cast<int16_t>(dtype.lanes);
-  if (dtype.code < kDLFloat8_e3m4) {
+  if (dtype.code < kDLFloat8_e3m4 && (dtype.code != kDLBool || dtype.bits != 
8)) {
     os << static_cast<int>(dtype.bits);
   }
   if (lanes > 1) {
@@ -228,6 +228,10 @@ inline DLDataType StringViewToDLDataType_(std::string_view 
str) {
   } else if (str.compare(0, 4, "uint") == 0) {
     dtype.code = kDLUInt;
     scan = str.data() + 4;
+  } else if (str.compare(0, 4, "bool") == 0) {
+    dtype.code = kDLBool;
+    dtype.bits = 8;
+    scan = str.data() + 4;
   } else if (str.compare(0, 5, "float") == 0) {
     if (str.compare(5, 2, "8_") == 0) {
       if (str.compare(7, 4, "e3m4") == 0) {
@@ -279,11 +283,6 @@ inline DLDataType StringViewToDLDataType_(std::string_view 
str) {
     dtype.code = kDLOpaqueHandle;
     dtype.bits = 64;  // handle uses 64 bit by default.
     scan = str.data() + 6;
-  } else if (str == "bool") {
-    dtype.code = kDLUInt;
-    dtype.bits = 1;
-    dtype.lanes = 1;
-    return dtype;
   } else if (str.compare(0, 6, "bfloat") == 0) {
     dtype.code = kDLBfloat;
     dtype.bits = 16;
diff --git a/tests/python/test_dtype.py b/tests/python/test_dtype.py
index cae5f84..2897faa 100644
--- a/tests/python/test_dtype.py
+++ b/tests/python/test_dtype.py
@@ -169,3 +169,16 @@ def test_dtype_from_dlpack_data_type() -> None:
     assert dtype.type_code == 0
     assert dtype.bits == 8
     assert dtype.lanes == 1
+
+
+def test_dtype_bool() -> None:
+    dtype = tvm_ffi.dtype("bool")
+    assert dtype.type_code == 6
+    assert dtype.bits == 8
+    assert dtype.lanes == 1
+
+    dtype_with_lanes = dtype.with_lanes(4)
+    assert dtype_with_lanes.type_code == 6
+    assert dtype_with_lanes.bits == 8
+    assert dtype_with_lanes.lanes == 4
+    assert dtype_with_lanes == "boolx4"

Reply via email to