https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/124713
>From e7ae9837e983d62c4b6bff04e3b193915c80d8af Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak <krzysdrewn...@gmail.com> Date: Sat, 18 Jan 2025 00:01:35 -0800 Subject: [PATCH 1/3] [mlir][ODS] Add a collective builder that takes the Properties struct This commit adds builders of the form ``` static void build(..., [TypeRange resultTypes], ValueRange operands, const Properties &properties, ArrayRef<NamedAttribute> discardableAttributes = {}, [unsigned numRegions]); ``` to go alongside the existing result/operands/[inherent + discardable attribute list] collective builders. This change is intended to support a refactor to the declarative rewrite engine to make it populate the `Properties` struct instead of creating a `DictionaryAttr`, thus enabling rewrite rules to handle non-`Attribute` properties. More generally, this means that generic code that would previously call `getAttrs()` to blend together inherent and discardable attributes can now use `getProperties()` and `getDiscardableAttrs()` separately, thus removing the need to serialize everything into a temporary `DictionaryAttr`. --- mlir/docs/DeclarativeRewrites.md | 4 +- mlir/docs/DefiningDialects/Operations.md | 27 +++- mlir/include/mlir/IR/OpDefinition.h | 5 +- mlir/include/mlir/IR/OperationSupport.h | 18 +++ mlir/test/lib/Dialect/Test/TestOps.td | 7 + mlir/test/mlir-tblgen/op-attribute.td | 12 ++ mlir/test/mlir-tblgen/op-decl-and-defs.td | 8 + mlir/test/mlir-tblgen/op-result.td | 11 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 162 +++++++++++++++----- 9 files changed, 205 insertions(+), 49 deletions(-) diff --git a/mlir/docs/DeclarativeRewrites.md b/mlir/docs/DeclarativeRewrites.md index 888ce57fa3b53..fd566a2393b63 100644 --- a/mlir/docs/DeclarativeRewrites.md +++ b/mlir/docs/DeclarativeRewrites.md @@ -237,9 +237,9 @@ In the above, we are using `BOp`'s result for building `COp`. Given that `COp` was specified with table-driven op definition, there will be several `build()` methods generated for it. One of them has aggregated -parameters for result types, operands, and attributes in the signature: `void +parameters for result types, operands, and properties in the signature: `void COp::build(..., ArrayRef<Type> resultTypes, Array<Value> operands, -ArrayRef<NamedAttribute> attr)`. The pattern in the above calls this `build()` +const COp::Properties& properties)`. The pattern in the above calls this `build()` method for constructing the `COp`. In general, arguments in the result pattern will be passed directly to the diff --git a/mlir/docs/DefiningDialects/Operations.md b/mlir/docs/DefiningDialects/Operations.md index 8ff60ac21424c..528070cd3ebff 100644 --- a/mlir/docs/DefiningDialects/Operations.md +++ b/mlir/docs/DefiningDialects/Operations.md @@ -465,7 +465,18 @@ def MyOp : ... { The following builders are generated: ```c++ +// All result-types/operands/properties/discardable attributes have one +// aggregate parameter. `Properties` is the properties structure of +// `MyOp`. +static void build(OpBuilder &odsBuilder, OperationState &odsState, + TypeRange resultTypes, + ValueRange operands, + Properties properties, + ArrayRef<NamedAttribute> discardableAttributes = {}); + // All result-types/operands/attributes have one aggregate parameter. +// Inherent properties and discardable attributes are mixed together in the +// `attributes` dictionary. static void build(OpBuilder &odsBuilder, OperationState &odsState, TypeRange resultTypes, ValueRange operands, @@ -498,20 +509,28 @@ static void build(OpBuilder &odsBuilder, OperationState &odsState, // All operands/attributes have aggregate parameters. // Generated if return type can be inferred. +static void build(OpBuilder &odsBuilder, OperationState &odsState, + ValueRange operands, + Properties properties, + ArrayRef<NamedAttribute> discardableAttributes); + +// All operands/attributes have aggregate parameters. +// Generated if return type can be inferred. Uses the legacy merged attribute +// dictionary. static void build(OpBuilder &odsBuilder, OperationState &odsState, ValueRange operands, ArrayRef<NamedAttribute> attributes); // (And manually specified builders depending on the specific op.) ``` -The first form provides basic uniformity so that we can create ops using the -same form regardless of the exact op. This is particularly useful for +The first two forms provide basic uniformity so that we can create ops using +the same form regardless of the exact op. This is particularly useful for implementing declarative pattern rewrites. -The second and third forms are good for use in manually written code, given that +The third and fourth forms are good for use in manually written code, given that they provide better guarantee via signatures. -The third form will be generated if any of the op's attribute has different +The fourth form will be generated if any of the op's attribute has different `Attr.returnType` from `Attr.storageType` and we know how to build an attribute from an unwrapped value (i.e., `Attr.constBuilderCall` is defined.) Additionally, for the third form, if an attribute appearing later in the diff --git a/mlir/include/mlir/IR/OpDefinition.h b/mlir/include/mlir/IR/OpDefinition.h index d91c573c03efe..4fad61580b31a 100644 --- a/mlir/include/mlir/IR/OpDefinition.h +++ b/mlir/include/mlir/IR/OpDefinition.h @@ -74,7 +74,10 @@ void ensureRegionTerminator( /// Structure used by default as a "marker" when no "Properties" are set on an /// Operation. -struct EmptyProperties {}; +struct EmptyProperties { + bool operator==(const EmptyProperties &) const { return true; } + bool operator!=(const EmptyProperties &) const { return false; } +}; /// Traits to detect whether an Operation defined a `Properties` type, otherwise /// it'll default to `EmptyProperties`. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h index d4035d14ab746..4ad1e5ff789a9 100644 --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -1017,6 +1017,24 @@ struct OperationState { setProperties(Operation *op, function_ref<InFlightDiagnostic()> emitError) const; + // Make `newProperties` the source of the properties that will be copied into + // the operation. The memory referenced by `newProperties` must remain live + // until after the `Operation` is created, at which time it may be + // deallocated. Calls to `getOrAddProperties<>() will return references to + // this memory. + template <typename T> + void useProperties(T &newProperties) { + assert(!properties && + "Can't provide a properties struct when one has been allocated"); + properties = &newProperties; + propertiesDeleter = [](OpaqueProperties) {}; + propertiesSetter = [](OpaqueProperties newProp, + const OpaqueProperties prop) { + *newProp.as<T *>() = *prop.as<const T *>(); + }; + propertiesId = TypeID::get<T>(); + } + void addOperands(ValueRange newOperands); void addTypes(ArrayRef<Type> newTypes) { diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 2aa0658ab0e5d..7d46c04440909 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2485,6 +2485,13 @@ def TableGenBuildOp6 : TEST_Op<"tblgen_build_6", [AttrSizedOperandSegments]> { let results = (outs F32:$result); } +// An inherent attribute. Test collective builders, both those that take properties as +// properties structs and those that take an attribute dictionary. +def TableGenBuildOp7 : TEST_Op<"tblgen_build_7", []> { + let arguments = (ins BoolAttr:$attr0); + let results = (outs); +} + //===----------------------------------------------------------------------===// // Test BufferPlacement //===----------------------------------------------------------------------===// diff --git a/mlir/test/mlir-tblgen/op-attribute.td b/mlir/test/mlir-tblgen/op-attribute.td index 55382a5bd3f8c..549830e06042f 100644 --- a/mlir/test/mlir-tblgen/op-attribute.td +++ b/mlir/test/mlir-tblgen/op-attribute.td @@ -165,6 +165,12 @@ def AOp : NS_Op<"a_op", []> { // DEF: ::llvm::ArrayRef<::mlir::NamedAttribute> attributes // DEF: odsState.addAttributes(attributes); +// DEF: void AOp::build( +// DEF-SAME: const Properties &properties, +// DEF-SAME: ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes +// DEF: odsState.useProperties(const_cast<Properties&>(properties)); +// DEF: odsState.addAttributes(discardableAttributes); + // DEF: void AOp::populateDefaultProperties // Test the above but with prefix. @@ -279,6 +285,12 @@ def AgetOp : Op<Test2_Dialect, "a_get_op", []> { // DEF: ::llvm::ArrayRef<::mlir::NamedAttribute> attributes // DEF: odsState.addAttributes(attributes); +// DEF: void AgetOp::build( +// DEF-SAME: const Properties &properties +// DEF-SAME: ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes +// DEF: odsState.useProperties(const_cast<Properties&>(properties)); +// DEF: odsState.addAttributes(discardableAttributes); + // Test the above but using properties. def ApropOp : NS_Op<"a_prop_op", []> { let arguments = (ins diff --git a/mlir/test/mlir-tblgen/op-decl-and-defs.td b/mlir/test/mlir-tblgen/op-decl-and-defs.td index ee800a2952bac..a3dce9b31f8d2 100644 --- a/mlir/test/mlir-tblgen/op-decl-and-defs.td +++ b/mlir/test/mlir-tblgen/op-decl-and-defs.td @@ -119,6 +119,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> { // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type r, ::mlir::TypeRange s, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr some_attr2, unsigned someRegionsCount) // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::ValueRange b, uint32_t attr1, /*optional*/::mlir::FloatAttr some_attr2, unsigned someRegionsCount); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions) +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes, unsigned numRegions) // CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result); // CHECK: void print(::mlir::OpAsmPrinter &p); // CHECK: ::llvm::LogicalResult verifyInvariants(); @@ -231,6 +232,7 @@ def NS_HCollectiveParamsOp : NS_Op<"op_collective_params", []> { // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::Type b, ::mlir::Value a); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}) +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {}) // Check suppression of "separate arg, separate result" build method for an op // with single variadic arg and single variadic result (since it will be @@ -281,6 +283,8 @@ def NS_IOp : NS_Op<"op_with_same_operands_and_result_types_trait", [SameOperands // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {}); // Check default value of `attributes` for the `genInferredTypeCollectiveParamBuilder` builder def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterfaceMethods<InferTypeOpInterface>]> { @@ -293,6 +297,8 @@ def NS_JOp : NS_Op<"op_with_InferTypeOpInterface_interface", [DeclareOpInterface // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {}); // Test usage of TraitList getting flattened during emission. def NS_KOp : NS_Op<"k_op", [IsolatedFromAbove, @@ -329,6 +335,8 @@ def NS_LOp : NS_Op<"op_with_same_operands_and_result_types_unwrapped_attr", [Sam // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::Value a, ::mlir::Value b, uint32_t attr1); // CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); // CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes = {}); +// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {}); +// CHECK: static void build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes = {}); def NS_MOp : NS_Op<"op_with_single_result_and_fold_adaptor_fold", []> { let results = (outs AnyType:$res); diff --git a/mlir/test/mlir-tblgen/op-result.td b/mlir/test/mlir-tblgen/op-result.td index 212d3189cf571..a4f7af6dbcf1c 100644 --- a/mlir/test/mlir-tblgen/op-result.td +++ b/mlir/test/mlir-tblgen/op-result.td @@ -57,7 +57,9 @@ def OpD : NS_Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { // CHECK-LABEL: OpD definitions // CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) -// CHECK: odsState.addTypes({::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()}); +// CHECK: odsState.addTypes({::llvm::cast<::mlir::TypeAttr>(typeAttr).getValue()}); +// CHECK: void OpD::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes) +// CHECK: odsState.addTypes({::llvm::cast<::mlir::TypeAttr>(typeAttr).getValue()}); def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { let arguments = (ins I32:$x, F32Attr:$attr); @@ -66,7 +68,10 @@ def OpE : NS_Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { // CHECK-LABEL: OpE definitions // CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) -// CHECK: odsState.addTypes({::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()}); +// CHECK: odsState.addTypes({::llvm::cast<::mlir::TypedAttr>(typeAttr).getType()}); +// CHECK: void OpE::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes) +// CHECK: ::mlir::Attribute typeAttr = properties.getAttr(); +// CHECK: odsState.addTypes({::llvm::cast<::mlir::TypedAttr>(typeAttr).getType()}); def OpF : NS_Op<"one_variadic_result_op", []> { let results = (outs Variadic<I32>:$x); @@ -118,6 +123,8 @@ def OpK : NS_Op<"only_input_is_variadic_with_same_value_type_op", [SameOperandsA // CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes) // CHECK: odsState.addTypes({operands[0].getType()}); +// CHECK-LABEL: OpK::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ::mlir::ValueRange operands, const Properties &properties, ::llvm::ArrayRef<::mlir::NamedAttribute> discardableAttributes) +// CHECK: odsState.addTypes({operands[0].getType()}); // Test with inferred shapes and interleaved with operands/attributes. // diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 1970647a80115..b0396ad876071 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -411,6 +411,15 @@ class OpOrAdaptorHelper { return true; if (!op.getDialect().usePropertiesForAttributes()) return false; + return true; + } + + /// Returns whether the operation will have a non-empty `Properties` struct. + bool hasNonEmptyPropertiesStruct() const { + if (!op.getProperties().empty()) + return true; + if (!hasProperties()) + return false; if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments") || op.getTrait("::mlir::OpTrait::AttrSizedResultSegments")) return true; @@ -661,24 +670,33 @@ class OpEmitter { // type as all results' types. void genUseOperandAsResultTypeSeparateParamBuilder(); + // The kind of collective builder to generate + enum class CollectiveBuilderKind { + PropStruct, // Inherent attributes/properties are passed by `const + // Properties&` + AttrDict, // Inherent attributes/properties are passed by attribute + // dictionary + }; + // Generates the build() method that takes all operands/attributes // collectively as one parameter. The generated build() method uses first // operand's type as all results' types. - void genUseOperandAsResultTypeCollectiveParamBuilder(); + void + genUseOperandAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind); // Generates the build() method that takes aggregate operands/attributes // parameters. This build() method uses inferred types as result types. // Requires: The type needs to be inferable via InferTypeOpInterface. - void genInferredTypeCollectiveParamBuilder(); + void genInferredTypeCollectiveParamBuilder(CollectiveBuilderKind kind); - // Generates the build() method that takes each operand/attribute as a - // stand-alone parameter. The generated build() method uses first attribute's + // Generates the build() method that takesaggregate operands/attributes as + // parameters. The generated build() method uses first attribute's // type as all result's types. - void genUseAttrAsResultTypeBuilder(); + void genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind); // Generates the build() method that takes all result types collectively as // one parameter. Similarly for operands and attributes. - void genCollectiveParamBuilder(); + void genCollectiveParamBuilder(CollectiveBuilderKind kind); // The kind of parameter to generate for result types in builders. enum class TypeParamKind { @@ -1363,8 +1381,6 @@ void OpEmitter::genPropertiesSupport() { attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); if (emitHelper.getResultSegmentsSize()) attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); - if (attrOrProperties.empty()) - return; auto &setPropMethod = opClass .addStaticMethod( @@ -1728,6 +1744,9 @@ void OpEmitter::genPropertiesSupport() { void OpEmitter::genPropertiesSupportForBytecode( ArrayRef<ConstArgument> attrOrProperties) { + if (attrOrProperties.empty()) + return; + if (op.useCustomPropertiesEncoding()) { opClass.declareStaticMethod( "::llvm::LogicalResult", "readProperties", @@ -2644,7 +2663,8 @@ void OpEmitter::genSeparateArgParamBuilder() { } } -void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { +void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder( + CollectiveBuilderKind kind) { int numResults = op.getNumResults(); // Signature @@ -2652,10 +2672,15 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); + if (kind == CollectiveBuilderKind::PropStruct) + paramList.emplace_back("const Properties &", "properties"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; + StringRef attributesName = kind == CollectiveBuilderKind::PropStruct + ? "discardableAttributes" + : "attributes"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", attributesDefaultValue); + attributesName, attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); @@ -2668,8 +2693,12 @@ void OpEmitter::genUseOperandAsResultTypeCollectiveParamBuilder() { // Operands body << " " << builderOpState << ".addOperands(operands);\n"; + if (kind == CollectiveBuilderKind::PropStruct) + body << " " << builderOpState + << ".useProperties(const_cast<Properties&>(properties));\n"; // Attributes - body << " " << builderOpState << ".addAttributes(attributes);\n"; + body << " " << builderOpState << ".addAttributes(" << attributesName + << ");\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { @@ -2752,14 +2781,20 @@ void OpEmitter::genPopulateDefaultAttributes() { } } -void OpEmitter::genInferredTypeCollectiveParamBuilder() { +void OpEmitter::genInferredTypeCollectiveParamBuilder( + CollectiveBuilderKind kind) { SmallVector<MethodParameter> paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); + if (kind == CollectiveBuilderKind::PropStruct) + paramList.emplace_back("const Properties &", "properties"); StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; + StringRef attributesName = kind == CollectiveBuilderKind::PropStruct + ? "discardableAttributes" + : "attributes"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", attributesDefaultValue); + attributesName, attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); @@ -2784,7 +2819,11 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() { << numNonVariadicOperands << "u && \"mismatched number of parameters\");\n"; body << " " << builderOpState << ".addOperands(operands);\n"; - body << " " << builderOpState << ".addAttributes(attributes);\n"; + if (kind == CollectiveBuilderKind::PropStruct) + body << " " << builderOpState + << ".useProperties(const_cast<Properties &>(properties));\n"; + body << " " << builderOpState << ".addAttributes(" << attributesName + << ");\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { @@ -2795,7 +2834,8 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() { } // Result types - if (emitHelper.hasProperties()) { + if (emitHelper.hasNonEmptyPropertiesStruct() && + kind == CollectiveBuilderKind::AttrDict) { // Initialize the properties from Attributes before invoking the infer // function. body << formatv(R"( @@ -2867,13 +2907,18 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { emit(AttrParamKind::UnwrappedValue); } -void OpEmitter::genUseAttrAsResultTypeBuilder() { +void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind) { SmallVector<MethodParameter> paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::ValueRange", "operands"); + if (kind == CollectiveBuilderKind::PropStruct) + paramList.emplace_back("const Properties &", "properties"); + StringRef attributesName = kind == CollectiveBuilderKind::PropStruct + ? "discardableAttributes" + : "attributes"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", "{}"); + attributesName, "{}"); auto *m = opClass.addStaticMethod("void", "build", std::move(paramList)); // If the builder is redundant, skip generating the method if (!m) @@ -2885,28 +2930,41 @@ void OpEmitter::genUseAttrAsResultTypeBuilder() { std::string resultType; const auto &namedAttr = op.getAttribute(0); - body << " auto attrName = " << op.getGetterName(namedAttr.name) - << "AttrName(" << builderOpState - << ".name);\n" - " for (auto attr : attributes) {\n" - " if (attr.getName() != attrName) continue;\n"; if (namedAttr.attr.isTypeAttr()) { - resultType = "::llvm::cast<::mlir::TypeAttr>(attr.getValue()).getValue()"; + resultType = "::llvm::cast<::mlir::TypeAttr>(typeAttr).getValue()"; + } else { + resultType = "::llvm::cast<::mlir::TypedAttr>(typeAttr).getType()"; + } + + if (kind == CollectiveBuilderKind::PropStruct) { + body << " ::mlir::Attribute typeAttr = properties." << op.getGetterName(namedAttr.name) << "();\n"; } else { - resultType = "::llvm::cast<::mlir::TypedAttr>(attr.getValue()).getType()"; + body << " ::mlir::Attribute typeAttr;\n" + << " auto attrName = " << op.getGetterName(namedAttr.name) + << "AttrName(" << builderOpState + << ".name);\n" + " for (auto attr : attributes) {\n" + " if (attr.getName() == attrName) {\n" + " typeAttr = attr.getValue();\n" + " break;\n" + " }\n" + " }\n"; } // Operands body << " " << builderOpState << ".addOperands(operands);\n"; + // Properties + if (kind == CollectiveBuilderKind::PropStruct) + body << " " << builderOpState << ".useProperties(const_cast<Properties&>(properties));\n"; + // Attributes - body << " " << builderOpState << ".addAttributes(attributes);\n"; + body << " " << builderOpState << ".addAttributes(" << attributesName << ");\n"; // Result types SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); body << " " << builderOpState << ".addTypes({" << llvm::join(resultTypes, ", ") << "});\n"; - body << " }\n"; } /// Returns a signature of the builder. Updates the context `fctx` to enable @@ -2973,22 +3031,30 @@ void OpEmitter::genBuilder() { // 1. one having a stand-alone parameter for each operand / attribute, and genSeparateArgParamBuilder(); // 2. one having an aggregated parameter for all result types / operands / - // attributes, and - genCollectiveParamBuilder(); + // [properties / discardable] attributes, and + genCollectiveParamBuilder(CollectiveBuilderKind::AttrDict); + if (emitHelper.hasProperties()) + genCollectiveParamBuilder(CollectiveBuilderKind::PropStruct); // 3. one having a stand-alone parameter for each operand and attribute, // use the first operand or attribute's type as all result types // to facilitate different call patterns. if (op.getNumVariableLengthResults() == 0) { if (op.getTrait("::mlir::OpTrait::SameOperandsAndResultType")) { genUseOperandAsResultTypeSeparateParamBuilder(); - genUseOperandAsResultTypeCollectiveParamBuilder(); + genUseOperandAsResultTypeCollectiveParamBuilder( + CollectiveBuilderKind::AttrDict); + if (emitHelper.hasProperties()) + genUseOperandAsResultTypeCollectiveParamBuilder( + CollectiveBuilderKind::PropStruct); + } + if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) { + genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind::AttrDict); + genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind::PropStruct); } - if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) - genUseAttrAsResultTypeBuilder(); } } -void OpEmitter::genCollectiveParamBuilder() { +void OpEmitter::genCollectiveParamBuilder(CollectiveBuilderKind kind) { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariableLengthResults(); int numNonVariadicResults = numResults - numVariadicResults; @@ -3002,10 +3068,15 @@ void OpEmitter::genCollectiveParamBuilder() { paramList.emplace_back("::mlir::OperationState &", builderOpState); paramList.emplace_back("::mlir::TypeRange", "resultTypes"); paramList.emplace_back("::mlir::ValueRange", "operands"); + if (kind == CollectiveBuilderKind::PropStruct) + paramList.emplace_back("const Properties &", "properties"); // Provide default value for `attributes` when its the last parameter StringRef attributesDefaultValue = op.getNumVariadicRegions() ? "" : "{}"; + StringRef attributesName = kind == CollectiveBuilderKind::PropStruct + ? "discardableAttributes" + : "attributes"; paramList.emplace_back("::llvm::ArrayRef<::mlir::NamedAttribute>", - "attributes", attributesDefaultValue); + attributesName, attributesDefaultValue); if (op.getNumVariadicRegions()) paramList.emplace_back("unsigned", "numRegions"); @@ -3023,8 +3094,14 @@ void OpEmitter::genCollectiveParamBuilder() { << "u && \"mismatched number of parameters\");\n"; body << " " << builderOpState << ".addOperands(operands);\n"; + // Properties + if (kind == CollectiveBuilderKind::PropStruct) + body << " " << builderOpState + << ".useProperties(const_cast<Properties&>(properties));\n"; + // Attributes - body << " " << builderOpState << ".addAttributes(attributes);\n"; + body << " " << builderOpState << ".addAttributes(" << attributesName + << ");\n"; // Create the correct number of regions if (int numRegions = op.getNumRegions()) { @@ -3041,7 +3118,8 @@ void OpEmitter::genCollectiveParamBuilder() { << "u && \"mismatched number of return types\");\n"; body << " " << builderOpState << ".addTypes(resultTypes);\n"; - if (emitHelper.hasProperties()) { + if (emitHelper.hasNonEmptyPropertiesStruct() && + kind == CollectiveBuilderKind::AttrDict) { // Initialize the properties from Attributes before invoking the infer // function. body << formatv(R"( @@ -3060,7 +3138,7 @@ void OpEmitter::genCollectiveParamBuilder() { // Generate builder that infers type too. // TODO: Expand to handle successors. if (canInferType(op) && op.getNumSuccessors() == 0) - genInferredTypeCollectiveParamBuilder(); + genInferredTypeCollectiveParamBuilder(kind); } void OpEmitter::buildParamList(SmallVectorImpl<MethodParameter> ¶mList, @@ -4061,7 +4139,7 @@ void OpEmitter::genTraits() { // native/interface traits and after all the traits with `StructuralOpTrait`. opClass.addTrait("::mlir::OpTrait::OpInvariants"); - if (emitHelper.hasProperties()) + if (emitHelper.hasNonEmptyPropertiesStruct()) opClass.addTrait("::mlir::BytecodeOpInterface::Trait"); // Add the native and interface traits. @@ -4201,7 +4279,6 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( attrOrProperties.push_back(&emitHelper.getOperandSegmentsSize().value()); if (emitHelper.getResultSegmentsSize()) attrOrProperties.push_back(&emitHelper.getResultSegmentsSize().value()); - assert(!attrOrProperties.empty()); std::string declarations = " struct Properties {\n"; llvm::raw_string_ostream os(declarations); std::string comparator = @@ -4274,7 +4351,7 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( // Emit accessors using the interface type. if (attr) { const char *accessorFmt = R"decl( - auto get{0}() { + auto get{0}() const { auto &propStorage = this->{1}; return ::llvm::{2}<{3}>(propStorage); } @@ -4296,7 +4373,12 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter( os << comparator; os << " };\n"; - genericAdaptorBase.declare<ExtraClassDeclaration>(std::move(declarations)); + if (attrOrProperties.empty()) + genericAdaptorBase.declare<UsingDeclaration>("Properties", + "::mlir::EmptyProperties"); + else + genericAdaptorBase.declare<ExtraClassDeclaration>( + std::move(declarations)); } genericAdaptorBase.declare<VisibilityDeclaration>(Visibility::Protected); genericAdaptorBase.declare<Field>("::mlir::DictionaryAttr", "odsAttrs"); >From 058b5ec37eb6e5f43142ae1dcf22612a025b21c0 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak <krzysdrewn...@gmail.com> Date: Mon, 27 Jan 2025 23:22:58 -0800 Subject: [PATCH 2/3] Clang-format --- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index b0396ad876071..b957c8ee9f8ab 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -2907,7 +2907,8 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() { emit(AttrParamKind::UnwrappedValue); } -void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind kind) { +void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder( + CollectiveBuilderKind kind) { SmallVector<MethodParameter> paramList; paramList.emplace_back("::mlir::OpBuilder &", "odsBuilder"); paramList.emplace_back("::mlir::OperationState &", builderOpState); @@ -2937,7 +2938,8 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKi } if (kind == CollectiveBuilderKind::PropStruct) { - body << " ::mlir::Attribute typeAttr = properties." << op.getGetterName(namedAttr.name) << "();\n"; + body << " ::mlir::Attribute typeAttr = properties." + << op.getGetterName(namedAttr.name) << "();\n"; } else { body << " ::mlir::Attribute typeAttr;\n" << " auto attrName = " << op.getGetterName(namedAttr.name) @@ -2956,10 +2958,12 @@ void OpEmitter::genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKi // Properties if (kind == CollectiveBuilderKind::PropStruct) - body << " " << builderOpState << ".useProperties(const_cast<Properties&>(properties));\n"; + body << " " << builderOpState + << ".useProperties(const_cast<Properties&>(properties));\n"; // Attributes - body << " " << builderOpState << ".addAttributes(" << attributesName << ");\n"; + body << " " << builderOpState << ".addAttributes(" << attributesName + << ");\n"; // Result types SmallVector<std::string, 2> resultTypes(op.getNumResults(), resultType); @@ -3048,8 +3052,10 @@ void OpEmitter::genBuilder() { CollectiveBuilderKind::PropStruct); } if (op.getTrait("::mlir::OpTrait::FirstAttrDerivedResultType")) { - genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind::AttrDict); - genUseAttrAsResultTypeCollectiveParamBuilder(CollectiveBuilderKind::PropStruct); + genUseAttrAsResultTypeCollectiveParamBuilder( + CollectiveBuilderKind::AttrDict); + genUseAttrAsResultTypeCollectiveParamBuilder( + CollectiveBuilderKind::PropStruct); } } } >From 4e5cfedb79eed66bf6001f883112841617e294ec Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak <krzysdrewn...@gmail.com> Date: Tue, 28 Jan 2025 20:43:25 -0800 Subject: [PATCH 3/3] Right, there was supposed to be a unit test here --- mlir/unittests/TableGen/OpBuildGen.cpp | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/mlir/unittests/TableGen/OpBuildGen.cpp b/mlir/unittests/TableGen/OpBuildGen.cpp index 94fbfa28803c4..b4a5c316d63d6 100644 --- a/mlir/unittests/TableGen/OpBuildGen.cpp +++ b/mlir/unittests/TableGen/OpBuildGen.cpp @@ -291,4 +291,20 @@ TEST_F(OpBuildGenTest, BuildMethodsVariadicProperties) { verifyOp(std::move(op), {f32Ty}, {*cstI32}, {*cstI32}, attrs); } +TEST_F(OpBuildGenTest, BuildMethodsInherentDiscardableAttrs) { + test::TableGenBuildOp7::Properties props; + props.attr0 = cast<BoolAttr>(attrs[0].getValue()); + ArrayRef<NamedAttribute> discardableAttrs = attrs.drop_front(); + auto op7 = builder.create<test::TableGenBuildOp7>( + loc, TypeRange{}, ValueRange{}, props, discardableAttrs); + verifyOp(op7, {}, {}, attrs); + + // Check that the old-style builder where all the attributes go in the same + // place works. + auto op7b = builder.create<test::TableGenBuildOp7>(loc, TypeRange{}, + ValueRange{}, attrs); + verifyOp(op7b, {}, {}, attrs); + ASSERT_EQ(op7b.getProperties().getAttr0(), attrs[0].getValue()); +} + } // namespace mlir _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits