This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bb750012a4 [TIRx] Add a dedicated boolean buffer lowering pass (#19873)
bb750012a4 is described below

commit bb750012a42e58eef1927452a053183bb28cf9b1
Author: Guan-Ming Chiu <[email protected]>
AuthorDate: Wed Jul 1 08:00:41 2026 +0800

    [TIRx] Add a dedicated boolean buffer lowering pass (#19873)
    
    ## Why
    
    The bool→int8 backing-array conversion was duplicated inline across
    `FlattenBuffer` and `LowerTIRxCleanup`.
    
    ## How
    
    - Add `LowerBoolBuffer` pass that rewrites `bool` buffers to `int8` and
    inserts load/store casts.
    - Remove the duplicated bool handling from `FlattenBuffer` and
    `LowerTIRxCleanup`.
    - Run it after FlattenBuffer (before VectorizeLoop) in every pipeline,
    with structural and build-and-run tests.
    
    ---------
    
    Signed-off-by: Guan-Ming (Wesley) Chiu 
<[email protected]>
---
 src/target/llvm/codegen_llvm.cc                    | 31 +++++++++++++++---
 src/tirx/ir/buffer.cc                              |  2 --
 src/tirx/transform/flatten_buffer.cc               | 38 +---------------------
 src/tirx/transform/lower_tirx_cleanup.cc           | 33 +------------------
 .../test_tir_transform_flatten_buffer.py           |  8 ++---
 5 files changed, 32 insertions(+), 80 deletions(-)

diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index f091015f22..6ec68fe386 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -109,6 +109,16 @@ PrimType WithScalableVScaleFactor(const PrimType& dtype, 
int vscale_factor) {
   return PrimType::ScalableVector(dtype.code(), dtype.bits(), vscale_factor);
 }
 
+// Underlying access type for a Buffer: bool is backed by int8 so vectorized
+// accesses lower to real loads/stores instead of i1 predicate registers.
+PrimType BufferAccessType(const PrimType& dtype) {
+  if (!dtype.MatchesCode(DLDataTypeCode::kDLBool)) return dtype;
+  if (dtype.IsScalableVector()) {
+    return PrimType::ScalableVector(DLDataTypeCode::kDLInt, 8, 
dtype.VScaleFactor());
+  }
+  return PrimType::Int(8, dtype.lanes());
+}
+
 }  // namespace
 
 // CodeGenLLVM has members of type std::unique_ptr<T>. These members will be
@@ -1720,7 +1730,7 @@ void CodeGenLLVM::BufferAccessHelper(
     std::function<llvm::Instruction*(TypedPointer buffer_ptr, int subelement_i,
                                      llvm::Value* predicate, int alignment, 
bool is_volatile)>
         make_instruction) {
-  PrimType buffer_element_dtype = buffer->dtype;
+  PrimType buffer_element_dtype = BufferAccessType(buffer->dtype);
 
   TVM_FFI_ICHECK_GE(indices.size(), 1)
       << "Buffer " << buffer->name << " is accessed with no indices.  "
@@ -1825,6 +1835,7 @@ void CodeGenLLVM::BufferAccessHelper(
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const BufferLoadNode* op) {
   PrimType value_dtype(op->ty()->dtype);
+  PrimType access_dtype = BufferAccessType(value_dtype);
 
   std::vector<llvm::Value*> loads;
 
@@ -1848,17 +1859,21 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const 
BufferLoadNode* op) {
   // Pass all indices into BufferAccessHelper.  In CodeGenLLVM,
   // non-flat indices will result in an error in CreateBufferPtr, but
   // a subclass may override CreateBufferPtr.
-  BufferAccessHelper(op->buffer, op->indices, op->predicate, value_dtype, 
make_load);
+  BufferAccessHelper(op->buffer, op->indices, op->predicate, access_dtype, 
make_load);
 
+  llvm::Value* ret;
   if (loads.size() == 1) {
-    return loads[0];
+    ret = loads[0];
   } else {
-    llvm::Value* ret = llvm::UndefValue::get(DTypeToLLVMType(value_dtype));
+    ret = llvm::UndefValue::get(DTypeToLLVMType(access_dtype));
     for (size_t i = 0; i < loads.size(); i++) {
       ret = builder_->CreateInsertElement(ret, loads[i], ConstInt32(i));
     }
-    return ret;
   }
+  if (!access_dtype.same_as(value_dtype)) {
+    ret = CreateCast(access_dtype, value_dtype, ret);
+  }
+  return ret;
 }
 
 llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op) {
@@ -1977,6 +1992,12 @@ void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
 
   llvm::Value* value = MakeValue(op->value);
 
+  PrimType store_dtype = BufferAccessType(value_dtype);
+  if (!store_dtype.same_as(value_dtype)) {
+    value = CreateCast(value_dtype, store_dtype, value);
+    value_dtype = store_dtype;
+  }
+
   auto make_store = [this, value](TypedPointer buffer_ptr, int subelement_i, 
llvm::Value* predicate,
                                   int alignment, bool is_volatile) {
     llvm::Value* to_store = value;
diff --git a/src/tirx/ir/buffer.cc b/src/tirx/ir/buffer.cc
index a0bcff6695..5401a2a3af 100644
--- a/src/tirx/ir/buffer.cc
+++ b/src/tirx/ir/buffer.cc
@@ -428,7 +428,6 @@ Buffer Buffer::GetFlattenedBuffer() const {
 
 PrimExpr Buffer::vload(ffi::Array<PrimExpr> begin, PrimType value_dtype,
                        ffi::Optional<PrimExpr> predicate) const {
-  // Specially handle bool, stored as int8 in buffers.
   const BufferNode* n = operator->();
   TVM_FFI_ICHECK(n != nullptr);
   PrimType buffer_dtype(n->dtype);
@@ -454,7 +453,6 @@ PrimExpr Buffer::vload(ffi::Array<PrimExpr> begin, PrimType 
value_dtype,
 
 Stmt Buffer::vstore(ffi::Array<PrimExpr> begin, PrimExpr value,
                     ffi::Optional<PrimExpr> predicate) const {
-  // Specially handle bool, stored as int8 in buffers.
   const BufferNode* n = operator->();
   TVM_FFI_ICHECK(n != nullptr);
   PrimType value_dtype = value.ty();
diff --git a/src/tirx/transform/flatten_buffer.cc 
b/src/tirx/transform/flatten_buffer.cc
index d959fb0ad8..a3890e3377 100644
--- a/src/tirx/transform/flatten_buffer.cc
+++ b/src/tirx/transform/flatten_buffer.cc
@@ -113,11 +113,6 @@ class BufferFlattener : public 
arith::IRMutatorWithAnalyzer {
     auto node = StmtExprMutator::VisitStmt_(op).as_or_throw<AllocBuffer>();
 
     auto new_buf = GetFlattenedBuffer(node->buffer);
-    // TODO(Lunderberg): Move the handling of boolean into a dedicated pass.
-    if (new_buf->dtype->dtype == DLDataType{kDLBool, 8, 1}) {
-      auto writer = new_buf.CopyOnWrite();
-      writer->dtype = PrimType::Int(8);
-    }
     if (!node->buffer.same_as(new_buf)) {
       node.CopyOnWrite()->buffer = new_buf;
     }
@@ -144,11 +139,6 @@ class BufferFlattener : public 
arith::IRMutatorWithAnalyzer {
     auto flattened = buf.GetFlattenedBuffer();
     auto writer = flattened.CopyOnWrite();
 
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (flattened->dtype->dtype == DLDataType{kDLBool, 8, 1}) {
-      writer->dtype = PrimType::Int(8);
-    }
     // canonicalize shape
     for (size_t i = 0; i < flattened->shape.size(); ++i) {
       writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i]));
@@ -161,40 +151,14 @@ class BufferFlattener : public 
arith::IRMutatorWithAnalyzer {
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
     BufferStore store = 
StmtExprMutator::VisitStmt_(op).as_or_throw<BufferStore>();
-    PrimType store_value_ty = op->value.ty();
-    bool store_returns_bool = 
store_value_ty.MatchesCode(DLDataTypeCode::kDLBool);
     store = VisitBufferAccess(store);
-
-    // Handle casts from the value's dtype to the dtype of the
-    // backing array.
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (store_returns_bool) {
-      TVM_FFI_ICHECK_EQ(store->buffer->dtype->dtype, (DLDataType{kDLInt, 8, 
1}))
-          << "Expected int8 backing array for boolean tensor";
-      auto writer = store.CopyOnWrite();
-      writer->value = tvm::cast(PrimType::Int(8), store->value);
-      return store;
-    }
     return store;
   }
 
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    PrimType load_ty = op->ty();
-    bool load_returns_bool = load_ty.MatchesCode(DLDataTypeCode::kDLBool);
     BufferLoad load = 
StmtExprMutator::VisitExpr_(op).as_or_throw<BufferLoad>();
     load = VisitBufferAccess(load);
-    // Handle casts from dtype of the backing array to value's dtype.
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (load_returns_bool) {
-      TVM_FFI_ICHECK_EQ(load->buffer->dtype->dtype, (DLDataType{kDLInt, 8, 1}))
-          << "Expected int8 backing array for boolean tensor";
-      load.CopyOnWrite()->ExprNode::ty = PrimType::Int(8);
-      return tvm::cast(PrimType::Bool(), load);
-    } else {
-      return load;
-    }
+    return load;
   }
 
   ffi::Array<PrimExpr> GetSimplifiedElemOffset(const Buffer& buffer,
diff --git a/src/tirx/transform/lower_tirx_cleanup.cc 
b/src/tirx/transform/lower_tirx_cleanup.cc
index 3637e2af42..a9ecfa1c9e 100644
--- a/src/tirx/transform/lower_tirx_cleanup.cc
+++ b/src/tirx/transform/lower_tirx_cleanup.cc
@@ -169,11 +169,6 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer {
       flattened = buf.GetFlattenedBuffer();
       writer = flattened.CopyOnWrite();
     }
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (flattened->dtype.MatchesCode(DLDataTypeCode::kDLBool)) {
-      writer->dtype = PrimType::Int(8);
-    }
     // canonicalize shape
     for (size_t i = 0; i < flattened->shape.size(); ++i) {
       writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i]));
@@ -187,40 +182,14 @@ class LayoutApplier : public arith::IRMutatorWithAnalyzer 
{
 
   Stmt VisitStmt_(const BufferStoreNode* op) final {
     BufferStore store = 
StmtExprMutator::VisitStmt_(op).as_or_throw<BufferStore>();
-    PrimType store_value_ty = op->value.ty();
-    bool store_returns_bool = 
store_value_ty.MatchesCode(DLDataTypeCode::kDLBool);
     store = VisitBufferAccess(store);
-
-    // Handle casts from the value's dtype to the dtype of the
-    // backing array.
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (store_returns_bool) {
-      TVM_FFI_ICHECK_EQ(store->buffer->dtype, PrimType::Int(8))
-          << "Expected int8 backing array for boolean tensor";
-      auto writer = store.CopyOnWrite();
-      writer->value = tvm::cast(PrimType::Int(8), store->value);
-      return std::move(store);
-    }
     return std::move(store);
   }
 
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    PrimType load_ty = op->ty();
-    bool load_returns_bool = load_ty.MatchesCode(DLDataTypeCode::kDLBool);
     BufferLoad load = 
StmtExprMutator::VisitExpr_(op).as_or_throw<BufferLoad>();
     load = VisitBufferAccess(load);
-    // Handle casts from dtype of the backing array to value's dtype.
-    // TODO(Lunderberg): Move the handling of boolean into a
-    // dedicated pass.
-    if (load_returns_bool) {
-      TVM_FFI_ICHECK_EQ(load->buffer->dtype, PrimType::Int(8))
-          << "Expected int8 backing array for boolean tensor";
-      load.CopyOnWrite()->ExprNode::ty = PrimType::Int(8);
-      return tvm::cast(PrimType::Bool(), load);
-    } else {
-      return std::move(load);
-    }
+    return std::move(load);
   }
 
   Stmt VisitStmt_(const tirx::TilePrimitiveCallNode* op) final {
diff --git a/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py 
b/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py
index 9b1c171be4..a437156655 100644
--- a/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py
+++ b/tests/python/tirx-transform/test_tir_transform_flatten_buffer.py
@@ -318,7 +318,7 @@ def test_strided():
 
 
 def test_boolean():
-    """Boolean buffers should be replaced by a backing int8 array"""
+    """Boolean buffers are flattened but kept as bool (no int8 backing 
array)"""
 
     @I.ir_module(s_tir=True)
     class Before:
@@ -331,11 +331,11 @@ def test_boolean():
     class Expected:
         @T.prim_func(s_tir=True)
         def main(input_A: T.Buffer(10, "bool"), input_B: T.Buffer(10, "bool")) 
-> None:
-            A = T.decl_buffer(10, dtype="int8", data=input_A.data)
-            B = T.decl_buffer(10, dtype="int8", data=input_B.data)
+            A = T.decl_buffer(10, dtype="bool", data=input_A.data)
+            B = T.decl_buffer(10, dtype="bool", data=input_B.data)
             # body
             for i0 in T.serial(10):
-                B[i0] = T.cast(T.cast(A[i0], "bool"), "int8")
+                B[i0] = A[i0]
 
     After = _transform()(Before)
     tvm.ir.assert_structural_equal(After, Expected)

Reply via email to