https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/180397
Discussion: https://discourse.llvm.org/t/denseelementsattr-i1-element-type/62525 Depends on #179122. >From 907f66d32fcd7caf572f05bf86fae1f5b138bb8e Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Sun, 8 Feb 2026 09:07:31 +0000 Subject: [PATCH] [mlir][IR] `DenseElementsAttr`: Remove `i1` dense packing special case --- mlir/lib/IR/AttributeDetail.h | 52 +---------------------- mlir/lib/IR/BuiltinAttributes.cpp | 60 +++------------------------ mlir/test/IR/attribute-roundtrip.mlir | 10 ----- mlir/test/IR/parse-literal.mlir | 8 ++-- mlir/unittests/IR/AttributeTest.cpp | 14 ------- 5 files changed, 11 insertions(+), 133 deletions(-) delete mode 100644 mlir/test/IR/attribute-roundtrip.mlir diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 7af5c8cd9191d..c60886bc061ce 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -33,10 +33,6 @@ namespace detail { /// Return the bit width which DenseElementsAttr should use for this type. inline size_t getDenseElementBitWidth(Type eltType) { - // i1 is stored as a single bit (bit-packed storage). - if (eltType.isInteger(1)) - return 1; - // Check for DenseElementTypeInterface. if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) return denseEltType.getDenseElementBitSize(); llvm_unreachable("unsupported element type"); @@ -92,10 +88,7 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage { // If the data is already known to be a splat, the key hash value is // directly the data buffer. - bool isBoolData = ty.getElementType().isInteger(1); if (isKnownSplat) { - if (isBoolData) - return getKeyForSplatBoolData(ty, data[0] != 0); return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat); } @@ -105,12 +98,8 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage { size_t numElements = ty.getNumElements(); assert(numElements != 1 && "splat of 1 element should already be detected"); - // Handle boolean values directly as they are packed to 1-bit. - if (isBoolData) - return getKeyForBoolData(ty, data, numElements); - size_t elementWidth = getDenseElementBitWidth(ty.getElementType()); - // Non 1-bit dense elements are padded to 8-bits. + // Dense elements are padded to 8-bits. size_t storageSize = llvm::divideCeil(elementWidth, CHAR_BIT); assert(((data.size() / storageSize) == numElements) && "data does not hold expected number of elements"); @@ -129,45 +118,6 @@ struct DenseIntOrFPElementsAttrStorage : public DenseElementsAttributeStorage { return KeyTy(ty, firstElt, hashVal, /*isSplat=*/true); } - /// Construct a key with a set of boolean data. - static KeyTy getKeyForBoolData(ShapedType ty, ArrayRef<char> data, - size_t numElements) { - ArrayRef<char> splatData = data; - bool splatValue = splatData.front() & 1; - - // Check the simple case where the data matches the known splat value. - if (splatData == ArrayRef<char>(splatValue ? kSplatTrue : kSplatFalse)) - return getKeyForSplatBoolData(ty, splatValue); - - // Handle the case where the potential splat value is 1 and the number of - // elements is non 8-bit aligned. - size_t numOddElements = numElements % CHAR_BIT; - if (splatValue && numOddElements != 0) { - // Check that all bits are set in the last value. - char lastElt = splatData.back(); - if (lastElt != llvm::maskTrailingOnes<unsigned char>(numOddElements)) - return KeyTy(ty, data, llvm::hash_value(data)); - - // If this is the only element, the data is known to be a splat. - if (splatData.size() == 1) - return getKeyForSplatBoolData(ty, splatValue); - splatData = splatData.drop_back(); - } - - // Check that the data buffer corresponds to a splat of the proper mask. - char mask = splatValue ? ~0 : 0; - return llvm::all_of(splatData, [mask](char c) { return c == mask; }) - ? getKeyForSplatBoolData(ty, splatValue) - : KeyTy(ty, data, llvm::hash_value(data)); - } - - /// Return a key to use for a boolean splat of the given value. - static KeyTy getKeyForSplatBoolData(ShapedType type, bool splatValue) { - const char &splatData = splatValue ? kSplatTrue : kSplatFalse; - return KeyTy(type, splatData, llvm::hash_value(splatData), - /*isSplat=*/true); - } - /// Hash the key for the storage. static llvm::hash_code hashKey(const KeyTy &key) { return llvm::hash_combine(key.type, key.hashCode); diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index d9c5fd9acb811..b2f5853269f0a 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -460,7 +460,7 @@ const char DenseIntOrFPElementsAttrStorage::kSplatFalse = 0; /// Get the bitwidth of a dense element type within the buffer. /// DenseElementsAttr requires bitwidths greater than 1 to be aligned by 8. static size_t getDenseElementStorageWidth(size_t origWidth) { - return origWidth == 1 ? origWidth : llvm::alignTo<8>(origWidth); + return llvm::alignTo<8>(origWidth); } static size_t getDenseElementStorageWidth(Type elementType) { return getDenseElementStorageWidth(getDenseElementBitWidth(elementType)); @@ -622,12 +622,6 @@ Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base)); Type eltTy = owner.getElementType(); - // Handle i1 (boolean) specially - it's bit-packed and doesn't use interface. - if (eltTy.isInteger(1)) { - bool value = *BoolElementIterator(owner, index); - return IntegerAttr::get(eltTy, APInt(1, value)); - } - // Handle strings specially. if (llvm::isa<DenseStringElementsAttr>(owner)) { ArrayRef<StringRef> vals = owner.getRawStringData(); @@ -654,7 +648,7 @@ DenseElementsAttr::BoolElementIterator::BoolElementIterator( attr.getRawData().data(), attr.isSplat(), dataIndex) {} bool DenseElementsAttr::BoolElementIterator::operator*() const { - return getBit(getData(), getDataIndex()); + return static_cast<bool>(getData()[getDataIndex()]); } //===----------------------------------------------------------------------===// @@ -900,18 +894,8 @@ bool DenseElementsAttr::classof(Attribute attr) { DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<Attribute> values) { assert(hasSameNumElementsOrSplat(type, values)); - Type eltType = type.getElementType(); - // Handle i1 (boolean) specially - it's bit-packed. - if (eltType.isInteger(1)) { - SmallVector<bool> boolValues; - boolValues.reserve(values.size()); - for (Attribute attr : values) - boolValues.push_back(llvm::cast<IntegerAttr>(attr).getValue().isOne()); - return get(type, boolValues); - } - // Handle strings specially. if (!llvm::isa<DenseElementType>(eltType)) { SmallVector<StringRef, 8> stringValues; @@ -941,25 +925,9 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, ArrayRef<bool> values) { assert(hasSameNumElementsOrSplat(type, values)); assert(type.getElementType().isInteger(1)); - - SmallVector<char> buff(llvm::divideCeil(values.size(), CHAR_BIT)); - - if (!values.empty()) { - bool isSplat = true; - bool firstValue = values[0]; - for (int i = 0, e = values.size(); i != e; ++i) { - isSplat &= values[i] == firstValue; - setBit(buff.data(), i, values[i]); - } - - // Splat of bool is encoded as a byte with all-ones in it. - if (isSplat) { - buff.resize(1); - buff[0] = values[0] ? -1 : 0; - } - } - - return DenseIntOrFPElementsAttr::getRaw(type, buff); + return DenseIntOrFPElementsAttr::getRaw( + type, ArrayRef<char>(reinterpret_cast<const char *>(values.data()), + values.size())); } DenseElementsAttr DenseElementsAttr::get(ShapedType type, @@ -1030,23 +998,7 @@ bool DenseElementsAttr::isValidRawBuffer(ShapedType type, // The initializer is always a splat if the result type has a single element. detectedSplat = numElements == 1; - // Storage width of 1 is special as it is packed by the bit. - if (storageWidth == 1) { - // Check for a splat, or a buffer equal to the number of elements which - // consists of either all 0's or all 1's. - if (rawBuffer.size() == 1) { - auto rawByte = static_cast<uint8_t>(rawBuffer[0]); - if (rawByte == 0 || rawByte == 0xff) { - detectedSplat = true; - return true; - } - } - - // This is a valid non-splat buffer if it has the right size. - return rawBufferWidth == llvm::alignTo<8>(numElements); - } - - // All other types are 8-bit aligned, so we can just check the buffer width + // All types are 8-bit aligned, so we can just check the buffer width // to know if only a single initializer element was passed in. if (rawBufferWidth == storageWidth) { detectedSplat = true; diff --git a/mlir/test/IR/attribute-roundtrip.mlir b/mlir/test/IR/attribute-roundtrip.mlir deleted file mode 100644 index 974dbcae6cf0a..0000000000000 --- a/mlir/test/IR/attribute-roundtrip.mlir +++ /dev/null @@ -1,10 +0,0 @@ -// RUN: mlir-opt -canonicalize %s | mlir-opt | FileCheck %s - -// CHECK-LABEL: @large_i1_tensor_roundtrip -func.func @large_i1_tensor_roundtrip() -> tensor<160xi1> { - %cst_0 = arith.constant dense<"0xFFF00000FF000000FF000000FF000000FF000000"> : tensor<160xi1> - %cst_1 = arith.constant dense<"0xFF000000FF000000FF000000FF000000FF0000F0"> : tensor<160xi1> - // CHECK: dense<"0xFF000000FF000000FF000000FF000000FF000000"> - %0 = arith.andi %cst_0, %cst_1 : tensor<160xi1> - return %0 : tensor<160xi1> -} diff --git a/mlir/test/IR/parse-literal.mlir b/mlir/test/IR/parse-literal.mlir index 71b25e1d86480..36867c56075d0 100644 --- a/mlir/test/IR/parse-literal.mlir +++ b/mlir/test/IR/parse-literal.mlir @@ -36,8 +36,8 @@ func.func @parse_i4_tensor() -> tensor<32xi4> { } // CHECK-LABEL: @parse_i1_tensor -func.func @parse_i1_tensor() -> tensor<256xi1> { - // CHECK: dense<"0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F"> : tensor<256xi1> - %0 = arith.constant dense<"0x0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F0F"> : tensor<256xi1> - return %0 : tensor<256xi1> +func.func @parse_i1_tensor() -> tensor<32xi1> { + // CHECK: dense<[true, false, true, false, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, true, false, false, false, false, false, false, true]> : tensor<32xi1> + %0 = arith.constant dense<"0x0100010001010101010101010101010101010101010101010100000000000001"> : tensor<32xi1> + return %0 : tensor<32xi1> } diff --git a/mlir/unittests/IR/AttributeTest.cpp b/mlir/unittests/IR/AttributeTest.cpp index fd40404bf3008..404aa8c0dcf3d 100644 --- a/mlir/unittests/IR/AttributeTest.cpp +++ b/mlir/unittests/IR/AttributeTest.cpp @@ -76,20 +76,6 @@ TEST(DenseSplatTest, BoolSplatRawRoundtrip) { EXPECT_EQ(trueSplat, trueSplatFromRaw); } -TEST(DenseSplatTest, BoolSplatSmall) { - MLIRContext context; - Builder builder(&context); - - // Check that splats that don't fill entire byte are handled properly. - auto tensorType = RankedTensorType::get({4}, builder.getI1Type()); - std::vector<char> data{0b00001111}; - auto trueSplatFromRaw = - DenseIntOrFPElementsAttr::getFromRawBuffer(tensorType, data); - EXPECT_TRUE(trueSplatFromRaw.isSplat()); - DenseElementsAttr trueSplat = DenseElementsAttr::get(tensorType, true); - EXPECT_EQ(trueSplat, trueSplatFromRaw); -} - TEST(DenseSplatTest, LargeBoolSplat) { constexpr int64_t boolCount = 56; _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
