https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/155244
Depends on #155242. >From b217bce2ba7ecaf94d1e6364cac7b75f4ffb3f41 Mon Sep 17 00:00:00 2001 From: Matthias Springer <m...@m-sp.org> Date: Sat, 23 Aug 2025 10:36:37 +0000 Subject: [PATCH] [mlir][Transforms] Add support for `ConversionPatternRewriter::replaceAllUsesWith` --- mlir/include/mlir/IR/PatternMatch.h | 2 +- .../mlir/Transforms/DialectConversion.h | 17 +- mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp | 2 +- .../Transforms/Utils/DialectConversion.cpp | 158 +++++++++++------- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 5 +- 5 files changed, 112 insertions(+), 72 deletions(-) diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 57e73c1d8c7c1..7b0b9cef9c5bd 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -633,7 +633,7 @@ class RewriterBase : public OpBuilder { /// Find uses of `from` and replace them with `to`. Also notify the listener /// about every in-place op modification (for every use that was replaced). - void replaceAllUsesWith(Value from, Value to) { + virtual void replaceAllUsesWith(Value from, Value to) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f23a70601fc0a..ffad78db3ca87 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -780,15 +780,18 @@ class ConversionPatternRewriter final : public PatternRewriter { Region *region, const TypeConverter &converter, TypeConverter::SignatureConversion *entryConversion = nullptr); - /// Replace all the uses of the block argument `from` with `to`. This - /// function supports both 1:1 and 1:N replacements. + /// Replace all the uses of `from` with `to`. This function supports both 1:1 + /// and 1:N replacements. /// /// Note: If `allowPatternRollback` is set to "true", this function replaces - /// all current and future uses of the block argument. This same block - /// block argument must not be replaced multiple times. Uses are not replaced - /// immediately but in a delayed fashion. Patterns may still see the original - /// uses when inspecting IR. - void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to); + /// all current and future uses of the `from` value. This same value must not + /// be replaced multiple times. Uses are not replaced immediately but in a + /// delayed fashion. Patterns may still see the original uses when inspecting + /// IR. + void replaceAllUsesWith(Value from, ValueRange to); + void replaceAllUsesWith(Value from, Value to) override { + replaceAllUsesWith(from, ValueRange{to}); + } /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp index 42c76ed475b4c..93fe2edad5274 100644 --- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp +++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp @@ -284,7 +284,7 @@ static void restoreByValRefArgumentType( cast<TypeAttr>(byValRefAttr->getValue()).getValue()); Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg); - rewriter.replaceUsesOfBlockArgument(arg, valueArg); + rewriter.replaceAllUsesWith(arg, valueArg); } } diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index e3248204d6694..ce8e314ed6f7b 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -277,13 +277,14 @@ class IRRewrite { InlineBlock, MoveBlock, BlockTypeConversion, - ReplaceBlockArg, // Operation rewrites MoveOperation, ModifyOperation, ReplaceOperation, CreateOperation, - UnresolvedMaterialization + UnresolvedMaterialization, + // Value rewrites + ReplaceValue }; virtual ~IRRewrite() = default; @@ -330,7 +331,7 @@ class BlockRewrite : public IRRewrite { static bool classof(const IRRewrite *rewrite) { return rewrite->getKind() >= Kind::CreateBlock && - rewrite->getKind() <= Kind::ReplaceBlockArg; + rewrite->getKind() <= Kind::BlockTypeConversion; } protected: @@ -342,6 +343,25 @@ class BlockRewrite : public IRRewrite { Block *block; }; +/// A value rewrite. +class ValueRewrite : public IRRewrite { +public: + /// Return the value that this rewrite operates on. + Value getValue() const { return value; } + + static bool classof(const IRRewrite *rewrite) { + return rewrite->getKind() == Kind::ReplaceValue; + } + +protected: + ValueRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl, + Value value) + : IRRewrite(kind, rewriterImpl), value(value) {} + + // The value that this rewrite operates on. + Value value; +}; + /// Creation of a block. Block creations are immediately reflected in the IR. /// There is no extra work to commit the rewrite. During rollback, the newly /// created block is erased. @@ -548,19 +568,18 @@ class BlockTypeConversionRewrite : public BlockRewrite { Block *newBlock; }; -/// Replacing a block argument. This rewrite is not immediately reflected in the +/// Replacing a value. This rewrite is not immediately reflected in the /// IR. An internal IR mapping is updated, but the actual replacement is delayed /// until the rewrite is committed. -class ReplaceBlockArgRewrite : public BlockRewrite { +class ReplaceValueRewrite : public ValueRewrite { public: - ReplaceBlockArgRewrite(ConversionPatternRewriterImpl &rewriterImpl, - Block *block, BlockArgument arg, - const TypeConverter *converter) - : BlockRewrite(Kind::ReplaceBlockArg, rewriterImpl, block), arg(arg), + ReplaceValueRewrite(ConversionPatternRewriterImpl &rewriterImpl, Value value, + const TypeConverter *converter) + : ValueRewrite(Kind::ReplaceValue, rewriterImpl, value), converter(converter) {} static bool classof(const IRRewrite *rewrite) { - return rewrite->getKind() == Kind::ReplaceBlockArg; + return rewrite->getKind() == Kind::ReplaceValue; } void commit(RewriterBase &rewriter) override; @@ -568,9 +587,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite { void rollback() override; private: - BlockArgument arg; - - /// The current type converter when the block argument was replaced. + /// The current type converter when the value was replaced. const TypeConverter *converter; }; @@ -942,10 +959,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// uses. void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues); - /// Replace the given block argument with the given values. The specified + /// Replace the uses of the given value with the given values. The specified /// converter is used to build materializations (if necessary). - void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to, - const TypeConverter *converter); + void replaceAllUsesWith(Value from, ValueRange to, + const TypeConverter *converter); /// Erase the given block and its contents. void eraseBlock(Block *block); @@ -1132,10 +1149,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { IRRewriter notifyingRewriter; #ifndef NDEBUG - /// A set of replaced block arguments. This set is for debugging purposes - /// only and it is maintained only if `allowPatternRollback` is set to - /// "true". - DenseSet<BlockArgument> replacedArgs; + /// A set of replaced values. This set is for debugging purposes only and it + /// is maintained only if `allowPatternRollback` is set to "true". + DenseSet<Value> replacedValues; /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1172,32 +1188,54 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, - Value repl) { +/// Replace all uses of `from` with `repl`. +static void performReplaceValue(RewriterBase &rewriter, Value from, + Value repl) { if (isa<BlockArgument>(repl)) { - rewriter.replaceAllUsesWith(arg, repl); + // `repl` is a block argument. Directly replace all uses. + rewriter.replaceAllUsesWith(from, repl); return; } - // If the replacement value is an operation, we check to make sure that we - // don't replace uses that are within the parent operation of the - // replacement value. - Operation *replOp = cast<OpResult>(repl).getOwner(); + // If the replacement value is an operation, only replace those uses that: + // - are in a different block than the replacement operation, or + // - are in the same block but after the replacement operation. + // + // Example: + // ^bb0(%arg0: i32): + // %0 = "consumer"(%arg0) : (i32) -> (i32) + // "another_consumer"(%arg0) : (i32) -> () + // + // In the above example, replaceAllUsesWith(%arg0, %0) will replace the + // use in "another_consumer" but not the use in "consumer". When using the + // normal RewriterBase API, this would typically be done with + // `replaceUsesWithIf` / `replaceAllUsesExcept`. However, that API is not + // supported by the `ConversionPatternRewriter`. Due to the mapping mechanism + // it cannot be supported efficiently with `allowPatternRollback` set to + // "true". Therefore, the conversion driver is trying to be smart and replaces + // only those uses that do not lead to a dominance violation. E.g., the + // FuncToLLVM lowering (`restoreByValRefArgumentType`) relies on this + // behavior. + // + // TODO: As we move more and more towards `allowPatternRollback` set to + // "false", we should remove this special handling, in order to align the + // `ConversionPatternRewriter` API with the normal `RewriterBase` API. + Operation *replOp = repl.getDefiningOp(); Block *replBlock = replOp->getBlock(); - rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) { + rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) { Operation *user = operand.getOwner(); return user->getBlock() != replBlock || replOp->isBeforeInBlock(user); }); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); +void ReplaceValueRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter); if (!repl) return; - performReplaceBlockArg(rewriter, arg, repl); + performReplaceValue(rewriter, value, repl); } -void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } +void ReplaceValueRewrite::rollback() { rewriterImpl.mapping.erase({value}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { auto *listener = @@ -1590,7 +1628,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( /*outputTypes=*/origArgType, /*originalType=*/Type(), converter, /*castOp=*/nullptr, /*isPureTypeConversion=*/false) .front(); - replaceUsesOfBlockArgument(origArg, mat, converter); + replaceAllUsesWith(origArg, mat, converter); continue; } @@ -1599,15 +1637,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( assert(inputMap->size == 0 && "invalid to provide a replacement value when the argument isn't " "dropped"); - replaceUsesOfBlockArgument(origArg, inputMap->replacementValues, - converter); + replaceAllUsesWith(origArg, inputMap->replacementValues, converter); continue; } // This is a 1->1+ mapping. auto replArgs = newBlock->getArguments().slice(inputMap->inputNo, inputMap->size); - replaceUsesOfBlockArgument(origArg, replArgs, converter); + replaceAllUsesWith(origArg, replArgs, converter); } if (config.allowPatternRollback) @@ -1882,8 +1919,8 @@ void ConversionPatternRewriterImpl::replaceOp( op->walk([&](Operation *op) { replacedOps.insert(op); }); } -void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( - BlockArgument from, ValueRange to, const TypeConverter *converter) { +void ConversionPatternRewriterImpl::replaceAllUsesWith( + Value from, ValueRange to, const TypeConverter *converter) { if (!config.allowPatternRollback) { SmallVector<Value> toConv = llvm::to_vector(to); SmallVector<Value> repls = @@ -1893,25 +1930,25 @@ void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( if (!repl) return; - performReplaceBlockArg(r, from, repl); + performReplaceValue(r, from, repl); return; } #ifndef NDEBUG - // Make sure that a block argument is not replaced multiple times. In - // rollback mode, `replaceUsesOfBlockArgument` replaces not only all current - // uses of the given block argument, but also all future uses that may be - // introduced by future pattern applications. Therefore, it does not make - // sense to call `replaceUsesOfBlockArgument` multiple times with the same - // block argument. Doing so would overwrite the mapping and mess with the - // internal state of the dialect conversion driver. - assert(!replacedArgs.contains(from) && - "attempting to replace a block argument that was already replaced"); - replacedArgs.insert(from); + // Make sure that a value is not replaced multiple times. In rollback mode, + // `replaceAllUsesWith` replaces not only all current uses of the given value, + // but also all future uses that may be introduced by future pattern + // applications. Therefore, it does not make sense to call + // `replaceAllUsesWith` multiple times with the same value. Doing so would + // overwrite the mapping and mess with the internal state of the dialect + // conversion driver. + assert(!replacedValues.contains(from) && + "attempting to replace a value that was already replaced"); + replacedValues.insert(from); #endif // NDEBUG - appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter); mapping.map(from, to); + appendRewrite<ReplaceValueRewrite>(from, converter); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { @@ -2116,18 +2153,19 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes( return impl->convertRegionTypes(*this, region, converter, entryConversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - ValueRange to) { +void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) { LLVM_DEBUG({ - impl->logger.startLine() << "** Replace Argument : '" << from << "'"; - if (Operation *parentOp = from.getOwner()->getParentOp()) { - impl->logger.getOStream() << " (in region of '" << parentOp->getName() - << "' (" << parentOp << ")\n"; - } else { - impl->logger.getOStream() << " (unlinked block)\n"; + impl->logger.startLine() << "** Replace Value : '" << from << "'"; + if (auto blockArg = dyn_cast<BlockArgument>(from)) { + if (Operation *parentOp = blockArg.getOwner()->getParentOp()) { + impl->logger.getOStream() << " (in region of '" << parentOp->getName() + << "' (" << parentOp << ")\n"; + } else { + impl->logger.getOStream() << " (unlinked block)\n"; + } } }); - impl->replaceUsesOfBlockArgument(from, to, impl->currentTypeConverter); + impl->replaceAllUsesWith(from, to, impl->currentTypeConverter); } Value ConversionPatternRewriter::getRemappedValue(Value key) { @@ -2185,7 +2223,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // Replace all uses of block arguments. for (auto it : llvm::zip(source->getArguments(), argValues)) - replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it)); + replaceAllUsesWith(std::get<0>(it), std::get<1>(it)); if (fastPath) { // Move all ops at once. diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index b6f16ac1b5c48..e0a004b706be4 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -951,7 +951,7 @@ struct TestCreateIllegalBlock : public RewritePattern { } }; -/// A simple pattern that tests the "replaceUsesOfBlockArgument" API. +/// A simple pattern that tests the "replaceAllUsesWith" API. struct TestBlockArgReplace : public ConversionPattern { TestBlockArgReplace(MLIRContext *ctx, const TypeConverter &converter) : ConversionPattern(converter, "test.block_arg_replace", /*benefit=*/1, @@ -962,8 +962,7 @@ struct TestBlockArgReplace : public ConversionPattern { ConversionPatternRewriter &rewriter) const final { // Replace the first block argument with 2x the second block argument. Value repl = op->getRegion(0).getArgument(1); - rewriter.replaceUsesOfBlockArgument(op->getRegion(0).getArgument(0), - {repl, repl}); + rewriter.replaceAllUsesWith(op->getRegion(0).getArgument(0), {repl, repl}); rewriter.modifyOpInPlace(op, [&] { // If the "trigger_rollback" attribute is set, keep the op illegal, so // that a rollback is triggered. _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits