https://github.com/matthias-springer created 
https://github.com/llvm/llvm-project/pull/180397

Discussion: https://discourse.llvm.org/t/denseelementsattr-i1-element-type/62525

Depends on #179122.


>From 907f66d32fcd7caf572f05bf86fae1f5b138bb8e Mon Sep 17 00:00:00 2001
From: Matthias Springer <[email protected]>
Date: Sun, 8 Feb 2026 09:07:31 +0000
Subject: [PATCH] [mlir][IR] `DenseElementsAttr`: Remove `i1` dense packing
 special case

---
 mlir/lib/IR/AttributeDetail.h         | 52 +----------------------
 mlir/lib/IR/BuiltinAttributes.cpp     | 60 +++------------------------
 mlir/test/IR/attribute-roundtrip.mlir | 10 -----
 mlir/test/IR/parse-literal.mlir       |  8 ++--
 mlir/unittests/IR/AttributeTest.cpp   | 14 -------
 5 files changed, 11 insertions(+), 133 deletions(-)
 delete mode 100644 mlir/test/IR/attribute-roundtrip.mlir

diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h
index 7af5c8cd9191d..c60886bc061ce 100644
--- a/mlir/lib/IR/AttributeDetail.h
+++ b/mlir/lib/IR/AttributeDetail.h
@@ -33,10 +33,6 @@ namespace detail {
 
 /// Return the bit width which DenseElementsAttr should use for this type.
 inline size_t getDenseElementBitWidth(Type eltType) {
-  // i1 is stored as a single bit (bit-packed storage).
-  if (eltType.isInteger(1))
-    return 1;
-  // Check for DenseElementTypeInterface.
   if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType))
     return denseEltType.getDenseElementBitSize();
   llvm_unreachable("unsupported element type");
@@ -92,10 +88,7 @@ struct DenseIntOrFPElementsAttrStorage : public 
DenseElementsAttributeStorage {
 
     // If the data is already known to be a splat, the key hash value is
     // directly the data buffer.
-    bool isBoolData = ty.getElementType().isInteger(1);
     if (isKnownSplat) {
-      if (isBoolData)
-        return getKeyForSplatBoolData(ty, data[0] != 0);
       return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
     }
 
@@ -105,12 +98,8 @@ struct DenseIntOrFPElementsAttrStorage : public 
DenseElementsAttributeStorage {
     size_t numElements = ty.getNumElements();
     assert(numElements != 1 && "splat of 1 element should already be 
detected");
 
-    // Handle boolean values directly as they are packed to 1-bit.
-    if (isBoolData)
-      return getKeyForBoolData(ty, data, numElements);
-
     size_t elementWidth = getDenseElementBitWidth(ty.getElementType());
-    // Non 1-bit dense elements are padded to 8-bits.
+    // Dense elements are padded to 8-bits.
     size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT);
     assert(((data.size() / storageSize) == numElements) &&
            "data does not hold expected number of elements");
@@ -129,45 +118,6 @@ struct DenseIntOrFPElementsAttrStorage : public 
DenseElementsAttributeStorage {
     return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true);
   }
 
-  /// Construct a key with a set of boolean data.
-  static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data,
-                                 size_t numElements) {
-    ArrayRef<char> splatData = data;
-    bool splatValue = splatData.front() & 1;
-
-    // Check the simple case where the data matches the known splat value.
-    if (splatData == ArrayRef<char>(splatValue ? kSplatTrue : kSplatFalse))
-      return getKeyForSplatBoolData(ty, splatValue);
-
-    // Handle the case where the potential splat value is 1 and the number of
-    // elements is non 8-bit aligned.
-    size_t numOddElements = numElements % CHAR_BIT;
-    if (splatValue && numOddElements != 0) {
-      // Check that all bits are set in the last value.
-      char lastElt = splatData.back();
-      if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements))
-        return KeyTy(ty, data, llvm::hash_value(data));
-
-      // If this is the only element, the data is known to be a splat.
-      if (splatData.size() == 1)
-        return getKeyForSplatBoolData(ty, splatValue);
-      splatData = splatData.drop_back();
-    }
-
-    // Check that the data buffer corresponds to a splat of the proper mask.
-    char mask = splatValue ? ~0 : 0;
-    return llvm::all_of(splatData, [mask](char c) { return c == mask; })
-               ? getKeyForSplatBoolData(ty, splatValue)
-               : KeyTy(ty, data, llvm::hash_value(data));
-  }
-
-  /// Return a key to use for a boolean splat of the given value.
-  static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) {
-    const char &splatData = splatValue ? kSplatTrue : kSplatFalse;
-    return KeyTy(type, splatData, llvm::hash_value(splatData),
-                 /*isSplat=*/true);
-  }
-
   /// Hash the key for the storage.
   static llvm::hash_code hashKey(const KeyTy &key) {
     return llvm::hash_combine(key.type, key.hashCode);
diff --git a/mlir/lib/IR/BuiltinAttributes.cpp 
b/mlir/lib/IR/BuiltinAttributes.cpp
index d9c5fd9acb811..b2f5853269f0a 100644
--- a/mlir/lib/IR/BuiltinAttributes.cpp
+++ b/mlir/lib/IR/BuiltinAttributes.cpp
@@ -460,7 +460,7 @@ const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0;
 /// Get the bitwidth of a dense element type within the buffer.
 /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8.
 static size_t getDenseElementStorageWidth(size_t origWidth) {
-  return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth);
+  return llvm::alignTo<8>(origWidth);
 }
 static size_t getDenseElementStorageWidth(Type elementType) {
   return getDenseElementStorageWidth(getDenseElementBitWidth(elementType));
@@ -622,12 +622,6 @@ Attribute 
DenseElementsAttr::AttributeElementIterator::operator*() const {
   auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base));
   Type eltTy = owner.getElementType();
 
