This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b6ac0721a0 [DataType] Update to use explicit Bool Type Aligning with 
DLPack (#18453)
b6ac0721a0 is described below

commit b6ac0721a0a393e30a11d30d86b8caa65c59a263
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Nov 14 20:47:42 2025 -0500

    [DataType] Update to use explicit Bool Type Aligning with DLPack (#18453)
    
    This PR updates the project to use explicit bool type which helps us to
    align with dlpack. It will also streamline explicit use of bool types.
---
 3rdparty/tvm-ffi                                   |  2 +-
 include/tvm/runtime/data_type.h                    | 11 ++--
 include/tvm/tir/op.h                               |  6 +--
 python/tvm/script/parser/tir/operation.py          |  2 +
 python/tvm/tir/ir_builder.py                       |  2 +-
 src/arith/const_fold.h                             | 26 ++++-----
 src/arith/const_int_bound.cc                       |  5 +-
 src/ir/expr.cc                                     |  7 +--
 src/relax/transform/utils.h                        |  2 +-
 src/runtime/vm/builtin.cc                          |  2 +-
 src/target/llvm/codegen_llvm.cc                    |  7 ++-
 src/target/llvm/codegen_llvm.h                     |  1 +
 src/target/source/codegen_opencl.cc                |  6 +++
 src/target/source/codegen_source_base.cc           |  5 ++
 src/target/spirv/codegen_spirv.cc                  |  4 +-
 src/target/spirv/ir_builder.cc                     | 61 +++++++++++-----------
 src/tir/ir/expr.cc                                 |  2 +-
 src/tir/ir/stmt.cc                                 |  5 +-
 src/tir/op/op.cc                                   | 55 ++++++++++++-------
 src/tir/transforms/arg_binder.cc                   |  2 +-
 src/tir/transforms/inject_ptx_ldg32.cc             |  2 +-
 src/tir/transforms/lower_tvm_builtin.cc            |  4 +-
 tests/cpp/tir_scalable_datatype.cc                 |  4 +-
 tests/python/arith/test_arith_rewrite_simplify.py  | 22 ++++----
 tests/python/relax/test_op_nn.py                   |  2 -
 tests/python/tir-base/test_tir_constructor.py      | 12 ++---
 tests/python/tir-base/test_tir_nodes.py            |  2 +-
 tests/python/tir-base/test_tir_ops.py              | 14 ++---
 .../tvmscript/test_tvmscript_ir_builder_tir.py     |  2 +-
 .../python/tvmscript/test_tvmscript_printer_tir.py |  4 +-
 30 files changed, 159 insertions(+), 122 deletions(-)

diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi
index f703a0cf93..ae346ec92a 160000
--- a/3rdparty/tvm-ffi
+++ b/3rdparty/tvm-ffi
@@ -1 +1 @@
-Subproject commit f703a0cf9358fa30d8faee719f905c58d8ca6ee3
+Subproject commit ae346ec92a3c386f1376064ae086aae72947c329
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 0af3022bbd..0c698334ac 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -60,6 +60,7 @@ class DataType {
     kFloat = kDLFloat,
     kHandle = kDLOpaqueHandle,
     kBFloat = kDLBfloat,
+    kBool = kDLBool,
     kFloat8_e3m4 = kDLFloat8_e3m4,
     kFloat8_e4m3 = kDLFloat8_e4m3,
     kFloat8_e4m3b11fnuz = kDLFloat8_e4m3b11fnuz,
@@ -137,8 +138,10 @@ class DataType {
   }
   /*! \return whether type is a scalar type. */
   bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
-  /*! \return whether type is a scalar type. */
-  bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
+  /*! \return whether type is a bool type. */
+  bool is_bool() const { return code() == DataType::kBool; }
+  /*! \return whether type can be used in a predicate expression. */
+  bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() 
== 1); }
   /*! \return whether type is a float type. */
   bool is_float() const { return code() == DataType::kFloat; }
   /*! \return whether type is a bfloat type. */
@@ -204,7 +207,7 @@ class DataType {
   /*! \return whether type is a vector type. */
   bool is_vector() const { return lanes() > 1; }
   /*! \return whether type is a bool vector type. */
-  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && 
bits() == 1; }
+  bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && 
is_bool(); }
   /*! \return whether type is a Void type. */
   bool is_void() const {
     return code() == DataType::kHandle && bits() == 0 && 
static_cast<int16_t>(data_.lanes) == 0;
@@ -381,7 +384,7 @@ class DataType {
    * \return The constructed data type.
    */
   static DataType Bool(int lanes = 1, bool is_scalable = false) {
-    return DataType::UInt(1, lanes, is_scalable);
+    return DataType(kDLBool, 8, lanes, is_scalable);
   }
   /*!
    * \brief Construct a handle type.
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index 6a0f427b80..57f8681514 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -816,7 +816,7 @@ inline PrimExpr make_zero(DataType t, Span span = Span());
  * \return The result expression.
  */
 inline PrimExpr const_true(int lanes = 1, Span span = Span()) {
-  return make_const(DataType::UInt(1, lanes), 1);
+  return make_const(DataType::Bool(lanes), 1);
 }
 /*!
  * \brief Make a constant false expression.
@@ -825,7 +825,7 @@ inline PrimExpr const_true(int lanes = 1, Span span = 
Span()) {
  * \return The result expression.
  */
 inline PrimExpr const_false(int lanes = 1, Span span = Span()) {
-  return make_const(DataType::UInt(1, lanes), 0);
+  return make_const(DataType::Bool(lanes), 0);
 }
 /*!
  * \brief Get x as constant int expression.
@@ -957,7 +957,7 @@ inline bool is_no_op(const tir::Stmt& stmt) {
 
 template <typename ValueType>
 inline PrimExpr MakeConstScalar(DataType t, ValueType value, Span span = 
Span()) {
-  if (t.is_int()) return IntImm(t, static_cast<int64_t>(value), span);
+  if (t.is_int() || t.is_bool()) return IntImm(t, static_cast<int64_t>(value), 
span);
   if (t.is_uint()) {
     // Use IntImm if it is a small integer
     uint64_t uval = static_cast<uint64_t>(value);
diff --git a/python/tvm/script/parser/tir/operation.py 
b/python/tvm/script/parser/tir/operation.py
index 22f996a456..b22b0a7335 100644
--- a/python/tvm/script/parser/tir/operation.py
+++ b/python/tvm/script/parser/tir/operation.py
@@ -61,6 +61,7 @@ def _register_expr_op(ty: Type):  # pylint: 
disable=invalid-name
                 if (
                     DataType(b.dtype).type_code == DataTypeCode.INT
                     or DataType(b.dtype).type_code == DataTypeCode.UINT
+                    or DataType(b.dtype).type_code == DataTypeCode.BOOL
                 ):
                     a = IntImm(_get_type_str(b.dtype), a)
                 elif DataType(b.dtype).type_code == DataTypeCode.FLOAT:
@@ -80,6 +81,7 @@ def _register_expr_op(ty: Type):  # pylint: 
disable=invalid-name
             if (
                 DataType(a.dtype).type_code == DataTypeCode.INT
                 or DataType(a.dtype).type_code == DataTypeCode.UINT
+                or DataType(a.dtype).type_code == DataTypeCode.BOOL
             ):
                 b = IntImm(_get_type_str(a.dtype), b)
             elif DataType(a.dtype).type_code == DataTypeCode.FLOAT:
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index d6466b0922..a6313ae3bc 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -448,7 +448,7 @@ class IRBuilder(object):
         )
 
         buffer_var = buffer.data
-        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, 
dtype="uint1"), x))
+        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, 
dtype="bool"), x))
         return BufferVar(self, buffer, dtype)
 
     def pointer(self, content_type, name="ptr", scope=""):
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index dda7f67465..5118204db6 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -349,8 +349,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
   });
   return std::nullopt;
 }
@@ -358,8 +358,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
   });
   return std::nullopt;
 }
@@ -367,8 +367,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
   });
   return std::nullopt;
 }
@@ -376,8 +376,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value);
   });
   return std::nullopt;
 }
@@ -385,8 +385,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value);
   });
   return std::nullopt;
 }
@@ -394,8 +394,8 @@ inline ffi::Optional<PrimExpr> 
TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
 template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
   TVM_ARITH_CONST_PROPAGATION({
-    if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
-    if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
+    if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value);
+    if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value);
   });
   return std::nullopt;
 }
@@ -426,7 +426,7 @@ template <>
 inline ffi::Optional<PrimExpr> TryConstFold<tir::Not>(PrimExpr a) {
   const IntImmNode* pa = a.as<IntImmNode>();
   if (pa) {
-    return IntImm(DataType::UInt(1), !(pa->value));
+    return IntImm(DataType::Bool(), !(pa->value));
   }
   return std::nullopt;
 }
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 7e1d8fb3fb..d8296bafd9 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -735,9 +735,12 @@ class ConstIntBoundAnalyzer::Impl
    * \return Bound that represent everything dtype can represent.
    */
   static Entry Everything(DataType dtype) {
-    if (!dtype.is_int() && !dtype.is_uint()) {
+    if (!dtype.is_int() && !dtype.is_uint() && !dtype.is_bool()) {
       return MakeBound(kNegInf, kPosInf);
     }
+    if (dtype.is_bool()) {
+      return MakeBound(0, 1);
+    }
     Entry ret;
     int64_t vbits = dtype.bits() - static_cast<int>(dtype.is_int());
     if (dtype.is_uint()) {
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index 6c0065c29c..b856854a5d 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -53,8 +53,9 @@ PrimExpr PrimExpr::ConvertFallbackValue(ffi::String value) { 
return tir::StringI
 IntImm::IntImm(DataType dtype, int64_t value, Span span) {
   ICHECK(dtype.is_scalar()) << "ValueError: IntImm can only take scalar, but " 
<< dtype
                             << " was supplied.";
-  ICHECK(dtype.is_int() || dtype.is_uint())
-      << "ValueError: IntImm supports only int or uint type, but " << dtype << 
" was supplied.";
+  ICHECK(dtype.is_int() || dtype.is_uint() || dtype.is_bool())
+      << "ValueError: IntImm supports only int or uint or bool type, but " << 
dtype
+      << " was supplied.";
   if (dtype.is_uint()) {
     ICHECK_GE(value, 0U) << "ValueError: Literal value " << value
                          << " is negative for unsigned integer type " << dtype;
@@ -62,7 +63,7 @@ IntImm::IntImm(DataType dtype, int64_t value, Span span) {
       ICHECK_LT(value, 1LL << dtype.bits())
           << "ValueError: Literal value " << value << " exceeds maximum of " 
<< dtype;
     }
-  } else if (dtype.bits() == 1) {
+  } else if (dtype.bits() == 1 || dtype.is_bool()) {
     // int(1)
     ICHECK(value == 0 || value == 1) << "ValueError: " << value << " exceeds 
range of " << dtype;
   } else if (dtype.bits() < 64) {
diff --git a/src/relax/transform/utils.h b/src/relax/transform/utils.h
index ff8596cd79..5bcb5f2199 100644
--- a/src/relax/transform/utils.h
+++ b/src/relax/transform/utils.h
@@ -328,7 +328,7 @@ inline Constant MakeConstantScalar(T value, DataType dtype) 
{
     *static_cast<int32_t*>(arr->data) = static_cast<int32_t>(value);
   } else if (dtype == DataType::Int(64)) {
     *static_cast<int64_t*>(arr->data) = static_cast<int64_t>(value);
-  } else if (dtype == DataType::UInt(1)) {
+  } else if (dtype == DataType::Bool()) {
     *static_cast<bool*>(arr->data) = static_cast<bool>(value);
   } else if (dtype == DataType::UInt(8)) {
     *static_cast<uint8_t*>(arr->data) = static_cast<uint8_t>(value);
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 13446a158f..1bd3084c21 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -535,7 +535,7 @@ bool ReadIfCond(ffi::AnyView cond) {
   if (arr->device.device_type != kDLCPU) {
     arr = arr.CopyTo(DLDevice{kDLCPU, 0});
   }
-  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt);
+  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || 
arr->dtype.code == kDLBool);
   int64_t result;
   switch (arr->dtype.bits) {
     case 1: {
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bdb0c6b738..5f8b599a3b 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -148,6 +148,7 @@ void CodeGenLLVM::Init(const std::string& module_name, 
LLVMTarget* llvm_target,
   // types
   t_void_ = llvm::Type::getVoidTy(*ctx);
   t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx), 
GetGlobalAddressSpace());
+  t_int1_ = llvm::Type::getInt1Ty(*ctx);
   t_int_ = llvm::Type::getInt32Ty(*ctx);
   t_char_ = llvm::Type::getInt8Ty(*ctx);
   t_int8_ = llvm::Type::getInt8Ty(*ctx);
@@ -576,6 +577,8 @@ llvm::Type* CodeGenLLVM::DTypeToLLVMType(const DataType& 
dtype) const {
   llvm::LLVMContext* ctx = llvm_target_->GetContext();
   if (dtype.is_int() || dtype.is_uint()) {
     etype = llvm::Type::getIntNTy(*ctx, dtype.bits());
+  } else if (dtype.is_bool()) {
+    etype = t_int1_;
   } else if (dtype.is_float()) {
     switch (dtype.bits()) {
       case 16:
@@ -922,7 +925,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, 
DataType to, llvm::Value* va
 
   if (to.is_handle()) {
     return builder_->CreateBitCast(value, target);
-  } else if (to.is_uint() && to.bits() == 1) {
+  } else if (to.is_bool()) {
     if (from.is_float()) {
       llvm::Constant* zero = llvm::ConstantFP::get(DTypeToLLVMType(from), 0.);
       return builder_->CreateFCmpONE(value, zero);
@@ -943,7 +946,7 @@ llvm::Value* CodeGenLLVM::CreateCast(DataType from, 
DataType to, llvm::Value* va
     }
   } else if (from.is_int() && to.is_float()) {
     return builder_->CreateSIToFP(value, target);
-  } else if (from.is_uint() && to.is_float()) {
+  } else if ((from.is_uint() || from.is_bool()) && to.is_float()) {
     return builder_->CreateUIToFP(value, target);
   } else {
     ICHECK(from.is_float() && to.is_float());
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 5cf053cf71..efec7ad6ad 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -536,6 +536,7 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const 
PrimExpr&)>,
   llvm::Type* t_void_{nullptr};
   llvm::PointerType* t_void_p_{nullptr};
   llvm::Type* t_int_{nullptr};
+  llvm::Type* t_int1_{nullptr};
   llvm::Type* t_char_{nullptr};
   llvm::Type* t_int8_{nullptr};
   llvm::Type* t_int16_{nullptr};
diff --git a/src/target/source/codegen_opencl.cc 
b/src/target/source/codegen_opencl.cc
index 769401c4bc..8ea55b8ff5 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -230,6 +230,12 @@ void CodeGenOpenCL::PrintType(DataType t, std::ostream& 
os) {  // NOLINT(*)
       os << lanes;
       return;
     }
+  } else if (t.is_bool()) {
+    os << "uint";
+    if (!fail && ((lanes >= 2 && lanes <= 4) || lanes == 8 || lanes == 16)) {
+      os << lanes;
+      return;
+    }
   } else if (t.is_uint() || t.is_int()) {
     if (t.is_uint()) {
       os << 'u';
diff --git a/src/target/source/codegen_source_base.cc 
b/src/target/source/codegen_source_base.cc
index 60fa786d52..917036b8e2 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -109,6 +109,11 @@ void CodeGenSourceBase::PrintType(DataType type, 
std::ostream& os) {  // NOLINT(
     os << "void";
     return;
   }
+  // default c may be have bool type, can be handled in subclass
+  if (type.is_bool()) {
+    os << "int";
+    return;
+  }
   if (type.is_float()) {
     if (type.bits() == 32) {
       os << "float";
diff --git a/src/target/spirv/codegen_spirv.cc 
b/src/target/spirv/codegen_spirv.cc
index ddbc22d88a..c062926cc2 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -430,7 +430,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
     spirv::Value dst_ptr =
         builder_->StructArrayAccess(dst_ptr_type, var_map_[buffer_node], 
MakeValue(dst_index));
     spirv::Value src_ptr = VisitExpr(op->args[5]);
-    spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+    spirv::SType type_bool = builder_->GetSType(DataType::Bool());
     spirv::Value t_val = builder_->UIntImm(type_bool, 1);
     spirv::Value f_val = builder_->UIntImm(type_bool, 0);
     spirv::Value loaded =
@@ -492,7 +492,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
         builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], 
MakeValue(index));
     uint32_t mask = spv::MemoryAccessMaskNone;
     spirv::Value loaded = builder_->MakeValue(spv::OpLoad, fragment_type, ptr, 
mask);
-    spirv::SType type_bool = builder_->GetSType(DataType::UInt(1));
+    spirv::SType type_bool = builder_->GetSType(DataType::Bool());
     spirv::Value t_val = builder_->UIntImm(type_bool, 1);
     spirv::Value f_val = builder_->UIntImm(type_bool, 0);
     builder_->MakeInst(spv::OpCooperativeMatrixStoreNV, dst_ptr, loaded, 
stride_val,
diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc
index 545e677af9..bac66a3aac 100644
--- a/src/target/spirv/ir_builder.cc
+++ b/src/target/spirv/ir_builder.cc
@@ -76,7 +76,7 @@ void IRBuilder::InitPreDefs() {
   ext_glsl450_ = ExtInstImport("GLSL.std.450");
   t_int32_ = DeclareType(DataType::Int(32));
   t_uint32_ = DeclareType(DataType::UInt(32));
-  t_bool_ = DeclareType(DataType::UInt(1));
+  t_bool_ = DeclareType(DataType::Bool());
   t_fp32_ = DeclareType(DataType::Float(32));
   const_i32_zero_ = IntImm(t_int32_, 0);
 
@@ -115,7 +115,7 @@ std::vector<uint32_t> IRBuilder::Finalize() {
 SType IRBuilder::GetSType(const DataType& dtype, uint32_t row, uint32_t col) {
   if (dtype == DataType::Int(32)) {
     return t_int32_;
-  } else if (dtype == DataType::UInt(1)) {
+  } else if (dtype == DataType::Bool()) {
     return t_bool_;
   } else if (dtype == DataType::Float(32)) {
     return t_fp32_;
@@ -467,7 +467,7 @@ Value IRBuilder::GetConst_(const SType& dtype, const 
uint64_t* pvalue) {
   }
   ICHECK_LE(dtype.type.bits(), 64);
   Value ret = NewValue(dtype, kConstant);
-  if (dtype.type == DataType::UInt(1)) {
+  if (dtype.type == DataType::Bool()) {
     // bool types.
     if (*pvalue) {
       ib_.Begin(spv::OpConstantTrue).AddSeq(dtype, ret);
@@ -501,8 +501,7 @@ SType IRBuilder::DeclareType(const DataType& dtype, 
uint32_t row, uint32_t col)
     SType t;
     t.id = id_counter_++;
     t.type = dtype;
-    if (dtype.bits() == 1) {
-      ICHECK(dtype.is_uint());
+    if (dtype.is_bool()) {
       ib_.Begin(spv::OpTypeBool).Add(t).Commit(&global_);
     } else if (dtype.is_int()) {
       ib_.Begin(spv::OpTypeInt).AddSeq(t, dtype.bits(), 1).Commit(&global_);
@@ -584,7 +583,7 @@ void IRBuilder::AddCapabilityFor(const DataType& dtype) {
   // future.  Requiring StorageBuffer8BitAccess in order to declare an
   // Int8 prevents use of an 8-bit loop iterator on a device that
   // supports Int8 but doesn't support 8-bit buffer access.
-  if (dtype.bits() == 8) {
+  if (dtype.bits() == 8 && !dtype.is_bool()) {
     ICHECK(spirv_support_.supports_storage_buffer_8bit_access)
         << "Vulkan target does not support StorageBuffer8BitAccess.  "
         << "If your device supports 8-bit buffer access, "
@@ -822,19 +821,19 @@ Value IRBuilder::Mod(Value a, Value b) {
   }
 }
 
-#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                                    
                 \
-  Value IRBuilder::_OpName(Value a, Value b) {                                 
                 \
-    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                 \
-    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                 \
-    const auto& bool_type = 
this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
-    if (a.stype.type.is_int()) {                                               
                 \
-      return MakeValue(spv::OpS##_Op, bool_type, a, b);                        
                 \
-    } else if (a.stype.type.is_uint()) {                                       
                 \
-      return MakeValue(spv::OpU##_Op, bool_type, a, b);                        
                 \
-    } else {                                                                   
                 \
-      ICHECK(a.stype.type.is_float());                                         
                 \
-      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                 \
-    }                                                                          
                 \
+#define DEFINE_BUILDER_CMP_OP(_OpName, _Op)                                    
                \
+  Value IRBuilder::_OpName(Value a, Value b) {                                 
                \
+    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                \
+    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                \
+    const auto& bool_type = 
this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int()) {                                               
                \
+      return MakeValue(spv::OpS##_Op, bool_type, a, b);                        
                \
+    } else if (a.stype.type.is_uint()) {                                       
                \
+      return MakeValue(spv::OpU##_Op, bool_type, a, b);                        
                \
+    } else {                                                                   
                \
+      ICHECK(a.stype.type.is_float());                                         
                \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                \
+    }                                                                          
                \
   }
 
 DEFINE_BUILDER_CMP_OP(LT, LessThan);
@@ -842,17 +841,17 @@ DEFINE_BUILDER_CMP_OP(LE, LessThanEqual);
 DEFINE_BUILDER_CMP_OP(GT, GreaterThan);
 DEFINE_BUILDER_CMP_OP(GE, GreaterThanEqual);
 
-#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                                   
                 \
-  Value IRBuilder::_OpName(Value a, Value b) {                                 
                 \
-    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                 \
-    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                 \
-    const auto& bool_type = 
this->GetSType(DataType::UInt(1).with_lanes(a.stype.type.lanes())); \
-    if (a.stype.type.is_int() || a.stype.type.is_uint()) {                     
                 \
-      return MakeValue(spv::OpI##_Op, bool_type, a, b);                        
                 \
-    } else {                                                                   
                 \
-      ICHECK(a.stype.type.is_float());                                         
                 \
-      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                 \
-    }                                                                          
                 \
+#define DEFINE_BUILDER_CMP_UOP(_OpName, _Op)                                   
                \
+  Value IRBuilder::_OpName(Value a, Value b) {                                 
                \
+    ICHECK_EQ(a.stype.id, b.stype.id);                                         
                \
+    ICHECK_EQ(a.stype.type.lanes(), b.stype.type.lanes());                     
                \
+    const auto& bool_type = 
this->GetSType(DataType::Bool().with_lanes(a.stype.type.lanes())); \
+    if (a.stype.type.is_int() || a.stype.type.is_uint()) {                     
                \
+      return MakeValue(spv::OpI##_Op, bool_type, a, b);                        
                \
+    } else {                                                                   
                \
+      ICHECK(a.stype.type.is_float());                                         
                \
+      return MakeValue(spv::OpFOrd##_Op, bool_type, a, b);                     
                \
+    }                                                                          
                \
   }
 
 DEFINE_BUILDER_CMP_UOP(EQ, Equal);
@@ -860,7 +859,7 @@ DEFINE_BUILDER_CMP_UOP(NE, NotEqual);
 
 Value IRBuilder::Select(Value cond, Value a, Value b) {
   ICHECK_EQ(a.stype.id, b.stype.id);
-  ICHECK_EQ(cond.stype.type.element_of(), DataType::UInt(1));
+  ICHECK_EQ(cond.stype.type.element_of(), DataType::Bool());
   return MakeValue(spv::OpSelect, a.stype, cond, a, b);
 }
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 252b8693a7..5eee4ffd8b 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -840,7 +840,7 @@ BufferLoad::BufferLoad(Buffer buffer, ffi::Array<PrimExpr> 
indices,
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << 
predicate_element_dtype
         << ".";
   }
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d33a01340b..47622757e5 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -485,7 +485,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, 
ffi::Array<PrimExpr> ind
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << 
predicate_element_dtype
         << ".";
   }
@@ -687,7 +687,8 @@ BlockRealize::BlockRealize(ffi::Array<PrimExpr> values, 
PrimExpr predicate, Bloc
                            Span span) {
   CHECK_EQ(block->iter_vars.size(), values.size())
       << "ValueError: BlockRealize needs to have the same number of iter_vars 
and binding values";
-  CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to 
be a bool expression";
+  CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1))
+      << "TypeError: Expect Block.predicate to be a bool expression";
   ObjectPtr<BlockRealizeNode> node = ffi::make_object<BlockRealizeNode>();
   node->iter_values = std::move(values);
   node->predicate = std::move(predicate);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 935f9928a5..51c0b64ed2 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -214,6 +214,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span 
span) {  // NOLINT(*)
   } else if (ltype.is_float4() && !rtype.is_float4()) {
     // Cast int->float4 for rhs when lhs is a float4
     rhs = cast(ltype, rhs);
+  } else if (ltype.is_bool() && (rtype.is_int() || rtype.is_uint())) {
+    // Cast bool to int for lhs when rhs is a int or uint
+    lhs = cast(rtype, lhs);
+  } else if ((ltype.is_int() || ltype.is_uint()) && rtype.is_bool()) {
+    // Cast bool to int for rhs when lhs is a int or uint
+    rhs = cast(ltype, rhs);
   } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && 
rtype.is_uint())) {
     // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
     if (ltype.bits() < rtype.bits()) {
@@ -621,7 +627,7 @@ PrimExpr max(PrimExpr a, PrimExpr b, Span span) {
 
 // if_then_else
 PrimExpr if_then_else(PrimExpr cond, PrimExpr true_value, PrimExpr 
false_value, Span span) {
-  ICHECK(cond.dtype() == DataType::Bool(1))
+  ICHECK(cond.dtype() == DataType::Bool())
       << "if_then_else only accept the condition to be boolean type.";
   BinaryOpMatchTypes(true_value, false_value, span);
   if (const IntImmNode* op = cond.as<IntImmNode>()) {
@@ -698,10 +704,10 @@ void type_check_boolean_args(const PrimExpr& lhs, const 
PrimExpr& rhs, const cha
                                 << rhs << " of type " << rhs.dtype();
 }
 
-void type_check_integer_args(const PrimExpr& arg, const char* op) {
-  ICHECK(arg.dtype().is_int() || arg.dtype().is_uint())
-      << "Expected integer argument for " << op << ", but received " << arg << 
" of type "
-      << arg.dtype();
+void type_check_int_or_bool_args(const PrimExpr& arg, const char* op) {
+  ICHECK(arg.dtype().is_int() || arg.dtype().is_uint() || 
arg.dtype().is_bool())
+      << "Expected integer or boolean argument for " << op << ", but received 
" << arg
+      << " of type " << arg.dtype();
 }
 
 void type_check_integer_args(const PrimExpr& lhs, const PrimExpr& rhs, const 
char* op) {
@@ -712,6 +718,15 @@ void type_check_integer_args(const PrimExpr& lhs, const 
PrimExpr& rhs, const cha
       << "Expected integer argument as RHS of " << op << ", but received " << 
rhs << " of type "
       << rhs.dtype();
 }
+
+void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, 
const char* op) {
+  ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || 
lhs.dtype().is_bool())
+      << "Expected integer argument as LHS of " << op << ", but received " << 
lhs << " of type "
+      << lhs.dtype();
+  ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || 
rhs.dtype().is_bool())
+      << "Expected integer argument as RHS of " << op << ", but received " << 
rhs << " of type "
+      << rhs.dtype();
+}
 }  // namespace
 
 PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
@@ -781,7 +796,7 @@ PrimExpr left_shift(PrimExpr a, PrimExpr b, Span span) {
 // bitwise and
 PrimExpr operator&(PrimExpr a, PrimExpr b) { return bitwise_and(a, b); }
 PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "& operator (bitwise AND)");
+  type_check_int_or_bool_args(a, b, "& operator (bitwise AND)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -793,7 +808,7 @@ PrimExpr bitwise_and(PrimExpr a, PrimExpr b, Span span) {
 // bitwise_or
 PrimExpr operator|(PrimExpr a, PrimExpr b) { return bitwise_or(a, b); }
 PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "| operator (bitwise OR)");
+  type_check_int_or_bool_args(a, b, "| operator (bitwise OR)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -805,7 +820,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
 // bitwise_xor
 PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
 PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "^ operator (bitwise XOR)");
+  type_check_int_or_bool_args(a, b, "^ operator (bitwise XOR)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
@@ -818,7 +833,7 @@ PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
 PrimExpr operator~(PrimExpr a) { return bitwise_neg(a); }
 
 PrimExpr bitwise_neg(PrimExpr a, Span span) {
-  type_check_integer_args(a, "~ operator (bitwise NOT)");
+  type_check_int_or_bool_args(a, "~ operator (bitwise NOT)");
   return tir::Call(a.dtype(), tir::builtin::bitwise_not(), {a}, span);
 }
 
@@ -935,7 +950,7 @@ PrimExpr sum(PrimExpr source, ffi::Array<IterVar> rdom, 
ffi::Array<PrimExpr> ini
   PrimExpr result = tir::Add(x, y, span);
   PrimExpr identity_element = make_zero(source.dtype(), span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), 
true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init, span);
 }
 
 PrimExpr all(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> 
init, Span span) {
@@ -944,7 +959,7 @@ PrimExpr all(PrimExpr source, ffi::Array<IterVar> rdom, 
ffi::Array<PrimExpr> ini
   PrimExpr result = tir::And(x, y, span);
   PrimExpr identity_element = make_const(source.dtype(), true, span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), 
true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init, span);
 }
 
 PrimExpr any(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> 
init, Span span) {
@@ -953,7 +968,7 @@ PrimExpr any(PrimExpr source, ffi::Array<IterVar> rdom, 
ffi::Array<PrimExpr> ini
   PrimExpr result = tir::Or(x, y, span);
   PrimExpr identity_element = make_const(source.dtype(), false, span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), 
true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init, span);
 }
 
 PrimExpr max(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> 
init, Span span) {
@@ -961,7 +976,7 @@ PrimExpr max(PrimExpr source, ffi::Array<IterVar> rdom, 
ffi::Array<PrimExpr> ini
   PrimExpr result = tir::Max(x, y, span);
   PrimExpr identity_element = min_value(source.dtype(), span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), 
true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init, span);
 }
 
 PrimExpr min(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> 
init, Span span) {
@@ -969,7 +984,7 @@ PrimExpr min(PrimExpr source, ffi::Array<IterVar> rdom, 
ffi::Array<PrimExpr> ini
   PrimExpr result = tir::Min(x, y, span);
   PrimExpr identity_element = max_value(source.dtype(), span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), 
true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init, span);
 }
 
 PrimExpr prod(PrimExpr source, ffi::Array<IterVar> rdom, ffi::Array<PrimExpr> 
init, Span span) {
@@ -977,7 +992,7 @@ PrimExpr prod(PrimExpr source, ffi::Array<IterVar> rdom, 
ffi::Array<PrimExpr> in
   PrimExpr result = tir::Mul(x, y, span);
   PrimExpr identity_element = make_const(source.dtype(), 1, span);
   tir::CommReducer combiner = tir::CommReducer({x}, {y}, {result}, 
{identity_element}, span);
-  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(1), 
true), 0, init, span);
+  return tir::Reduce(combiner, {source}, rdom, make_const(DataType::Bool(), 
true), 0, init, span);
 }
 
 // fmod
@@ -992,7 +1007,7 @@ TVM_TIR_REGISTER_PURE_UNARY_OP("fmod");
 
 // floor
 PrimExpr floor(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1006,7 +1021,7 @@ 
TVM_TIR_REGISTER_PURE_UNARY_OP("floor").set_attr<TVectorizable>("TVectorizable",
 
 // ceil
 PrimExpr ceil(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1020,7 +1035,7 @@ 
TVM_TIR_REGISTER_PURE_UNARY_OP("ceil").set_attr<TVectorizable>("TVectorizable",
 
 // round
 PrimExpr round(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1034,7 +1049,7 @@ 
TVM_TIR_REGISTER_PURE_UNARY_OP("round").set_attr<TVectorizable>("TVectorizable",
 
 // nearbyint
 PrimExpr nearbyint(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
@@ -1048,7 +1063,7 @@ TVM_TIR_REGISTER_PURE_UNARY_OP("nearbyint");
 
 // trunc
 PrimExpr trunc(PrimExpr x, Span span) {
-  if (x.dtype().is_int() || x.dtype().is_uint()) {
+  if (x.dtype().is_int() || x.dtype().is_uint() || x.dtype().is_bool()) {
     return x;
   }
   using tir::FloatImmNode;
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index 8a5d39ec35..1b85d7d211 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -218,7 +218,7 @@ void ArgBinder::BindDLTensor(const Buffer& buffer, const 
PrimExpr& device_type,
   init_nest_.emplace_back(LetStmt(
       buf_strides->data, TVMArrayGet(DataType::Handle(), handle, 
builtin::kArrStrides), nop));
   init_nest_.emplace_back(DeclBuffer(buf_strides, nop));
-  PrimExpr v_strides_is_null = Call(DataType::Bool(1), builtin::isnullptr(), 
{buf_strides->data});
+  PrimExpr v_strides_is_null = Call(DataType::Bool(), builtin::isnullptr(), 
{buf_strides->data});
   if (buffer->strides.size() == 0) {
     // Assert the buffer is compact
     DataType stype = buffer->DefaultIndexType();
diff --git a/src/tir/transforms/inject_ptx_ldg32.cc 
b/src/tir/transforms/inject_ptx_ldg32.cc
index 1b4bd7b410..8cdef1be44 100644
--- a/src/tir/transforms/inject_ptx_ldg32.cc
+++ b/src/tir/transforms/inject_ptx_ldg32.cc
@@ -41,7 +41,7 @@ class PTXRewriter : public StmtMutator {
       // addr[0] -> global_addr /  addr[1] -> local_addr
       addr_buffer = decl_buffer({IntImm(DataType::Int(32), 2)}, 
DataType::Int(32), "addr", "local");
       predicate_buffer =
-          decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(1), 
"predicate", "local");
+          decl_buffer({IntImm(DataType::Int(32), 1)}, DataType::Bool(), 
"predicate", "local");
     }
     Stmt result = StmtMutator::VisitStmt_(allocate);
     if (!has_buffer_2) {
diff --git a/src/tir/transforms/lower_tvm_builtin.cc 
b/src/tir/transforms/lower_tvm_builtin.cc
index f6df6c877d..66e13791f3 100644
--- a/src/tir/transforms/lower_tvm_builtin.cc
+++ b/src/tir/transforms/lower_tvm_builtin.cc
@@ -256,7 +256,7 @@ class BuiltinLower : public StmtExprMutator {
     Stmt throw_last_error = Evaluate(Call(DataType::Int(32), 
builtin::tvm_throw_last_error(), {}));
 
     Stmt alloc_nullptr_check = IfThenElse(
-        Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), 
throw_last_error);
+        Call(DataType::Bool(), builtin::isnullptr(), {op->buffer_var}), 
throw_last_error);
     PrimExpr free_op = Call(DataType::Int(32), 
Op::Get("tir.TVMBackendFreeWorkspace"),
                             {cast(DataType::Int(32), device_type_.value()),
                              cast(DataType::Int(32), device_id_.value()), 
op->buffer_var});
@@ -617,7 +617,7 @@ class BuiltinLower : public StmtExprMutator {
     Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), 
throw_last_error);
 
     Stmt body = SeqStmt(
-        {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), 
throw_last_error),
+        {IfThenElse(Call(DataType::Bool(), builtin::isnullptr(), {let->var}), 
throw_last_error),
          let->body, free_stmt});
 
     DataType dtype =
diff --git a/tests/cpp/tir_scalable_datatype.cc 
b/tests/cpp/tir_scalable_datatype.cc
index 6c42972d94..6ae6deb50d 100644
--- a/tests/cpp/tir_scalable_datatype.cc
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -167,8 +167,8 @@ TEST(ScalableDataType, 
TestScalableDataTypeInvalidLanesAccess) {
 
 TEST(ScalableDataType, TestScalableBool) {
   tvm::DataType scalable_type = tvm::DataType::Bool(4, true);
-  ASSERT_EQ(scalable_type.code(), kDLUInt);
-  ASSERT_EQ(scalable_type.bits(), 1);
+  ASSERT_EQ(scalable_type.code(), kDLBool);
+  ASSERT_EQ(scalable_type.bits(), 8);
   ASSERT_EQ(scalable_type.vscale_factor(), 4);
   ASSERT_TRUE(scalable_type.is_scalable_vector());
 }
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index 6954cf4e1d..5eaaac68f0 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -93,7 +93,7 @@ class TestVector(BaseCompare):
     x, y, z = te.var("x"), te.var("y"), te.var("z")
     x64 = te.var("x", dtype="int64")
     vx = te.var("vx", dtype="int32x2")
-    vc = te.var("vc", dtype="uint1")
+    vc = te.var("vc", dtype="bool")
     test_case = tvm.testing.parameter(
         # Add rules
         TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x 
+ y, 3, 4)),
@@ -285,22 +285,22 @@ class TestVector(BaseCompare):
             tvm.te.max(vx, tvm.te.max(y, x).astype("int32x2")),
         ),
         ## Logical rules
-        TestCase(y.astype("int32x2").equal(x.astype("int32x2")), 
(y.equal(x)).astype("uint1x2")),
+        TestCase(y.astype("int32x2").equal(x.astype("int32x2")), 
(y.equal(x)).astype("boolx2")),
         TestCase(
             tvm.tir.NE(y.astype("int32x2"), (x.astype("int32x2"))),
-            (tvm.tir.NE(y, x)).astype("uint1x2"),
+            (tvm.tir.NE(y, x)).astype("boolx2"),
         ),
-        TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < 
y).astype("uint1x2")),
-        TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= 
y).astype("uint1x2")),
-        TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < 
x).astype("uint1x2")),
-        TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= 
x).astype("uint1x2")),
+        TestCase(y.astype("int32x2") > x.astype("int32x2"), (x < 
y).astype("boolx2")),
+        TestCase(y.astype("int32x2") >= x.astype("int32x2"), (x <= 
y).astype("boolx2")),
+        TestCase(y.astype("int32x2") < x.astype("int32x2"), (y < 
x).astype("boolx2")),
+        TestCase(y.astype("int32x2") <= x.astype("int32x2"), (y <= 
x).astype("boolx2")),
         TestCase(
-            tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("uint1x2")),
-            (tvm.tir.And(y <= x, vc)).astype("uint1x2"),
+            tvm.tir.And(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("boolx2")),
+            (tvm.tir.And(y <= x, vc)).astype("boolx2"),
         ),
         TestCase(
-            tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("uint1x2")),
-            (tvm.tir.Or(y <= x, vc)).astype("uint1x2"),
+            tvm.tir.Or(y.astype("int32x2") <= x.astype("int32x2"), 
vc.astype("boolx2")),
+            (tvm.tir.Or(y <= x, vc)).astype("boolx2"),
         ),
     )
 
diff --git a/tests/python/relax/test_op_nn.py b/tests/python/relax/test_op_nn.py
index a0ff507ef8..b076827dc4 100644
--- a/tests/python/relax/test_op_nn.py
+++ b/tests/python/relax/test_op_nn.py
@@ -1721,7 +1721,6 @@ def test_nll_loss_infer_struct_info_targets_dtype():
     w = relax.Var("w", R.Tensor((5,), "float32"))
     targets0 = relax.Var("targets", R.Tensor((3, 10, 10), "float32"))
     targets1 = relax.Var("targets", R.Tensor((3, 10, 10), "float64"))
-    targets2 = relax.Var("targets", R.Tensor((3, 10, 10), "bool"))
     targets3 = relax.Var("targets", R.Tensor((3, 10, 10), "int32"))
     targets4 = relax.Var("targets", R.Tensor((3, 10, 10), "int64"))
     targets5 = relax.Var("targets", R.Tensor((3, 10, 10), "uint32"))
@@ -1733,7 +1732,6 @@ def test_nll_loss_infer_struct_info_targets_dtype():
         bb.normalize(relax.op.nn.nll_loss(x, targets1, w))
 
     # correct cases
-    bb.normalize(relax.op.nn.nll_loss(x, targets2, w))  # bool is uint1
     bb.normalize(relax.op.nn.nll_loss(x, targets3, w))
     bb.normalize(relax.op.nn.nll_loss(x, targets4, w))
     bb.normalize(relax.op.nn.nll_loss(x, targets5, w))
diff --git a/tests/python/tir-base/test_tir_constructor.py 
b/tests/python/tir-base/test_tir_constructor.py
index 42c2998e27..4076070557 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -140,7 +140,7 @@ def test_stmt_constructor():
     assert isinstance(x, tvm.tir.AttrStmt)
     assert x.value.value == 1
 
-    x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), 
tvm.runtime.convert("hellow"), nop)
+    x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), 
tvm.runtime.convert("hellow"), nop)
     assert isinstance(x, tvm.tir.AssertStmt)
     assert x.body == nop
 
@@ -150,8 +150,8 @@ def test_stmt_constructor():
     assert x.extent.value == 10
     assert x.body == nop
 
-    buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("uint1")))
-    buffer = tvm.tir.decl_buffer([16], "uint1", data=buffer_var)
+    buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("bool")))
+    buffer = tvm.tir.decl_buffer([16], "bool", data=buffer_var)
     x = tvm.tir.BufferStore(buffer, tvm.tir.IntImm("bool", 1), [10])
     assert isinstance(x, tvm.tir.BufferStore)
     assert x.buffer == buffer
@@ -160,7 +160,7 @@ def test_stmt_constructor():
     assert x.value.value == 1
 
     buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("float32")))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -168,7 +168,7 @@ def test_stmt_constructor():
 
     storage_scope = "global.texture"
     buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -181,7 +181,7 @@ def test_stmt_constructor():
     assert x.attr_key == "xyz"
     assert x.body == nop
 
-    x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), 
nop)
+    x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop)
     assert isinstance(x, tvm.tir.IfThenElse)
     assert x.then_case.value.value == 11
     assert x.else_case == nop
diff --git a/tests/python/tir-base/test_tir_nodes.py 
b/tests/python/tir-base/test_tir_nodes.py
index 5e1d25e48b..bc7cfeae17 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -302,7 +302,7 @@ def test_isnan():
     z = te.var("z", "int32")
     assert str(tvm.tir.isnan(z)) == "T.bool(False)"
     k = te.var("k", "int8x2")
-    assert str(tvm.tir.isnan(k).dtype) == "uint1x2"
+    assert str(tvm.tir.isnan(k).dtype) == "boolx2"
 
 
 def test_equality():
diff --git a/tests/python/tir-base/test_tir_ops.py 
b/tests/python/tir-base/test_tir_ops.py
index dfa5cbab80..cb7d8c597a 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -69,8 +69,8 @@ def test_const_fold3():
     x = te.var("x")
     for val in [0, 1]:
         for func in [tvm.tir.all, tvm.tir.any]:
-            check_throws(lambda: func(tvm.tir.const(val, "uint1"), x))
-            check_throws(lambda: func(x, tvm.tir.const(val, "uint1")))
+            check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
+            check_throws(lambda: func(x, tvm.tir.const(val, "bool")))
 
     # Test const folding when both arguments are const
     for tvm_func, py_func in [
@@ -80,13 +80,13 @@ def test_const_fold3():
         for v1 in [0, 1]:
             for v2 in [0, 1]:
                 tvm.ir.assert_structural_equal(
-                    tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, 
"uint1")),
-                    tvm.tir.const(py_func(v1, v2), "uint1"),
+                    tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2, 
"bool")),
+                    tvm.tir.const(py_func(v1, v2), "bool"),
                 )
 
-    x = te.var("x", "uint1")
-    true = tvm.tir.const(1, "uint1")
-    false = tvm.tir.const(0, "uint1")
+    x = te.var("x", "bool")
+    true = tvm.tir.const(1, "bool")
+    false = tvm.tir.const(0, "bool")
 
     assert tvm.tir.all(x, true).same_as(x)
     assert tvm.tir.all(true, x).same_as(x)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py 
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index db6f4ba47f..8352b11644 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -366,7 +366,7 @@ def test_ir_builder_tir_allocate():
     # the expected allocate
     buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), 
"local"))
     ir_expected = tir.Allocate(
-        buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+        buffer_var, "float32", [10], tvm.tir.const(1, "bool"), tir.Evaluate(1)
     )
 
     # Check if the generated ir is expected
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index fc7deacd98..e4af158074 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -961,13 +961,13 @@ def test_predicated_buffer_load_store():
     buffer_load = tir.BufferLoad(
         buffer=buffer_map[b],
         indices=[0, tir.Ramp(0, 4, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     body = tir.BufferStore(
         buffer=buffer_map[a],
         value=buffer_load,
         indices=[0, tir.Ramp(0, 2, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     func = tir.PrimFunc(
         params=[a, b],

Reply via email to