Author: Matthias Springer Date: 2026-02-28T15:53:44+02:00 New Revision: be4a51db14a53ca98ff3b3db416ce3a3ff1a7abc
URL: https://github.com/llvm/llvm-project/commit/be4a51db14a53ca98ff3b3db416ce3a3ff1a7abc DIFF: https://github.com/llvm/llvm-project/commit/be4a51db14a53ca98ff3b3db416ce3a3ff1a7abc.diff LOG: Revert "[mlir][IR] Generalize `DenseElementsAttr` to custom element types (#1…" This reverts commit e655c36c16c118e3f8ae0c95854f33119218a4bf. Added: Modified: mlir/include/mlir/IR/BuiltinAttributes.td mlir/include/mlir/IR/BuiltinTypeInterfaces.h mlir/include/mlir/IR/BuiltinTypeInterfaces.td mlir/include/mlir/IR/BuiltinTypes.td mlir/lib/AsmParser/AttributeParser.cpp mlir/lib/IR/AsmPrinter.cpp mlir/lib/IR/AttributeDetail.h mlir/lib/IR/BuiltinAttributes.cpp mlir/lib/IR/BuiltinTypeInterfaces.cpp mlir/lib/IR/BuiltinTypes.cpp mlir/test/lib/Dialect/Test/TestTypeDefs.td mlir/test/lib/Dialect/Test/TestTypes.cpp mlir/test/lib/Dialect/Test/TestTypes.h Removed: mlir/test/IR/dense-elements-type-interface.mlir ################################################################################ diff --git a/mlir/include/mlir/IR/BuiltinAttributes.td b/mlir/include/mlir/IR/BuiltinAttributes.td index dced379d1f979..798d3c84f9618 100644 --- a/mlir/include/mlir/IR/BuiltinAttributes.td +++ b/mlir/include/mlir/IR/BuiltinAttributes.td @@ -239,48 +239,29 @@ def Builtin_DenseIntOrFPElementsAttr : Builtin_Attr< "DenseElementsAttr" > { let summary = "An Attribute containing a dense multi-dimensional array of " - "values"; + "integer or floating-point values"; let description = [{ - A dense elements attribute stores one or multiple elements of the same type. - The term "dense" refers to the fact that elements are not stored as - individual MLIR attributes, but in a raw buffer. The attribute provides a - covenience API to access elements in the form of MLIR attributes, but users - should avoid that API in performance-critical code and utilize APIs that - operate on raw bytes instead. - - The number of elements is determined by the `type` shaped type. (Unranked - shaped types are not supported.) The element type of the shaped type must - implement the `DenseElementType` interface. This type interface defines the - bitwidth of an element and provides a serializer/deserializer to/from MLIR - attributes. - - Storage format: Given an element bitwidth "w", element "i" starts at byte - offset "i * ceildiv(w, 8)". In other words, each element starts at a full - byte offset. - - TODO: The name `DenseIntOrFPElements` is no longer accurate. The attribute - will be renamed in the future. + Syntax: + + ``` + tensor-literal ::= integer-literal | float-literal | bool-literal | [] | [tensor-literal (, tensor-literal)* ] + dense-intorfloat-elements-attribute ::= `dense` `<` tensor-literal `>` `:` + ( tensor-type | vector-type ) + ``` + + A dense int-or-float elements attribute is an elements attribute containing + a densely packed vector or tensor of integer or floating-point values. The + element type of this attribute is required to be either an `IntegerType` or + a `FloatType`. Examples: ``` - // Literal-first syntax: A splat tensor of integer values. + // A splat tensor of integer values. dense<10> : tensor<2xi32> - - // Literal-first syntax: A tensor of 2 float32 elements. + // A tensor of 2 float32 elements. dense<[10.0, 11.0]> : tensor<2xf32> - - // Type-first syntax: A splat tensor of integer values. - dense<tensor<2xi32> : 10 : i32> - - // Type-first syntax: A tensor of 2 float32 elements. - dense<tensor<2xf32> : [10.0, 11.0]> ``` - - Note: The literal-first syntax is supported only for complex, float, index, - int element types. The parser/print have special casing for these types. - Dense element attributes with other element types must use the type-first - syntax. }]; let parameters = (ins AttributeSelfTypeParameter<"", "ShapedType">:$type, "ArrayRef<char>":$rawData); diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h index 9425d554b427c..5f14517d8dd71 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.h +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.h @@ -19,29 +19,6 @@ struct fltSemantics; namespace mlir { class FloatType; class MLIRContext; - -namespace detail { -/// Float type implementation of -/// DenseElementTypeInterface::getDenseElementBitSize. -size_t getFloatTypeDenseElementBitSize(Type type); - -/// Float type implementation of DenseElementTypeInterface::convertToAttribute. -Attribute convertFloatTypeToAttribute(Type type, llvm::ArrayRef<char> rawData); - -/// Float type implementation of -/// DenseElementTypeInterface::convertFromAttribute. -LogicalResult -convertFloatTypeFromAttribute(Type type, Attribute attr, - llvm::SmallVectorImpl<char> &result); - -/// Read `bitWidth` bits from byte-aligned position in `rawData` and return as -/// an APInt. Handles endianness correctly. -llvm::APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth); - -/// Write `value` to byte-aligned position `bitPos` in `rawData`. Handles -/// endianness correctly. -void writeBits(char *rawData, size_t bitPos, llvm::APInt value); -} // namespace detail } // namespace mlir #include "mlir/IR/BuiltinTypeInterfaces.h.inc" diff --git a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td index 93c8c0694b467..9ef08b7020b99 100644 --- a/mlir/include/mlir/IR/BuiltinTypeInterfaces.td +++ b/mlir/include/mlir/IR/BuiltinTypeInterfaces.td @@ -41,70 +41,12 @@ def VectorElementTypeInterface : TypeInterface<"VectorElementTypeInterface"> { }]; } -//===----------------------------------------------------------------------===// -// DenseElementTypeInterface -//===----------------------------------------------------------------------===// - -def DenseElementTypeInterface : TypeInterface<"DenseElementType"> { - let cppNamespace = "::mlir"; - let description = [{ - This interface allows custom types to be used as element types in - DenseElementsAttr. Types implementing this interface define: - - 1. The bit size for element storage. - 2. Helper methods for converting from/to Attribute. This assumes that there - is a corresponding attribute for each type that implements this - interface. - - The helper methods for converting from/to Attribute are utilized when - parsing/printing IR or iterating over the elements via Attribute. - }]; - - let methods = [ - InterfaceMethod< - /*desc=*/[{ - Return the number of bits required to store one element in dense - storage. - - Note: The DenseElementsAttr infrastructure will automatically align - every element to a full byte in storage. This limitation could be lifted - in the future to support dense packing of non-byte-sized elements. - }], - /*retTy=*/"size_t", - /*methodName=*/"getDenseElementBitSize", - /*args=*/(ins) - >, - InterfaceMethod< - /*desc=*/[{ - Attribute deserialization / attribute factory: Convert raw storage bytes - into an MLIR attribute. The size of `rawData` is - "ceilDiv(getDenseElementBitSize(), 8)". - }], - /*retTy=*/"::mlir::Attribute", - /*methodName=*/"convertToAttribute", - /*args=*/(ins "::llvm::ArrayRef<char>":$rawData) - >, - InterfaceMethod< - /*desc=*/[{ - Attribute serialization: Convert an MLIR attribute into raw bytes. - Implementations must append "getDenseElementBitSize() / 8" values to - `result`. Return "failure" if the attribute is incompatible with this - element type. - }], - /*retTy=*/"::llvm::LogicalResult", - /*methodName=*/"convertFromAttribute", - /*args=*/(ins "::mlir::Attribute":$attr, - "::llvm::SmallVectorImpl<char>&":$result) - >, - ]; -} - //===----------------------------------------------------------------------===// // FloatTypeInterface //===----------------------------------------------------------------------===// def FloatTypeInterface : TypeInterface<"FloatType", - [DenseElementTypeInterface, VectorElementTypeInterface]> { + [VectorElementTypeInterface]> { let cppNamespace = "::mlir"; let description = [{ This type interface should be implemented by all floating-point types. It @@ -141,21 +83,6 @@ def FloatTypeInterface : TypeInterface<"FloatType", /// The width includes the integer bit. unsigned getFPMantissaWidth(); }]; - - let extraTraitClassDeclaration = [{ - /// DenseElementTypeInterface implementations for float types. - size_t getDenseElementBitSize() const { - return ::mlir::detail::getFloatTypeDenseElementBitSize($_type); - } - ::mlir::Attribute convertToAttribute(::llvm::ArrayRef<char> rawData) const { - return ::mlir::detail::convertFloatTypeToAttribute($_type, rawData); - } - ::llvm::LogicalResult - convertFromAttribute(::mlir::Attribute attr, - ::llvm::SmallVectorImpl<char> &result) const { - return ::mlir::detail::convertFloatTypeFromAttribute($_type, attr, result); - } - }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/IR/BuiltinTypes.td b/mlir/include/mlir/IR/BuiltinTypes.td index e7d0a03a85e7d..806064faeda00 100644 --- a/mlir/include/mlir/IR/BuiltinTypes.td +++ b/mlir/include/mlir/IR/BuiltinTypes.td @@ -45,10 +45,7 @@ def ValueSemantics : NativeTypeTrait<"ValueSemantics"> { // ComplexType //===----------------------------------------------------------------------===// -def Builtin_Complex : Builtin_Type<"Complex", "complex", - [DeclareTypeInterfaceMethods<DenseElementTypeInterface, - ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]> - ]> { +def Builtin_Complex : Builtin_Type<"Complex", "complex"> { let summary = "Complex number with a parameterized element type"; let description = [{ Syntax: @@ -563,9 +560,7 @@ def Builtin_Graph : Builtin_FunctionLike<"Graph", "graph">; //===----------------------------------------------------------------------===// def Builtin_Index : Builtin_Type<"Index", "index", - [DeclareTypeInterfaceMethods<DenseElementTypeInterface, - ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]>, - VectorElementTypeInterface]> { + [VectorElementTypeInterface]> { let summary = "Integer-like type with unknown platform-dependent bit width"; let description = [{ Syntax: @@ -596,10 +591,7 @@ def Builtin_Index : Builtin_Type<"Index", "index", //===----------------------------------------------------------------------===// def Builtin_Integer : Builtin_Type<"Integer", "integer", - [VectorElementTypeInterface, QuantStorageTypeInterface, - DeclareTypeInterfaceMethods<DenseElementTypeInterface, [ - "getDenseElementBitSize", "convertToAttribute", - "convertFromAttribute"]>]> { + [VectorElementTypeInterface, QuantStorageTypeInterface]> { let summary = "Integer type with arbitrary precision up to a fixed limit"; let description = [{ Syntax: diff --git a/mlir/lib/AsmParser/AttributeParser.cpp b/mlir/lib/AsmParser/AttributeParser.cpp index dc9744a42b730..5978a11d06bc9 100644 --- a/mlir/lib/AsmParser/AttributeParser.cpp +++ b/mlir/lib/AsmParser/AttributeParser.cpp @@ -16,7 +16,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" @@ -954,119 +953,6 @@ Attribute Parser::parseDenseArrayAttr(Type attrType) { return eltParser.getAttr(); } -/// Try to parse a dense elements attribute with the type-first syntax. -/// Syntax: dense<TYPE : [ATTR, ATTR, ...]> -/// This syntax is used for types other than int, float, index and complex. -/// -/// Returns: -/// - "null" attribute if this is not the type-first syntax. -/// - "failure" in case of a parse error. -/// - A valid Attribute otherwise. -static FailureOr<Attribute> parseDenseElementsAttrTyped(Parser &p, SMLoc loc) { - // Skip l_paren because "parseType" would try to parse it as a tuple/function - // type, but '(' starts a complex literal like in the literal-first syntax. - if (p.getToken().is(Token::l_paren)) - return Attribute(); - - // Parse type and valdiate that it's a shaped type. - auto typeLoc = p.getToken().getLoc(); - Type type; - OptionalParseResult typeResult = p.parseOptionalType(type); - if (!typeResult.has_value()) - return Attribute(); // Not type-first syntax. - if (failed(*typeResult)) - return failure(); // Type parse error. - - auto shapedType = dyn_cast<ShapedType>(type); - if (!shapedType) { - p.emitError(typeLoc, "expected a shaped type for dense elements"); - return failure(); - } - if (!shapedType.hasStaticShape()) { - p.emitError(typeLoc, "dense elements type must have static shape"); - return failure(); - } - - // Check that the element type implements DenseElementTypeInterface. - auto denseEltType = dyn_cast<DenseElementType>(shapedType.getElementType()); - if (!denseEltType) { - p.emitError(typeLoc, - "element type must implement DenseElementTypeInterface " - "for type-first dense syntax"); - return failure(); - } - - // Parse colon. - if (p.parseToken(Token::colon, "expected ':' after type in dense attribute")) - return failure(); - - // Parse the element attributes and convert to raw bytes. - SmallVector<char> rawData; - - // Helper to parse a single element. - auto parseSingleElement = [&]() -> ParseResult { - Attribute elemAttr = p.parseAttribute(); - if (!elemAttr) - return failure(); - if (failed(denseEltType.convertFromAttribute(elemAttr, rawData))) { - p.emitError("incompatible attribute for element type"); - return failure(); - } - return success(); - }; - - // Recursively parse elements matching the expected shape. - std::function<ParseResult(ArrayRef<int64_t>)> parseElements; - parseElements = [&](ArrayRef<int64_t> remainingShape) -> ParseResult { - // Leaf: parse a single element. - if (remainingShape.empty()) - return parseSingleElement(); - - // Non-leaf: expect a list with the correct number of elements. - int64_t expectedCount = remainingShape.front(); - ArrayRef<int64_t> innerShape = remainingShape.drop_front(); - int64_t actualCount = 0; - - auto parseOne = [&]() -> ParseResult { - if (parseElements(innerShape)) - return failure(); - ++actualCount; - return success(); - }; - - if (p.parseCommaSeparatedList(Parser::Delimiter::Square, parseOne)) - return failure(); - - if (actualCount != expectedCount) { - p.emitError() << "expected " << expectedCount - << " elements in dimension, got " << actualCount; - return failure(); - } - return success(); - }; - - // Parse elements. - if (!p.getToken().is(Token::l_square)) { - // Single element - parse as splat. - if (parseSingleElement()) - return failure(); - } else if (shapedType.getShape().empty()) { - // Scalar type shouldn't have a list. - p.emitError(loc, "expected single element for scalar type, got list"); - return failure(); - } else { - // Parse structured literal matching the shape. - if (parseElements(shapedType.getShape())) - return failure(); - } - - if (p.parseToken(Token::greater, "expected '>' to close dense attribute")) - return failure(); - - // Create the attribute from raw buffer. - return DenseElementsAttr::getFromRawBuffer(shapedType, rawData); -} - /// Parse a dense elements attribute. Attribute Parser::parseDenseElementsAttr(Type attrType) { auto attribLoc = getToken().getLoc(); @@ -1074,16 +960,7 @@ Attribute Parser::parseDenseElementsAttr(Type attrType) { if (parseToken(Token::less, "expected '<' after 'dense'")) return nullptr; - // Try to parse the type-first syntax: dense<TYPE : [ATTR, ...]> - FailureOr<Attribute> typedResult = - parseDenseElementsAttrTyped(*this, attribLoc); - if (failed(typedResult)) - return nullptr; - if (*typedResult) - return *typedResult; - - // Try to parse the literal-first syntax, which is the default format for - // int, float, index and complex element types. + // Parse the literal data if necessary. TensorLiteralParser literalParser(*this); if (!consumeIf(Token::greater)) { if (literalParser.parse(/*allowHex=*/true) || diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index b3242f838fc1d..81455699421cc 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -507,18 +507,11 @@ class AsmPrinter::Impl { /// Print a dense string elements attribute. void printDenseStringElementsAttr(DenseStringElementsAttr attr); - /// Print a dense elements attribute in the literal-first syntax. If - /// 'allowHex' is true, a hex string is used instead of individual elements - /// when the elements attr is large. + /// Print a dense elements attribute. If 'allowHex' is true, a hex string is + /// used instead of individual elements when the elements attr is large. void printDenseIntOrFPElementsAttr(DenseIntOrFPElementsAttr attr, bool allowHex); - /// Print a dense elements attribute using the type-first syntax and the - /// DenseElementTypeInterface, which provides the attribute printer for each - /// element. - void printTypeFirstDenseElementsAttr(DenseElementsAttr attr, - DenseElementType denseEltType); - /// Print a dense array attribute. void printDenseArrayAttr(DenseArrayAttr attr); @@ -2514,17 +2507,7 @@ void AsmPrinter::Impl::printAttributeImpl(Attribute attr, printElidedElementsAttr(os); } else { os << "dense<"; - // Check if the element type implements DenseElementTypeInterface and is - // not a built-in type. Built-in types (int, float, index, complex) use - // the existing printing format for backwards compatibility. - Type eltType = intOrFpEltAttr.getElementType(); - if (isa<FloatType, IntegerType, IndexType, ComplexType>(eltType)) { - printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); - } else { - printTypeFirstDenseElementsAttr(intOrFpEltAttr, - cast<DenseElementType>(eltType)); - typeElision = AttrTypeElision::Must; - } + printDenseIntOrFPElementsAttr(intOrFpEltAttr, /*allowHex=*/true); os << '>'; } @@ -2722,27 +2705,6 @@ void AsmPrinter::Impl::printDenseStringElementsAttr( printDenseElementsAttrImpl(attr.isSplat(), attr.getType(), os, printFn); } -void AsmPrinter::Impl::printTypeFirstDenseElementsAttr( - DenseElementsAttr attr, DenseElementType denseEltType) { - // Print the type first: dense<TYPE : [ELEMENTS]> - printType(attr.getType()); - os << " : "; - - ArrayRef<char> rawData = attr.getRawData(); - // Storage is byte-aligned: align bit size up to next byte boundary. - size_t bitSize = denseEltType.getDenseElementBitSize(); - size_t byteSize = llvm::divideCeil(bitSize, static_cast<size_t>(CHAR_BIT)); - - // Print elements: convert raw bytes to attribute, then print attribute. - printDenseElementsAttrImpl( - attr.isSplat(), attr.getType(), os, [&](unsigned index) { - size_t offset = attr.isSplat() ? 0 : index * byteSize; - ArrayRef<char> elemData = rawData.slice(offset, byteSize); - Attribute elemAttr = denseEltType.convertToAttribute(elemData); - printAttributeImpl(elemAttr); - }); -} - void AsmPrinter::Impl::printDenseArrayAttr(DenseArrayAttr attr) { Type type = attr.getElementType(); unsigned bitwidth = type.isInteger(1) ? 8 : type.getIntOrFloatBitWidth(); diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 8505149afdd9c..1f268603cf37f 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -16,7 +16,6 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/AttributeSupport.h" #include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/MLIRContext.h" @@ -33,9 +32,12 @@ namespace detail { /// Return the bit width which DenseElementsAttr should use for this type. inline size_t getDenseElementBitWidth(Type eltType) { - if (auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType)) - return denseEltType.getDenseElementBitSize(); - llvm_unreachable("unsupported element type"); + // Align the width for complex to 8 to make storage and interpretation easier. + if (ComplexType comp = llvm::dyn_cast<ComplexType>(eltType)) + return llvm::alignTo<8>(getDenseElementBitWidth(comp.getElementType())) * 2; + if (eltType.isIndex()) + return IndexType::kInternalStorageBitWidth; + return eltType.getIntOrFloatBitWidth(); } /// An attribute representing a reference to a dense vector or tensor object. diff --git a/mlir/lib/IR/BuiltinAttributes.cpp b/mlir/lib/IR/BuiltinAttributes.cpp index bbbc9198a68ab..1a29fc534b40f 100644 --- a/mlir/lib/IR/BuiltinAttributes.cpp +++ b/mlir/lib/IR/BuiltinAttributes.cpp @@ -10,7 +10,6 @@ #include "AttributeDetail.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectResourceBlobManager.h" #include "mlir/IR/IntegerSet.h" @@ -528,7 +527,7 @@ static void copyArrayToAPIntForBEmachine(const char *inArray, size_t numBytes, } /// Writes value to the bit position `bitPos` in array `rawData`. -void mlir::detail::writeBits(char *rawData, size_t bitPos, APInt value) { +static void writeBits(char *rawData, size_t bitPos, APInt value) { size_t bitWidth = value.getBitWidth(); // The bit position is guaranteed to be byte aligned. @@ -550,8 +549,7 @@ void mlir::detail::writeBits(char *rawData, size_t bitPos, APInt value) { /// Reads the next `bitWidth` bits from the bit position `bitPos` in array /// `rawData`. -APInt mlir::detail::readBits(const char *rawData, size_t bitPos, - size_t bitWidth) { +static APInt readBits(const char *rawData, size_t bitPos, size_t bitWidth) { // The bit position is guaranteed to be byte aligned. assert((bitPos % CHAR_BIT) == 0 && "expected bitPos to be 8-bit aligned"); APInt result(bitWidth, 0); @@ -597,21 +595,39 @@ DenseElementsAttr::AttributeElementIterator::AttributeElementIterator( Attribute DenseElementsAttr::AttributeElementIterator::operator*() const { auto owner = llvm::cast<DenseElementsAttr>(getFromOpaquePointer(base)); Type eltTy = owner.getElementType(); + if (llvm::dyn_cast<IntegerType>(eltTy)) + return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); + if (llvm::isa<IndexType>(eltTy)) + return IntegerAttr::get(eltTy, *IntElementIterator(owner, index)); + if (auto floatEltTy = llvm::dyn_cast<FloatType>(eltTy)) { + IntElementIterator intIt(owner, index); + FloatElementIterator floatIt(floatEltTy.getFloatSemantics(), intIt); + return FloatAttr::get(eltTy, *floatIt); + } + if (auto complexTy = llvm::dyn_cast<ComplexType>(eltTy)) { + auto complexEltTy = complexTy.getElementType(); + ComplexIntElementIterator complexIntIt(owner, index); + if (llvm::isa<IntegerType>(complexEltTy)) { + auto value = *complexIntIt; + auto real = IntegerAttr::get(complexEltTy, value.real()); + auto imag = IntegerAttr::get(complexEltTy, value.imag()); + return ArrayAttr::get(complexTy.getContext(), + ArrayRef<Attribute>{real, imag}); + } - // Handle strings specially. + ComplexFloatElementIterator complexFloatIt( + llvm::cast<FloatType>(complexEltTy).getFloatSemantics(), complexIntIt); + auto value = *complexFloatIt; + auto real = FloatAttr::get(complexEltTy, value.real()); + auto imag = FloatAttr::get(complexEltTy, value.imag()); + return ArrayAttr::get(complexTy.getContext(), + ArrayRef<Attribute>{real, imag}); + } if (llvm::isa<DenseStringElementsAttr>(owner)) { ArrayRef<StringRef> vals = owner.getRawStringData(); return StringAttr::get(owner.isSplat() ? vals.front() : vals[index], eltTy); } - - // All other types should implement DenseElementTypeInterface. - auto denseEltTy = llvm::cast<DenseElementType>(eltTy); - ArrayRef<char> rawData = owner.getRawData(); - // Storage is byte-aligned: align bit size up to next byte boundary. - size_t bitSize = denseEltTy.getDenseElementBitSize(); - size_t byteSize = llvm::divideCeil(bitSize, CHAR_BIT); - size_t offset = owner.isSplat() ? 0 : index * byteSize; - return denseEltTy.convertToAttribute(rawData.slice(offset, byteSize)); + llvm_unreachable("unexpected element type"); } //===----------------------------------------------------------------------===// @@ -872,28 +888,79 @@ DenseElementsAttr DenseElementsAttr::get(ShapedType type, assert(hasSameNumElementsOrSplat(type, values)); Type eltType = type.getElementType(); - // Handle strings specially. - if (!llvm::isa<DenseElementType>(eltType)) { + // Take care complex type case first. + if (auto complexType = llvm::dyn_cast<ComplexType>(eltType)) { + if (complexType.getElementType().isIntOrIndex()) { + SmallVector<std::complex<APInt>> complexValues; + complexValues.reserve(values.size()); + for (Attribute attr : values) { + assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex"); + auto arrayAttr = llvm::cast<ArrayAttr>(attr); + assert(arrayAttr.size() == 2 && "expected 2 element for complex"); + auto attr0 = arrayAttr[0]; + auto attr1 = arrayAttr[1]; + complexValues.push_back( + std::complex<APInt>(llvm::cast<IntegerAttr>(attr0).getValue(), + llvm::cast<IntegerAttr>(attr1).getValue())); + } + return DenseElementsAttr::get(type, complexValues); + } + // Must be float. + SmallVector<std::complex<APFloat>> complexValues; + complexValues.reserve(values.size()); + for (Attribute attr : values) { + assert(llvm::isa<ArrayAttr>(attr) && "expected ArrayAttr for complex"); + auto arrayAttr = llvm::cast<ArrayAttr>(attr); + assert(arrayAttr.size() == 2 && "expected 2 element for complex"); + auto attr0 = arrayAttr[0]; + auto attr1 = arrayAttr[1]; + complexValues.push_back( + std::complex<APFloat>(llvm::cast<FloatAttr>(attr0).getValue(), + llvm::cast<FloatAttr>(attr1).getValue())); + } + return DenseElementsAttr::get(type, complexValues); + } + + // If the element type is not based on int/float/index, assume it is a string + // type. + if (!eltType.isIntOrIndexOrFloat()) { SmallVector<StringRef, 8> stringValues; stringValues.reserve(values.size()); for (Attribute attr : values) { assert(llvm::isa<StringAttr>(attr) && - "expected string value for non-DenseElementType element"); + "expected string value for non integer/index/float element"); stringValues.push_back(llvm::cast<StringAttr>(attr).getValue()); } return get(type, stringValues); } - // All other types go through DenseElementTypeInterface. - auto denseEltType = llvm::dyn_cast<DenseElementType>(eltType); - assert(denseEltType && - "attempted to get DenseElementsAttr with unsupported element type"); - SmallVector<char> data; - for (Attribute attr : values) { - LogicalResult result = denseEltType.convertFromAttribute(attr, data); - if (failed(result)) + // Otherwise, get the raw storage width to use for the allocation. + size_t bitWidth = getDenseElementBitWidth(eltType); + size_t storageBitWidth = getDenseElementStorageWidth(bitWidth); + + // Compress the attribute values into a character buffer. + SmallVector<char, 8> data( + llvm::divideCeil(storageBitWidth * values.size(), CHAR_BIT)); + APInt intVal; + for (unsigned i = 0, e = values.size(); i < e; ++i) { + if (auto floatAttr = llvm::dyn_cast<FloatAttr>(values[i])) { + assert(floatAttr.getType() == eltType && + "expected float attribute type to equal element type"); + intVal = floatAttr.getValue().bitcastToAPInt(); + } else if (auto intAttr = llvm::dyn_cast<IntegerAttr>(values[i])) { + assert(intAttr.getType() == eltType && + "expected integer attribute type to equal element type"); + intVal = intAttr.getValue(); + } else { + // Unsupported attribute type. return {}; + } + + assert(intVal.getBitWidth() == bitWidth && + "expected value to have same bitwidth as element type"); + writeBits(data.data(), i * storageBitWidth, intVal); } + return DenseIntOrFPElementsAttr::getRaw(type, data); } diff --git a/mlir/lib/IR/BuiltinTypeInterfaces.cpp b/mlir/lib/IR/BuiltinTypeInterfaces.cpp index 29303d95eb003..2f063be3e7cd0 100644 --- a/mlir/lib/IR/BuiltinTypeInterfaces.cpp +++ b/mlir/lib/IR/BuiltinTypeInterfaces.cpp @@ -6,12 +6,9 @@ // //===----------------------------------------------------------------------===// -#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "llvm/ADT/APFloat.h" #include "llvm/Support/CheckedArithmetic.h" -#include "llvm/Support/MathExtras.h" -#include <climits> using namespace mlir; using namespace mlir::detail; @@ -22,37 +19,6 @@ using namespace mlir::detail; #include "mlir/IR/BuiltinTypeInterfaces.cpp.inc" -//===----------------------------------------------------------------------===// -// DenseElementTypeInterface implementations for float types -//===----------------------------------------------------------------------===// - -size_t mlir::detail::getFloatTypeDenseElementBitSize(Type type) { - return cast<FloatType>(type).getWidth(); -} - -Attribute mlir::detail::convertFloatTypeToAttribute(Type type, - ArrayRef<char> rawData) { - auto floatType = cast<FloatType>(type); - APInt intVal = readBits(rawData.data(), /*bitPos=*/0, floatType.getWidth()); - APFloat floatVal(floatType.getFloatSemantics(), intVal); - return FloatAttr::get(type, floatVal); -} - -LogicalResult -mlir::detail::convertFloatTypeFromAttribute(Type type, Attribute attr, - SmallVectorImpl<char> &result) { - auto floatType = cast<FloatType>(type); - auto floatAttr = dyn_cast<FloatAttr>(attr); - if (!floatAttr || floatAttr.getType() != type) - return failure(); - size_t byteSize = - llvm::divideCeil(floatType.getWidth(), static_cast<unsigned>(CHAR_BIT)); - size_t bitPos = result.size() * CHAR_BIT; - result.resize(result.size() + byteSize); - writeBits(result.data(), bitPos, floatAttr.getValue().bitcastToAPInt()); - return success(); -} - //===----------------------------------------------------------------------===// // FloatType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp index 786c30851a071..1e198043c590a 100644 --- a/mlir/lib/IR/BuiltinTypes.cpp +++ b/mlir/lib/IR/BuiltinTypes.cpp @@ -12,17 +12,14 @@ #include "mlir/IR/AffineMap.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinDialect.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/TensorEncoding.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/APFloat.h" -#include "llvm/ADT/APInt.h" #include "llvm/ADT/Sequence.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/CheckedArithmetic.h" -#include <cstring> using namespace mlir; using namespace mlir::detail; @@ -61,39 +58,6 @@ LogicalResult ComplexType::verify(function_ref<InFlightDiagnostic()> emitError, return success(); } -size_t ComplexType::getDenseElementBitSize() const { - auto elemTy = cast<DenseElementType>(getElementType()); - return llvm::alignTo<8>(elemTy.getDenseElementBitSize()) * 2; -} - -Attribute ComplexType::convertToAttribute(ArrayRef<char> rawData) const { - auto elemTy = cast<DenseElementType>(getElementType()); - size_t singleElementBytes = - llvm::alignTo<8>(elemTy.getDenseElementBitSize()) / 8; - Attribute real = - elemTy.convertToAttribute(rawData.take_front(singleElementBytes)); - Attribute imag = - elemTy.convertToAttribute(rawData.take_back(singleElementBytes)); - return ArrayAttr::get(getContext(), {real, imag}); -} - -LogicalResult -ComplexType::convertFromAttribute(Attribute attr, - SmallVectorImpl<char> &result) const { - auto arrayAttr = dyn_cast<ArrayAttr>(attr); - if (!arrayAttr || arrayAttr.size() != 2) - return failure(); - auto elemTy = cast<DenseElementType>(getElementType()); - SmallVector<char> realData, imagData; - if (failed(elemTy.convertFromAttribute(arrayAttr[0], realData))) - return failure(); - if (failed(elemTy.convertFromAttribute(arrayAttr[1], imagData))) - return failure(); - result.append(realData); - result.append(imagData); - return success(); -} - //===----------------------------------------------------------------------===// // Integer Type //===----------------------------------------------------------------------===// @@ -121,57 +85,6 @@ IntegerType IntegerType::scaleElementBitwidth(unsigned scale) { return IntegerType::get(getContext(), scale * getWidth(), getSignedness()); } -size_t IntegerType::getDenseElementBitSize() const { - // Return the actual bit width. Storage alignment is handled separately. - return getWidth(); -} - -Attribute IntegerType::convertToAttribute(ArrayRef<char> rawData) const { - APInt value = detail::readBits(rawData.data(), /*bitPos=*/0, getWidth()); - return IntegerAttr::get(*this, value); -} - -static void writeAPIntToVector(APInt apInt, SmallVectorImpl<char> &result) { - size_t byteSize = llvm::divideCeil(apInt.getBitWidth(), CHAR_BIT); - size_t bitPos = result.size() * CHAR_BIT; - result.resize(result.size() + byteSize); - detail::writeBits(result.data(), bitPos, apInt); -} - -LogicalResult -IntegerType::convertFromAttribute(Attribute attr, - SmallVectorImpl<char> &result) const { - auto intAttr = dyn_cast<IntegerAttr>(attr); - if (!intAttr || intAttr.getType() != *this) - return failure(); - writeAPIntToVector(intAttr.getValue(), result); - return success(); -} - -//===----------------------------------------------------------------------===// -// Index Type -//===----------------------------------------------------------------------===// - -size_t IndexType::getDenseElementBitSize() const { - return kInternalStorageBitWidth; -} - -Attribute IndexType::convertToAttribute(ArrayRef<char> rawData) const { - APInt value = - detail::readBits(rawData.data(), /*bitPos=*/0, kInternalStorageBitWidth); - return IntegerAttr::get(*this, value); -} - -LogicalResult -IndexType::convertFromAttribute(Attribute attr, - SmallVectorImpl<char> &result) const { - auto intAttr = dyn_cast<IntegerAttr>(attr); - if (!intAttr || intAttr.getType() != *this) - return failure(); - writeAPIntToVector(intAttr.getValue(), result); - return success(); -} - //===----------------------------------------------------------------------===// // Float Types //===----------------------------------------------------------------------===// diff --git a/mlir/test/IR/dense-elements-type-interface.mlir b/mlir/test/IR/dense-elements-type-interface.mlir deleted file mode 100644 index 8749e562087c2..0000000000000 --- a/mlir/test/IR/dense-elements-type-interface.mlir +++ /dev/null @@ -1,83 +0,0 @@ -// RUN: mlir-opt %s -verify-diagnostics -split-input-file | FileCheck %s - -// Test dense elements attribute with custom element type using DenseElementTypeInterface. -// Uses the new type-first syntax: dense<TYPE : [ATTR, ...]> -// Note: The type is embedded in the attribute, so it's not printed again at the end. - -// CHECK-LABEL: func @dense_custom_element_type -func.func @dense_custom_element_type() { - // CHECK: "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} - "test.dummy"() {attr = dense<tensor<3x!test.dense_element> : [1 : i32, 2 : i32, 3 : i32]>} : () -> () - return -} - -// ----- - -// CHECK-LABEL: func @dense_custom_element_type_2d -func.func @dense_custom_element_type_2d() { - // CHECK: "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : {{\[}}{{\[}}1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} - "test.dummy"() {attr = dense<tensor<2x2x!test.dense_element> : [[1 : i32, 2 : i32], [3 : i32, 4 : i32]]>} : () -> () - return -} - -// ----- - -// CHECK-LABEL: func @dense_custom_element_splat -func.func @dense_custom_element_splat() { - // CHECK: "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} - "test.dummy"() {attr = dense<tensor<4x!test.dense_element> : 42 : i32>} : () -> () - return -} - -// ----- - -// CHECK-LABEL func @dense_i32_1d -func.func @dense_i32_1d() { - // The default assembly format for int, index, float, complex element types is - // the literal-first syntax. Such a dense elements attribute can be parsed - // with the type-first syntax, but it will come back with the literal-first - // syntax. - // CHECK: "test.dummy"() {attr = dense<[1, 2, 3]> : tensor<3xi32>} : () -> () - "test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> () - return -} - -// ----- - -func.func @invalid_element() { - // expected-error @+1 {{expected attribute value}} - "test.dummy"() {attr = dense<tensor<3xi32> : [foo]>} : () -> () - return -} - -// ----- - -func.func @incompatible_attribute() { - // expected-error @+1 {{incompatible attribute for element type}} - "test.dummy"() {attr = dense<tensor<3xi32> : ["foo"]>} : () -> () - return -} - -// ----- - -func.func @shape_mismatch() { - // expected-error @+1 {{expected 3 elements in dimension, got 2}} - "test.dummy"() {attr = dense<tensor<3xi32> : [1 : i32, 2 : i32]>} : () -> () - return -} - -// ----- - -func.func @dynamic_shape() { - // expected-error @+1 {{dense elements type must have static shape}} - "test.dummy"() {attr = dense<tensor<?xi32> : [1 : i32, 2 : i32, 3 : i32]>} : () -> () - return -} - -// ----- - -func.func @invalid_type() { - // expected-error @+1 {{expected a shaped type for dense elements}} - "test.dummy"() {attr = dense<i32 : [1 : i32, 2 : i32, 3 : i32]>} : () -> () - return -} diff --git a/mlir/test/lib/Dialect/Test/TestTypeDefs.td b/mlir/test/lib/Dialect/Test/TestTypeDefs.td index 08600ce713a17..964792ceebc07 100644 --- a/mlir/test/lib/Dialect/Test/TestTypeDefs.td +++ b/mlir/test/lib/Dialect/Test/TestTypeDefs.td @@ -18,7 +18,6 @@ include "TestDialect.td" include "TestAttrDefs.td" include "TestInterfaces.td" include "mlir/IR/BuiltinTypes.td" -include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" @@ -513,15 +512,4 @@ def TestTypeNewlineAndIndent : Test_Type<"TestTypeNewlineAndIndent"> { let hasCustomAssemblyFormat = 1; } -def TestTypeDenseElement : Test_Type<"TestDenseElement", - [DeclareTypeInterfaceMethods<DenseElementTypeInterface, - ["getDenseElementBitSize", "convertToAttribute", "convertFromAttribute"]> - ]> { - let mnemonic = "dense_element"; - let description = [{ - A test type that implements DenseElementTypeInterface to test dense - elements with custom element types. Elements are stored as 32-bit integers. - }]; -} - #endif // TEST_TYPEDEFS diff --git a/mlir/test/lib/Dialect/Test/TestTypes.cpp b/mlir/test/lib/Dialect/Test/TestTypes.cpp index ef3396fc4f610..71dd25b0093e0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.cpp +++ b/mlir/test/lib/Dialect/Test/TestTypes.cpp @@ -15,7 +15,6 @@ #include "TestDialect.h" #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/DialectImplementation.h" #include "mlir/IR/ExtensibleDialect.h" #include "mlir/IR/Types.h" @@ -23,7 +22,6 @@ #include "llvm/ADT/SetVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/TypeSize.h" -#include <cstring> #include <optional> using namespace mlir; @@ -607,29 +605,3 @@ void TestTypeNewlineAndIndentType::print(::mlir::AsmPrinter &printer) const { printer.printNewline(); printer << ">"; } - -//===----------------------------------------------------------------------===// -// TestDenseElementType - DenseElementTypeInterface Implementation -//===----------------------------------------------------------------------===// - -// Elements are stored as 32-bit integers. -size_t TestDenseElementType::getDenseElementBitSize() const { return 32; } - -Attribute -TestDenseElementType::convertToAttribute(ArrayRef<char> rawData) const { - assert(rawData.size() == 4 && "expected 4 bytes for TestDenseElement"); - int32_t value; - std::memcpy(&value, rawData.data(), sizeof(value)); - return IntegerAttr::get(IntegerType::get(getContext(), 32), value); -} - -LogicalResult TestDenseElementType::convertFromAttribute( - Attribute attr, SmallVectorImpl<char> &result) const { - auto intAttr = dyn_cast<IntegerAttr>(attr); - if (!intAttr || intAttr.getType().getIntOrFloatBitWidth() != 32) - return failure(); - int32_t value = intAttr.getValue().getSExtValue(); - result.append(reinterpret_cast<const char *>(&value), - reinterpret_cast<const char *>(&value) + sizeof(value)); - return success(); -} diff --git a/mlir/test/lib/Dialect/Test/TestTypes.h b/mlir/test/lib/Dialect/Test/TestTypes.h index 705fb86e9e9b3..6499a96f495d0 100644 --- a/mlir/test/lib/Dialect/Test/TestTypes.h +++ b/mlir/test/lib/Dialect/Test/TestTypes.h @@ -19,7 +19,6 @@ #include "TestTraits.h" #include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/Diagnostics.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/DialectImplementation.h" _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
