https://github.com/krzysz00 updated https://github.com/llvm/llvm-project/pull/124876
>From 4fcff8c53af4055b6d92c5399e9f88a7fea18677 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak <krzysdrewn...@gmail.com> Date: Tue, 28 Jan 2025 20:25:38 -0800 Subject: [PATCH 1/2] [mlir][ODS] Switch declarative rewrite rules to properties structs Now that we have collective builders that take `const [RelevantOp]::Properties &` arguments, we don't need to serialize all the attributes that'll be set during an output pattern into a dictionary attribute. Similarly, we can use the properties struct to get the attributes instead of needing to go through the big if statement in getAttrOfType<>(). This also enables us to have declarative rewrite rules that match non-attribute properties in a future PR. This commit also adds a basic test for the generated matchers since there didn't seem to already be one. --- .../rewriter-attributes-properties.td | 47 +++++++++++ mlir/tools/mlir-tblgen/RewriterGen.cpp | 81 +++++++++++++------ 2 files changed, 105 insertions(+), 23 deletions(-) create mode 100644 mlir/test/mlir-tblgen/rewriter-attributes-properties.td diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td new file mode 100644 index 0000000000000..77869d36cc12e --- /dev/null +++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td @@ -0,0 +1,47 @@ +// RUN: mlir-tblgen -gen-rewriters -I %S/../../include %s | FileCheck %s + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" + +def Test_Dialect : Dialect { + let name = "test"; +} +class NS_Op<string mnemonic, list<Trait> traits> : + Op<Test_Dialect, mnemonic, traits>; + +def AOp : NS_Op<"a_op", []> { + let arguments = (ins + I32:$x, + I32Attr:$y + ); + + let results = (outs I32:$z); +} + +def BOp : NS_Op<"b_op", []> { + let arguments = (ins + I32Attr:$y + ); + + let results = (outs I32:$z); +} + +def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>; +// CHECK-LABEL: struct test1 +// CHECK: ::llvm::LogicalResult matchAndRewrite +// CHECK: ::mlir::IntegerAttr y; +// CHECK: test::BOp x; +// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; +// CHECK: tblgen_ops.push_back(op0); +// CHECK: x = castedOp1; +// CHECK: tblgen_attr = castedOp1.getProperties().getY(); +// CHECK: if (!(tblgen_attr)) +// CHECK: y = tblgen_attr; +// CHECK: tblgen_ops.push_back(op1); + +// CHECK: test::AOp tblgen_AOp_0; +// CHECK: ::llvm::SmallVector<::mlir::Value, 4> tblgen_values; +// CHECK: test::AOp::Properties tblgen_props; +// CHECK: tblgen_values.push_back((*x.getODSResults(0).begin())); +// CHECK: tblgen_props.y = ::llvm::dyn_cast_if_present<decltype(tblgen_props.y)>(y); +// CHECK: tblgen_AOp_0 = rewriter.create<test::AOp>(odsLoc, tblgen_types, tblgen_values, tblgen_props); diff --git a/mlir/tools/mlir-tblgen/RewriterGen.cpp b/mlir/tools/mlir-tblgen/RewriterGen.cpp index f6eb5bdfe568e..5dd4f87a6d0ce 100644 --- a/mlir/tools/mlir-tblgen/RewriterGen.cpp +++ b/mlir/tools/mlir-tblgen/RewriterGen.cpp @@ -122,7 +122,7 @@ class PatternEmitter { // Emits C++ statements for matching the `argIndex`-th argument of the given // DAG `tree` as an attribute. - void emitAttributeMatch(DagNode tree, StringRef opName, int argIndex, + void emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex, int depth); // Emits C++ for checking a match with a corresponding match failure @@ -664,7 +664,7 @@ void PatternEmitter::emitOpMatch(DagNode tree, StringRef opName, int depth) { /*variadicSubIndex=*/std::nullopt); ++nextOperand; } else if (isa<NamedAttribute *>(opArg)) { - emitAttributeMatch(tree, opName, opArgIdx, depth); + emitAttributeMatch(tree, castedName, opArgIdx, depth); } else { PrintFatalError(loc, "unhandled case when matching op"); } @@ -864,16 +864,22 @@ void PatternEmitter::emitVariadicOperandMatch(DagNode tree, os.unindent() << "}\n"; } -void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, +void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef castedName, int argIndex, int depth) { Operator &op = tree.getDialectOp(opMap); auto *namedAttr = cast<NamedAttribute *>(op.getArg(argIndex)); const auto &attr = namedAttr->attr; os << "{\n"; - os.indent() << formatv("auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" - "(void)tblgen_attr;\n", - opName, attr.getStorageType(), namedAttr->name); + if (op.getDialect().usePropertiesForAttributes()) { + os.indent() << formatv("auto tblgen_attr = {0}.getProperties().{1}();\n", + castedName, op.getGetterName(namedAttr->name)); + } else { + os.indent() << formatv( + "auto tblgen_attr = {0}->getAttrOfType<{1}>(\"{2}\");" + "(void)tblgen_attr;\n", + castedName, attr.getStorageType(), namedAttr->name); + } // TODO: This should use getter method to avoid duplication. if (attr.hasDefaultValue()) { @@ -887,7 +893,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, // That is precisely what getDiscardableAttr() returns on missing // attributes. } else { - emitMatchCheck(opName, tgfmt("tblgen_attr", &fmtCtx), + emitMatchCheck(castedName, tgfmt("tblgen_attr", &fmtCtx), formatv("\"expected op '{0}' to have attribute '{1}' " "of type '{2}'\"", op.getOperationName(), namedAttr->name, @@ -918,7 +924,7 @@ void PatternEmitter::emitAttributeMatch(DagNode tree, StringRef opName, } } emitStaticVerifierCall( - verifier, opName, "tblgen_attr", + verifier, castedName, "tblgen_attr", formatv("\"op '{0}' attribute '{1}' failed to satisfy constraint: " "'{2}'\"", op.getOperationName(), namedAttr->name, @@ -1532,6 +1538,7 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, LLVM_DEBUG(llvm::dbgs() << '\n'); Operator &resultOp = tree.getDialectOp(opMap); + bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); auto numOpArgs = resultOp.getNumArgs(); auto numPatArgs = tree.getNumArgs(); @@ -1623,9 +1630,10 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, createAggregateLocalVarsForOpArgs(tree, childNodeNames, depth); // Then create the op. - os.scope("", "\n}\n").os << formatv( - "{0} = rewriter.create<{1}>({2}, tblgen_values, tblgen_attrs);", - valuePackName, resultOp.getQualCppClassName(), locToUse); + os.scope("", "\n}\n").os + << formatv("{0} = rewriter.create<{1}>({2}, tblgen_values, {3});", + valuePackName, resultOp.getQualCppClassName(), locToUse, + useProperties ? "tblgen_props" : "tblgen_attrs"); return resultValue; } @@ -1682,8 +1690,9 @@ std::string PatternEmitter::handleOpCreation(DagNode tree, int resultIndex, } } os << formatv("{0} = rewriter.create<{1}>({2}, tblgen_types, " - "tblgen_values, tblgen_attrs);\n", - valuePackName, resultOp.getQualCppClassName(), locToUse); + "tblgen_values, {3});\n", + valuePackName, resultOp.getQualCppClassName(), locToUse, + useProperties ? "tblgen_props" : "tblgen_attrs"); os.unindent() << "}\n"; return resultValue; } @@ -1791,12 +1800,21 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( DagNode node, const ChildNodeIndexNameMap &childNodeNames, int depth) { Operator &resultOp = node.getDialectOp(opMap); + bool useProperties = resultOp.getDialect().usePropertiesForAttributes(); auto scope = os.scope(); os << formatv("::llvm::SmallVector<::mlir::Value, 4> " "tblgen_values; (void)tblgen_values;\n"); - os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " - "tblgen_attrs; (void)tblgen_attrs;\n"); + if (useProperties) { + os << formatv("{0}::Properties tblgen_props; (void)tblgen_props;\n", + resultOp.getQualCppClassName()); + } else { + os << formatv("::llvm::SmallVector<::mlir::NamedAttribute, 4> " + "tblgen_attrs; (void)tblgen_attrs;\n"); + } + const char *setPropCmd = + "tblgen_props.{0} = " + "::llvm::dyn_cast_if_present<decltype(tblgen_props.{0})>({1});\n"; const char *addAttrCmd = "if (auto tmpAttr = {1}) {\n" " tblgen_attrs.emplace_back(rewriter.getStringAttr(\"{0}\"), " @@ -1814,13 +1832,23 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( if (!subTree.isNativeCodeCall()) PrintFatalError(loc, "only NativeCodeCall allowed in nested dag node " "for creating attribute"); - os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); + + if (useProperties) { + os << formatv(setPropCmd, opArgName, childNodeNames.lookup(argIndex)); + } else { + os << formatv(addAttrCmd, opArgName, childNodeNames.lookup(argIndex)); + } } else { auto leaf = node.getArgAsLeaf(argIndex); // The argument in the result DAG pattern. auto patArgName = node.getArgName(argIndex); - os << formatv(addAttrCmd, opArgName, - handleOpArgument(leaf, patArgName)); + if (useProperties) { + os << formatv(setPropCmd, opArgName, + handleOpArgument(leaf, patArgName)); + } else { + os << formatv(addAttrCmd, opArgName, + handleOpArgument(leaf, patArgName)); + } } continue; } @@ -1876,11 +1904,18 @@ void PatternEmitter::createAggregateLocalVarsForOpArgs( const auto *sameVariadicSize = resultOp.getTrait("::mlir::OpTrait::SameVariadicOperandSize"); if (!sameVariadicSize) { - const char *setSizes = R"( - tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), - rewriter.getDenseI32ArrayAttr({{ {0} })); - )"; - os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); + if (useProperties) { + const char *setSizes = R"( + tblgen_props.operandSegmentSizes = {{ {0} }; + )"; + os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); + } else { + const char *setSizes = R"( + tblgen_attrs.emplace_back(rewriter.getStringAttr("operandSegmentSizes"), + rewriter.getDenseI32ArrayAttr({{ {0} })); + )"; + os.printReindented(formatv(setSizes, llvm::join(sizes, ", ")).str()); + } } } } >From d059f73171310d8d59e23282bbee73c64184989b Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak <krzysdrewn...@gmail.com> Date: Tue, 28 Jan 2025 21:10:01 -0800 Subject: [PATCH 2/2] Test fails on Windows, try to loosen it up --- mlir/test/mlir-tblgen/rewriter-attributes-properties.td | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td index 77869d36cc12e..fc36a51789ec2 100644 --- a/mlir/test/mlir-tblgen/rewriter-attributes-properties.td +++ b/mlir/test/mlir-tblgen/rewriter-attributes-properties.td @@ -29,9 +29,9 @@ def BOp : NS_Op<"b_op", []> { def test1 : Pat<(AOp (BOp:$x $y), $_), (AOp $x, $y)>; // CHECK-LABEL: struct test1 // CHECK: ::llvm::LogicalResult matchAndRewrite -// CHECK: ::mlir::IntegerAttr y; -// CHECK: test::BOp x; -// CHECK: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; +// CHECK-DAG: ::mlir::IntegerAttr y; +// CHECK-DAG: test::BOp x; +// CHECK-DAG: ::llvm::SmallVector<::mlir::Operation *, 4> tblgen_ops; // CHECK: tblgen_ops.push_back(op0); // CHECK: x = castedOp1; // CHECK: tblgen_attr = castedOp1.getProperties().getY(); _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits