This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch tvm-ffi-bool
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/tvm-ffi-bool by this push:
     new 11f324d9ea Update
11f324d9ea is described below

commit 11f324d9ea19dbe63c9aedacc1eec66f58fd3d28
Author: tqchen <[email protected]>
AuthorDate: Wed Nov 12 20:04:28 2025 -0500

    Update
---
 3rdparty/tvm-ffi                                        |  2 +-
 include/tvm/runtime/data_type.h                         |  2 ++
 python/tvm/tir/ir_builder.py                            |  2 +-
 src/runtime/vm/builtin.cc                               |  2 +-
 src/tir/ir/expr.cc                                      |  2 +-
 src/tir/ir/stmt.cc                                      |  4 ++--
 src/tir/op/op.cc                                        | 17 ++++++++++++++++-
 tests/python/tir-base/test_tir_constructor.py           |  8 ++++----
 tests/python/tir-base/test_tir_ops.py                   | 14 +++++++-------
 tests/python/tvmscript/test_tvmscript_ir_builder_tir.py |  2 +-
 tests/python/tvmscript/test_tvmscript_printer_tir.py    |  4 ++--
 11 files changed, 38 insertions(+), 21 deletions(-)

diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi
index 60f45ac017..5fcdf8597f 160000
--- a/3rdparty/tvm-ffi
+++ b/3rdparty/tvm-ffi
@@ -1 +1 @@
-Subproject commit 60f45ac017964caf2252b3c74a6e10a4422a1835
+Subproject commit 5fcdf8597f1ecb1a76e2eb0578bded73de91ace0
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index da355bd7ce..3a91d4777b 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -140,6 +140,8 @@ class DataType {
   bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
   /*! \return whether type is a scalar type. */
   bool is_bool() const { return code() == DataType::kBool; }
+  /*! \return whether type can be used in a predicate expression. */
+  bool is_predicate_dtype() const { return is_bool() || (is_uint() && bits() 
== 1); }
   /*! \return whether type is a float type. */
   bool is_float() const { return code() == DataType::kFloat; }
   /*! \return whether type is a bfloat type. */
diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py
index d6466b0922..a6313ae3bc 100644
--- a/python/tvm/tir/ir_builder.py
+++ b/python/tvm/tir/ir_builder.py
@@ -448,7 +448,7 @@ class IRBuilder(object):
         )
 
         buffer_var = buffer.data
-        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, 
dtype="uint1"), x))
+        self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, 
dtype="bool"), x))
         return BufferVar(self, buffer, dtype)
 
     def pointer(self, content_type, name="ptr", scope=""):
diff --git a/src/runtime/vm/builtin.cc b/src/runtime/vm/builtin.cc
index 13446a158f..1bd3084c21 100644
--- a/src/runtime/vm/builtin.cc
+++ b/src/runtime/vm/builtin.cc
@@ -535,7 +535,7 @@ bool ReadIfCond(ffi::AnyView cond) {
   if (arr->device.device_type != kDLCPU) {
     arr = arr.CopyTo(DLDevice{kDLCPU, 0});
   }
-  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt);
+  ICHECK(arr->dtype.code == kDLInt || arr->dtype.code == kDLUInt || 
arr->dtype.code == kDLBool);
   int64_t result;
   switch (arr->dtype.bits) {
     case 1: {
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 252b8693a7..5eee4ffd8b 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -840,7 +840,7 @@ BufferLoad::BufferLoad(Buffer buffer, ffi::Array<PrimExpr> 
indices,
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << 
predicate_element_dtype
         << ".";
   }
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index d33a01340b..781fb887ff 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -485,7 +485,7 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value, 
ffi::Array<PrimExpr> ind
         << " lanes. The number of lanes must match.";
 
     DataType predicate_element_dtype = predicate_dtype.element_of();
-    ICHECK(predicate_element_dtype.is_bool())
+    ICHECK(predicate_element_dtype.is_predicate_dtype())
         << "Predicate mask elements must be boolean values, but got " << 
predicate_element_dtype
         << ".";
   }
@@ -687,7 +687,7 @@ BlockRealize::BlockRealize(ffi::Array<PrimExpr> values, 
PrimExpr predicate, Bloc
                            Span span) {
   CHECK_EQ(block->iter_vars.size(), values.size())
       << "ValueError: BlockRealize needs to have the same number of iter_vars 
and binding values";
-  CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to 
be a bool expression";
+  CHECK(predicate.dtype().is_bool() || predicate.dtype() == DataType::UInt(1)) 
<< "TypeError: Expect Block.predicate to be a bool expression";
   ObjectPtr<BlockRealizeNode> node = ffi::make_object<BlockRealizeNode>();
   node->iter_values = std::move(values);
   node->predicate = std::move(predicate);
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 935f9928a5..d6d68e5410 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -214,6 +214,12 @@ void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span 
span) {  // NOLINT(*)
   } else if (ltype.is_float4() && !rtype.is_float4()) {
     // Cast int->float4 for rhs when lhs is a float4
     rhs = cast(ltype, rhs);
+  } else if (ltype.is_bool() && (rtype.is_int() || rtype.is_uint())) {
+    // Cast bool to int for lhs when rhs is a int or uint
+    lhs = cast(rtype, lhs);
+  } else if ((ltype.is_int() || ltype.is_uint()) && rtype.is_bool()) {
+    // Cast bool to int for rhs when lhs is a int or uint
+    rhs = cast(ltype, rhs);
   } else if ((ltype.is_int() && rtype.is_int()) || (ltype.is_uint() && 
rtype.is_uint())) {
     // Promote int to higher bits e.g. int8 + int16 --> int16 + int16
     if (ltype.bits() < rtype.bits()) {
@@ -712,6 +718,15 @@ void type_check_integer_args(const PrimExpr& lhs, const 
PrimExpr& rhs, const cha
       << "Expected integer argument as RHS of " << op << ", but received " << 
rhs << " of type "
       << rhs.dtype();
 }
+
+void type_check_int_or_bool_args(const PrimExpr& lhs, const PrimExpr& rhs, 
const char* op) {
+  ICHECK(lhs.dtype().is_int() || lhs.dtype().is_uint() || 
lhs.dtype().is_bool())
+      << "Expected integer argument as LHS of " << op << ", but received " << 
lhs << " of type "
+      << lhs.dtype();
+  ICHECK(rhs.dtype().is_int() || rhs.dtype().is_uint() || 
rhs.dtype().is_bool())
+      << "Expected integer argument as RHS of " << op << ", but received " << 
rhs << " of type "
+      << rhs.dtype();
+}
 }  // namespace
 
 PrimExpr operator&&(PrimExpr a, PrimExpr b) { return logical_and(a, b); }
@@ -805,7 +820,7 @@ PrimExpr bitwise_or(PrimExpr a, PrimExpr b, Span span) {
 // bitwise_xor
 PrimExpr operator^(PrimExpr a, PrimExpr b) { return bitwise_xor(a, b); }
 PrimExpr bitwise_xor(PrimExpr a, PrimExpr b, Span span) {
-  type_check_integer_args(a, b, "^ operator (bitwise XOR)");
+  type_check_int_or_bool_args(a, b, "^ operator (bitwise XOR)");
   BinaryOpMatchTypes(a, b, span);
   TVM_INDEX_CONST_PROPAGATION({
     const DataType& rtype = a.dtype();
diff --git a/tests/python/tir-base/test_tir_constructor.py 
b/tests/python/tir-base/test_tir_constructor.py
index 42c2998e27..fe64efa39b 100644
--- a/tests/python/tir-base/test_tir_constructor.py
+++ b/tests/python/tir-base/test_tir_constructor.py
@@ -140,7 +140,7 @@ def test_stmt_constructor():
     assert isinstance(x, tvm.tir.AttrStmt)
     assert x.value.value == 1
 
-    x = tvm.tir.AssertStmt(tvm.tir.const(1, "uint1"), 
tvm.runtime.convert("hellow"), nop)
+    x = tvm.tir.AssertStmt(tvm.tir.const(1, "bool"), 
tvm.runtime.convert("hellow"), nop)
     assert isinstance(x, tvm.tir.AssertStmt)
     assert x.body == nop
 
@@ -160,7 +160,7 @@ def test_stmt_constructor():
     assert x.value.value == 1
 
     buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("float32")))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -168,7 +168,7 @@ def test_stmt_constructor():
 
     storage_scope = "global.texture"
     buffer_var = tvm.tir.Var("buf", 
tvm.ir.PointerType(tvm.ir.PrimType("float32"), storage_scope))
-    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"uint1"), nop)
+    x = tvm.tir.Allocate(buffer_var, "float32", [10], tvm.tir.const(1, 
"bool"), nop)
     assert isinstance(x, tvm.tir.Allocate)
     assert x.dtype == "float32"
     assert x.buffer_var == buffer_var
@@ -181,7 +181,7 @@ def test_stmt_constructor():
     assert x.attr_key == "xyz"
     assert x.body == nop
 
-    x = tvm.tir.IfThenElse(tvm.tir.const(1, "uint1"), tvm.tir.Evaluate(11), 
nop)
+    x = tvm.tir.IfThenElse(tvm.tir.const(1, "bool"), tvm.tir.Evaluate(11), nop)
     assert isinstance(x, tvm.tir.IfThenElse)
     assert x.then_case.value.value == 11
     assert x.else_case == nop
diff --git a/tests/python/tir-base/test_tir_ops.py 
b/tests/python/tir-base/test_tir_ops.py
index dfa5cbab80..cb7d8c597a 100644
--- a/tests/python/tir-base/test_tir_ops.py
+++ b/tests/python/tir-base/test_tir_ops.py
@@ -69,8 +69,8 @@ def test_const_fold3():
     x = te.var("x")
     for val in [0, 1]:
         for func in [tvm.tir.all, tvm.tir.any]:
-            check_throws(lambda: func(tvm.tir.const(val, "uint1"), x))
-            check_throws(lambda: func(x, tvm.tir.const(val, "uint1")))
+            check_throws(lambda: func(tvm.tir.const(val, "bool"), x))
+            check_throws(lambda: func(x, tvm.tir.const(val, "bool")))
 
     # Test const folding when both arguments are const
     for tvm_func, py_func in [
@@ -80,13 +80,13 @@ def test_const_fold3():
         for v1 in [0, 1]:
             for v2 in [0, 1]:
                 tvm.ir.assert_structural_equal(
-                    tvm_func(tvm.tir.const(v1, "uint1"), tvm.tir.const(v2, 
"uint1")),
-                    tvm.tir.const(py_func(v1, v2), "uint1"),
+                    tvm_func(tvm.tir.const(v1, "bool"), tvm.tir.const(v2, 
"bool")),
+                    tvm.tir.const(py_func(v1, v2), "bool"),
                 )
 
-    x = te.var("x", "uint1")
-    true = tvm.tir.const(1, "uint1")
-    false = tvm.tir.const(0, "uint1")
+    x = te.var("x", "bool")
+    true = tvm.tir.const(1, "bool")
+    false = tvm.tir.const(0, "bool")
 
     assert tvm.tir.all(x, true).same_as(x)
     assert tvm.tir.all(true, x).same_as(x)
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py 
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index db6f4ba47f..8352b11644 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -366,7 +366,7 @@ def test_ir_builder_tir_allocate():
     # the expected allocate
     buffer_var = tir.Var("v", tvm.ir.PointerType(tvm.ir.PrimType("float32"), 
"local"))
     ir_expected = tir.Allocate(
-        buffer_var, "float32", [10], tvm.tir.const(1, "uint1"), tir.Evaluate(1)
+        buffer_var, "float32", [10], tvm.tir.const(1, "bool"), tir.Evaluate(1)
     )
 
     # Check if the generated ir is expected
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py 
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index fc7deacd98..e4af158074 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -961,13 +961,13 @@ def test_predicated_buffer_load_store():
     buffer_load = tir.BufferLoad(
         buffer=buffer_map[b],
         indices=[0, tir.Ramp(0, 4, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     body = tir.BufferStore(
         buffer=buffer_map[a],
         value=buffer_load,
         indices=[0, tir.Ramp(0, 2, 4)],
-        predicate=tir.Broadcast(tir.IntImm("uint1", 0), 4),
+        predicate=tir.Broadcast(tir.IntImm("bool", 0), 4),
     )
     func = tir.PrimFunc(
         params=[a, b],

Reply via email to