-  // Handle i1 (boolean) specially - it's bit-packed and doesn't use interface.
-  if (eltTy.isInteger(1)) {
-    bool value = *BoolElementIterator(owner, index);
-    return IntegerAttr::get(eltTy, APInt(1, value));
-  }
-
   // Handle strings specially.
   if (llvm::isa<DenseStringElementsAttr>(owner)) {
     ArrayRef<StringRef> vals = owner.getRawStringData();
@@ -654,7 +648,7 @@ DenseElementsAttr::BoolElementIterator::BoolElementIterator(
           attr.getRawData().data(), attr.isSplat(), dataIndex) {}
 
 bool DenseElementsAttr::BoolElementIterator::operator*() const {
-  return getBit(getData(), getDataIndex());
+  return static_cast<bool>(getData()[getDataIndex()]);
 }
 
 
//===----------------------------------------------------------------------===//
@@ -900,18 +894,8 @@ bool DenseElementsAttr::classof(Attribute attr) {
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<Attribute> values) {
   assert(hasSameNumElementsOrSplat(type, values));
-
   Type eltType = type.getElementType();
 
-  // Handle i1 (boolean) specially - it's bit-packed.
-  if (eltType.isInteger(1)) {
-    SmallVector<bool> boolValues;
-    boolValues.reserve(values.size());
-    for (Attribute attr : values)
-      boolValues.push_back(llvm::cast<IntegerAttr>(attr).getValue().isOne());
-    return get(type, boolValues);
-  }
-
   // Handle strings specially.
   if (!llvm::isa<DenseElementType>(eltType)) {
     SmallVector<StringRef, 8> stringValues;
@@ -941,25 +925,9 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type,
                                          ArrayRef<bool> values) {
   assert(hasSameNumElementsOrSplat(type, values));
   assert(type.getElementType().isInteger(1));
-
-  SmallVector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT));
-
-  if (!values.empty()) {
-    bool isSplat = true;
-    bool firstValue = values[0];
-    for (int i = 0, e = values.size(); i != e; ++i) {
-      isSplat &= values[i] == firstValue;
-      setBit(buff.data(), i, values[i]);
-    }
-
-    // Splat of bool is encoded as a byte with all-ones in it.
-    if (isSplat) {
-      buff.resize(1);
-      buff[0] = values[0] ? -1 : 0;
-    }
-  }
-
-  return DenseIntOrFPElementsAttr::getRaw(type, buff);
+  return DenseIntOrFPElementsAttr::getRaw(
+      type, ArrayRef<char>(reinterpret_cast<const char *>(values.data()),
+                           values.size()));
 }
 
 DenseElementsAttr DenseElementsAttr::get(ShapedType type,
@@ -1030,23 +998,7 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type,
   // The initializer is always a splat if the result type has a single element.
   detectedSplat = numElements == 1;
 
-  // Storage width of 1 is special as it is packed by the bit.
-  if (storageWidth == 1) {
-    // Check for a splat, or a buffer equal to the number of elements which
-    // consists of either all 0's or all 1's.
-    if (rawBuffer.size() == 1) {
-      auto rawByte = static_cast<uint8_t>(rawBuffer[0]);
-      if (rawByte == 0 || rawByte == 0xff) {
-        detectedSplat = true;
-        return true;
-      }
-    }
-
-    // This is a valid non-splat buffer if it has the right size.
-    return rawBufferWidth == llvm::alignTo<8>(numElements);
-  }
-
-  // All other types are 8-bit aligned, so we can just check the buffer width
+  // All types are 8-bit aligned, so we can just check the buffer width
   // to know if only a single initializer element was passed in.
   if (rawBufferWidth == storageWidth) {
     detectedSplat = true;
diff --git a/mlir/test/IR/attribute-roundtrip.mlir 
b/mlir/test/IR/attribute-roundtrip.mlir
deleted file mode 100644
index 974dbcae6cf0a..0000000000000
--- a/mlir/test/IR/attribute-roundtrip.mlir
+++ /dev/null
@@ -1,10 +0,0 @@
-// RUN: mlir-opt -canonicalize %s | mlir-opt | FileCheck %s
-
-// CHECK-LABEL: @large_i1_tensor_roundtrip
-func.func @large_i1_tensor_roundtrip() -> tensor<160xi1> {
-  %cst_0 = arith.constant dense<"0xFFF00000FF000000FF000000FF000000FF000000"> 
: tensor<160xi1>
-  %cst_1 = arith.constant dense<"0xFF000000FF000000FF000000FF000000FF0000F0"> 
: tensor<160xi1>
-  // CHECK: dense<"0xFF000000FF000000FF000000FF000000FF000000">
-  %0 = arith.andi %cst_0, %cst_1 : tensor<160xi1>
-  return %0 : tensor<160xi1>
-}
diff --git a/mlir/test/IR/parse-literal.mlir b/mlir/test/IR/parse-literal.mlir
index 71b25e1d86480..36867c56075d0 100644
--- a/mlir/test/IR/parse-literal.mlir
+++ b/mlir/test/IR/parse-literal.mlir
@@ -36,8 +36,8 @@ func.func @parse_i4_tensor() -> tensor<32xi4> {
 }
 
 // CHECK-LABEL: @parse_i1_tensor
-func.func @parse_i1_tensor() -> tensor<256xi1> {
-  // CHECK: 
dense<"0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F"> : 
tensor<256xi1>
-  %0 = arith.constant 
dense<"0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F"> : 
tensor<256xi1>
-  return %0 : tensor<256xi1>
+func.func @parse_i1_tensor() -> tensor<32xi1> {
+  // CHECK: dense<[true, false, true, false, true, true, true, true, true, 
true, true, true, true, true, true, true, true, true, true, true, true, true, 
true, true, true, false, false, false, false, false, false, true]> : 
tensor<32xi1>
+  %0 = arith.constant 
dense<"0x0100010001010101010101010101010101010101010101010100000000000001"> : 
tensor<32xi1>
+  return %0 : tensor<32xi1>
 }
diff --git a/mlir/unittests/IR/AttributeTest.cpp 
b/mlir/unittests/IR/AttributeTest.cpp
index fd40404bf3008..404aa8c0dcf3d 100644
--- a/mlir/unittests/IR/AttributeTest.cpp
+++ b/mlir/unittests/IR/AttributeTest.cpp
@@ -76,20 +76,6 @@ TEST(DenseSplatTest, BoolSplatRawRoundtrip) {
   EXPECT_EQ(trueSplat, trueSplatFromRaw);
 }
 
-TEST(DenseSplatTest, BoolSplatSmall) {
-  MLIRContext context;
-  Builder builder(&context);
-
-  // Check that splats that don't fill entire byte are handled properly.
-  auto tensorType = RankedTensorType::get({4}, builder.getI1Type());
-  std::vector<char> data{0b00001111};
-  auto trueSplatFromRaw =
-      DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data);
-  EXPECT_TRUE(trueSplatFromRaw.isSplat());
-  DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true);
-  EXPECT_EQ(trueSplat, trueSplatFromRaw);
-}
-
 TEST(DenseSplatTest, LargeBoolSplat) {
   constexpr int64_t boolCount = 56;
 

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to