https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/151865
>From 4932ff9587728655edd68082fb55c2b6dcd1c309 Mon Sep 17 00:00:00 2001 From: Matthias Springer <m...@m-sp.org> Date: Sun, 22 Jun 2025 09:21:22 +0000 Subject: [PATCH] :WIP update erase immediately update 2 fix fix some tests --- mlir/include/mlir/Conversion/Passes.td | 2 + .../mlir/Transforms/DialectConversion.h | 23 +- .../ConvertToLLVM/ConvertToLLVMPass.cpp | 26 +- .../Dialect/Linalg/Transforms/Detensorize.cpp | 35 +- .../Transforms/SparseTensorConversion.cpp | 2 +- .../Transforms/Utils/DialectConversion.cpp | 473 ++++++++++++++---- ...assume-alignment-runtime-verification.mlir | 9 + .../atomic-rmw-runtime-verification.mlir | 8 + .../MemRef/cast-runtime-verification.mlir | 11 +- .../MemRef/copy-runtime-verification.mlir | 9 + .../MemRef/dim-runtime-verification.mlir | 9 + .../MemRef/load-runtime-verification.mlir | 12 +- .../MemRef/store-runtime-verification.mlir | 8 + .../MemRef/subview-runtime-verification.mlir | 12 +- .../Tensor/cast-runtime-verification.mlir | 11 + .../Tensor/dim-runtime-verification.mlir | 16 +- .../Tensor/extract-runtime-verification.mlir | 11 + .../extract_slice-runtime-verification.mlir | 11 + mlir/test/Transforms/test-legalizer.mlir | 29 +- mlir/test/lib/Dialect/Test/TestPatterns.cpp | 8 + 20 files changed, 591 insertions(+), 134 deletions(-) diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 6e1baaf23fcf7..4a9464ff265e0 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> { "Test conversion patterns of only the specified dialects">, Option<"useDynamic", "dynamic", "bool", "false", "Use op conversion attributes to configure the conversion">, + Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "false", + "Experimental performance flag to disallow pattern rollback"> ]; } diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index f6437657c9a93..84b9035dc6358 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public PatternRewriter { public: ~ConversionPatternRewriter() override; + /// Return the configuration of the current dialect conversion. + const ConversionConfig &getConfig() const; + /// Apply a signature conversion to given block. This replaces the block with /// a new block containing the updated signature. The operations of the given /// block are inlined into the newly-created block, which is returned. @@ -1228,18 +1231,18 @@ struct ConversionConfig { /// 2. Pattern produces IR (in-place modification or new IR) that is illegal /// and cannot be legalized by subsequent foldings / pattern applications. /// - /// If set to "false", the conversion driver will produce an LLVM fatal error - /// instead of rolling back IR modifications. Moreover, in case of a failed - /// conversion, the original IR is not restored. The resulting IR may be a - /// mix of original and rewritten IR. (Same as a failed greedy pattern - /// rewrite.) + /// Experimental: If set to "false", the conversion driver will produce an + /// LLVM fatal error instead of rolling back IR modifications. Moreover, in + /// case of a failed conversion, the original IR is not restored. The + /// resulting IR may be a mix of original and rewritten IR. (Same as a failed + /// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + /// with ASAN to detect invalid pattern API usage. /// - /// Note: This flag was added in preparation of the One-Shot Dialect - /// Conversion refactoring, which will remove the ability to roll back IR - /// modifications from the conversion driver. Use this flag to ensure that - /// your patterns do not trigger any IR rollbacks. For details, see + /// When pattern rollback is disabled, the conversion driver has to maintain + /// less internal state. This is more efficient, but not supported by all + /// lowering patterns. For details, see /// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083. - bool allowPatternRollback = true; + bool allowPatternRollback = false; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp index ed5d6d4a7fe40..cdb715064b0f7 100644 --- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp +++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp @@ -31,7 +31,8 @@ namespace { class ConvertToLLVMPassInterface { public: ConvertToLLVMPassInterface(MLIRContext *context, - ArrayRef<std::string> filterDialects); + ArrayRef<std::string> filterDialects, + bool allowPatternRollback = true); virtual ~ConvertToLLVMPassInterface() = default; /// Get the dependent dialects used by `convert-to-llvm`. @@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface { MLIRContext *context; /// List of dialects names to use as filters. ArrayRef<std::string> filterDialects; + /// An experimental flag to disallow pattern rollback. This is more efficient + /// but not supported by all lowering patterns. + bool allowPatternRollback; }; /// This DialectExtension can be attached to the context, which will invoke the @@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface { /// Apply the conversion driver. LogicalResult transform(Operation *op, AnalysisManager manager) const final { - if (failed(applyPartialConversion(op, *target, *patterns))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, *target, *patterns, config))) return failure(); return success(); } @@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface { patterns); // Apply the conversion. - if (failed(applyPartialConversion(op, target, std::move(patterns)))) + ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; + if (failed(applyPartialConversion(op, target, std::move(patterns), config))) return failure(); return success(); } @@ -206,9 +214,11 @@ class ConvertToLLVMPass std::shared_ptr<ConvertToLLVMPassInterface> impl; // Choose the pass implementation. if (useDynamic) - impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects); + impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects, + allowPatternRollback); else - impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects); + impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects, + allowPatternRollback); if (failed(impl->initialize())) return failure(); this->impl = impl; @@ -228,8 +238,10 @@ class ConvertToLLVMPass //===----------------------------------------------------------------------===// ConvertToLLVMPassInterface::ConvertToLLVMPassInterface( - MLIRContext *context, ArrayRef<std::string> filterDialects) - : context(context), filterDialects(filterDialects) {} + MLIRContext *context, ArrayRef<std::string> filterDialects, + bool allowPatternRollback) + : context(context), filterDialects(filterDialects), + allowPatternRollback(allowPatternRollback) {} void ConvertToLLVMPassInterface::getDependentDialects( DialectRegistry ®istry) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp index 830905495e759..d6f26fa200dbc 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -458,6 +458,22 @@ struct LinalgDetensorize } }; + /// A listener that forwards notifyBlockErased and notifyOperationErased to + /// the given callbacks. + struct CallbackListener : public RewriterBase::Listener { + CallbackListener(std::function<void(Operation *op)> onOperationErased, + std::function<void(Block *block)> onBlockErased) + : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {} + + void notifyBlockErased(Block *block) override { onBlockErased(block); } + void notifyOperationErased(Operation *op) override { + onOperationErased(op); + } + + std::function<void(Operation *op)> onOperationErased; + std::function<void(Block *block)> onBlockErased; + }; + void runOnOperation() override { MLIRContext *context = &getContext(); DetensorizeTypeConverter typeConverter; @@ -551,8 +567,23 @@ struct LinalgDetensorize populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter, shouldConvertBranchOperand); - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) + ConversionConfig config; + CallbackListener listener(/*onOperationErased=*/ + [&](Operation *op) { + opsToDetensor.erase(op); + detensorableBranchOps.erase(op); + }, + /*onBlockErased=*/ + [&](Block *block) { + for (BlockArgument arg : + block->getArguments()) { + blockArgsToDetensor.erase(arg); + } + }); + + config.listener = &listener; + if (failed(applyFullConversion(getOperation(), target, std::move(patterns), + config))) signalPassFailure(); RewritePatternSet canonPatterns(context); diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp index 134aef3a6c719..0e88d31dae8e8 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp @@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> { {tensor, lvlCoords, values, filled, added, count}, EmitCInterface::On); Operation *parent = getTop(op); + rewriter.setInsertionPointAfter(parent); rewriter.replaceOp(op, adaptor.getTensor()); // Deallocate the buffers on exit of the loop nest. - rewriter.setInsertionPointAfter(parent); memref::DeallocOp::create(rewriter, loc, values); memref::DeallocOp::create(rewriter, loc, filled); memref::DeallocOp::create(rewriter, loc, added); diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp index 4a60bfe489f9c..90cd58c8285b5 100644 --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -121,17 +121,8 @@ struct ConversionValueMapping { /// false positives. bool isMappedTo(Value value) const { return mappedTo.contains(value); } - /// Lookup a value in the mapping. If `skipPureTypeConversions` is "true", - /// pure type conversions are not considered. Return an empty vector if no - /// mapping was found. - /// - /// Note: This mapping data structure supports N:M mappings. This function - /// first tries to look up mappings for each input value individually (and - /// then composes the results). If such a lookup is unsuccessful, the entire - /// vector is looked up together. If the lookup is still unsuccessful, an - /// empty vector is returned. - ValueVector lookup(const ValueVector &from, - bool skipPureTypeConversions = false) const; + /// Lookup a value in the mapping. + ValueVector lookup(const ValueVector &from) const; template <typename T> struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {}; @@ -191,43 +182,28 @@ struct ConversionValueMapping { /// conversions.) static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__"; +/// Return the operation that defines all values in the vector. Return nullptr +/// if the values are not defined by the same operation. +static Operation *getCommonDefiningOp(const ValueVector &values) { + assert(!values.empty() && "expected non-empty value vector"); + Operation *op = values.front().getDefiningOp(); + for (Value v : llvm::drop_begin(values)) { + if (v.getDefiningOp() != op) + return nullptr; + } + return op; +} + /// A vector of values is a pure type conversion if all values are defined by /// the same operation and the operation has the `kPureTypeConversionMarker` /// attribute. static bool isPureTypeConversion(const ValueVector &values) { assert(!values.empty() && "expected non-empty value vector"); - Operation *op = values.front().getDefiningOp(); - for (Value v : llvm::drop_begin(values)) - if (v.getDefiningOp() != op) - return false; + Operation *op = getCommonDefiningOp(values); return op && op->hasAttr(kPureTypeConversionMarker); } -ValueVector ConversionValueMapping::lookup(const ValueVector &from, - bool skipPureTypeConversions) const { - // If possible, replace each value with (one or multiple) mapped values. - ValueVector next; - for (Value v : from) { - auto it = mapping.find({v}); - if (it != mapping.end()) { - llvm::append_range(next, it->second); - } else { - next.push_back(v); - } - } - if (next != from) { - // At least one value was replaced. - return next; - } - - // Otherwise: Check if there is a mapping for the entire vector. Such - // mappings are materializations. (N:M mapping are not supported for value - // replacements.) - // - // Note: From a correctness point of view, materializations do not have to - // be stored (and looked up) in the mapping. But for performance reasons, - // we choose to reuse existing IR (when possible) instead of creating it - // multiple times. +ValueVector ConversionValueMapping::lookup(const ValueVector &from) const { auto it = mapping.find(from); if (it == mapping.end()) { // No mapping found: The lookup stops here. @@ -874,7 +850,7 @@ namespace detail { struct ConversionPatternRewriterImpl : public RewriterBase::Listener { explicit ConversionPatternRewriterImpl(MLIRContext *ctx, const ConversionConfig &config) - : context(ctx), config(config) {} + : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {} //===--------------------------------------------------------------------===// // State Management @@ -896,6 +872,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// failure. template <typename RewriteTy, typename... Args> void appendRewrite(Args &&...args) { + assert(config.allowPatternRollback && "appending rewrites is not allowed"); rewrites.push_back( std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...)); } @@ -922,15 +899,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { bool wasOpReplaced(Operation *op) const; /// Lookup the most recently mapped values with the desired types in the - /// mapping. - /// - /// Special cases: - /// - If the desired type range is empty, simply return the most recently - /// mapped values. - /// - If there is no mapping to the desired types, also return the most - /// recently mapped values. - /// - If there is no mapping for the given values at all, return the given - /// value. + /// mapping, taking into account only replacements. Perform a best-effort + /// search for existing materializations with the desired types. /// /// If `skipPureTypeConversions` is "true", materializations that are pure /// type conversions are not considered. @@ -1099,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { ConversionValueMapping mapping; /// Ordered list of block operations (creations, splits, motions). + /// This vector is maintained only if `allowPatternRollback` is set to + /// "true". Otherwise, all IR rewrites are materialized immediately and no + /// bookkeeping is needed. SmallVector<std::unique_ptr<IRRewrite>> rewrites; /// A set of operations that should no longer be considered for legalization. @@ -1122,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// by the current pattern. SetVector<Block *> patternInsertedBlocks; + /// A list of unresolved materializations that were created by the current + /// pattern. + DenseSet<UnrealizedConversionCastOp> patternMaterializations; + /// A mapping for looking up metadata of unresolved materializations. DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo> unresolvedMaterializations; @@ -1137,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener { /// Dialect conversion configuration. const ConversionConfig &config; + /// A set of erased operations. This set is utilized only if + /// `allowPatternRollback` is set to "false". Conceptually, this set is + /// simialar to `replacedOps` (which is maintained when the flag is set to + /// "true"). However, erasing from a DenseSet is more efficient than erasing + /// from a SetVector. + DenseSet<Operation *> erasedOps; + + /// A set of erased blocks. This set is utilized only if + /// `allowPatternRollback` is set to "false". + DenseSet<Block *> erasedBlocks; + + /// A rewriter that notifies the listener (if any) about all IR + /// modifications. This rewriter is utilized only if `allowPatternRollback` + /// is set to "false". If the flag is set to "true", the listener is notified + /// with a separate mechanism (e.g., in `IRRewrite::commit`). + IRRewriter notifyingRewriter; + #ifndef NDEBUG /// A set of operations that have pending updates. This tracking isn't /// strictly necessary, and is thus only active during debug builds for extra @@ -1173,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() { getNewBlock()->replaceAllUsesWith(getOrigBlock()); } -void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { - Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); - if (!repl) - return; - +static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg, + Value repl) { if (isa<BlockArgument>(repl)) { rewriter.replaceAllUsesWith(arg, repl); return; @@ -1194,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { }); } +void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) { + Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter); + if (!repl) + return; + performReplaceBlockArg(rewriter, arg, repl); +} + void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); } void ReplaceOperationRewrite::commit(RewriterBase &rewriter) { @@ -1279,6 +1277,63 @@ void ConversionPatternRewriterImpl::applyRewrites() { ValueVector ConversionPatternRewriterImpl::lookupOrDefault( Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const { + + // Helper function that looks up a single value. + auto lookup = [&](const ValueVector &values) -> ValueVector { + assert(!values.empty() && "expected non-empty value vector"); + + // If the pattern rollback is enabled, use the mapping to look up the + // values. + if (config.allowPatternRollback) + return mapping.lookup(values); + + // Otherwise, look up values by examining the IR. All replacements have + // already been materialized in IR. + Operation *op = getCommonDefiningOp(values); + if (!op) + return {}; + auto castOp = dyn_cast<UnrealizedConversionCastOp>(op); + if (!castOp) + return {}; + if (!this->unresolvedMaterializations.contains(castOp)) + return {}; + if (castOp.getOutputs() != values) + return {}; + return castOp.getInputs(); + }; + + auto composedLookup = [&](const ValueVector &values) -> ValueVector { + // If possible, replace each value with (one or multiple) mapped values. + ValueVector next; + for (Value v : values) { + ValueVector r = lookup({v}); + if (!r.empty()) { + llvm::append_range(next, r); + } else { + next.push_back(v); + } + } + if (next != values) { + // At least one value was replaced. + return next; + } + + // Otherwise: Check if there is a mapping for the entire vector. Such + // mappings are materializations. (N:M mapping are not supported for value + // replacements.) + // + // Note: From a correctness point of view, materializations do not have to + // be stored (and looked up) in the mapping. But for performance reasons, + // we choose to reuse existing IR (when possible) instead of creating it + // multiple times. + ValueVector r = lookup(values); + if (r.empty()) { + // No mapping found: The lookup stops here. + return {}; + } + return r; + }; + // Try to find the deepest values that have the desired types. If there is no // such mapping, simply return the deepest values. ValueVector desiredValue; @@ -1300,7 +1355,7 @@ ValueVector ConversionPatternRewriterImpl::lookupOrDefault( desiredValue = current; // Lookup next value in the mapping. - ValueVector next = mapping.lookup(current, skipPureTypeConversions); + ValueVector next = composedLookup(current); if (next.empty()) break; current = std::move(next); @@ -1345,15 +1400,8 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state, void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep, StringRef patternName) { for (auto &rewrite : - llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) { - if (!config.allowPatternRollback && - !isa<UnresolvedMaterializationRewrite>(rewrite)) { - // Unresolved materializations can always be rolled back (erased). - llvm::report_fatal_error("pattern '" + patternName + - "' rollback of IR modifications requested"); - } + llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) rewrite->rollback(); - } rewrites.resize(numRewritesToKeep); } @@ -1417,12 +1465,12 @@ LogicalResult ConversionPatternRewriterImpl::remapValues( bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const { // Check to see if this operation is ignored or was replaced. - return replacedOps.count(op) || ignoredOps.count(op); + return wasOpReplaced(op) || ignoredOps.count(op); } bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const { // Check to see if this operation was replaced. - return replacedOps.count(op); + return replacedOps.count(op) || erasedOps.count(op); } //===----------------------------------------------------------------------===// @@ -1506,7 +1554,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( // a bit more efficient, so we try to do that when possible. bool fastPath = !config.listener; if (fastPath) { - appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); + if (config.allowPatternRollback) + appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end()); newBlock->getOperations().splice(newBlock->end(), block->getOperations()); } else { while (!block->empty()) @@ -1554,7 +1603,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( replaceUsesOfBlockArgument(origArg, replArgs, converter); } - appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); + if (config.allowPatternRollback) + appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock); // Erase the old block. (It is just unlinked for now and will be erased during // cleanup.) @@ -1583,23 +1633,32 @@ ValueRange ConversionPatternRewriterImpl::buildUnresolvedMaterialization( // tracking the materialization like we do for other operations. OpBuilder builder(outputTypes.front().getContext()); builder.setInsertionPoint(ip.getBlock(), ip.getPoint()); - auto convertOp = + UnrealizedConversionCastOp convertOp = UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs); if (isPureTypeConversion) convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr()); - if (!valuesToMap.empty()) - mapping.map(valuesToMap, convertOp.getResults()); + + // Register the materialization. if (castOp) *castOp = convertOp; unresolvedMaterializations[convertOp] = UnresolvedMaterializationInfo(converter, kind, originalType); - appendRewrite<UnresolvedMaterializationRewrite>(convertOp, - std::move(valuesToMap)); + if (config.allowPatternRollback) { + if (!valuesToMap.empty()) + mapping.map(valuesToMap, convertOp.getResults()); + appendRewrite<UnresolvedMaterializationRewrite>(convertOp, + std::move(valuesToMap)); + } else { + patternMaterializations.insert(convertOp); + } return convertOp.getResults(); } Value ConversionPatternRewriterImpl::findOrBuildReplacementValue( Value value, const TypeConverter *converter) { + assert(config.allowPatternRollback && + "this code path is valid only in rollback mode"); + // Try to find a replacement value with the same type in the conversion value // mapping. This includes cached materializations. We try to reuse those // instead of generating duplicate IR. @@ -1661,26 +1720,119 @@ void ConversionPatternRewriterImpl::notifyOperationInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(op->getParentOp()) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) && "attempting to insert into a block within a replaced/erased op"); + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyOperationInserted(op, previous); + if (wasDetached) { - // If the op was detached, it is most likely a newly created op. - // TODO: If the same op is inserted multiple times from a detached state, - // the rollback mechanism may erase the same op multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateOperationRewrite>(op); + // If the op was detached, it is most likely a newly created op. Add it the + // set of newly created ops, so that it will be legalized. If this op is + // not a newly created op, it will be legalized a second time, which is + // inefficient but harmless. patternNewOps.insert(op); + + if (config.allowPatternRollback) { + // TODO: If the same op is inserted multiple times from a detached + // state, the rollback mechanism may erase the same op multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateOperationRewrite>(op); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased operations that must be kept up to date. + erasedOps.erase(op); + } return; } // The op was moved from one place to another. - appendRewrite<MoveOperationRewrite>(op, previous); + if (config.allowPatternRollback) + appendRewrite<MoveOperationRewrite>(op, previous); +} + +/// Given that `fromRange` is about to be replaced with `toRange`, compute +/// replacement values with the types of `fromRange`. +static SmallVector<Value> +getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange, + const SmallVector<SmallVector<Value>> &toRange, + const TypeConverter *converter) { + assert(!impl.config.allowPatternRollback && + "this code path is valid only in 'no rollback' mode"); + SmallVector<Value> repls; + for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) { + if (from.use_empty()) { + // The replaced value is dead. No replacement value is needed. + repls.push_back(Value()); + continue; + } + + if (to.empty()) { + // The replaced value is dropped. Materialize a replacement value "out of + // thin air". + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(from), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/ValueRange(), + /*outputTypes=*/from.getType(), /*originalType=*/Type(), + converter)[0]; + repls.push_back(srcMat); + continue; + } + + if (TypeRange(to) == TypeRange(from.getType())) { + // The replacement value already has the correct type. Use it directly. + repls.push_back(to[0]); + continue; + } + + // The replacement value has the wrong type. Build a source materialization + // to the original type. + // TODO: This is a bit inefficient. We should try to reuse existing + // materializations if possible. This would require an extension of the + // `lookupOrDefault` API. + Value srcMat = impl.buildUnresolvedMaterialization( + MaterializationKind::Source, computeInsertPoint(to), from.getLoc(), + /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(), + /*originalType=*/Type(), converter)[0]; + repls.push_back(srcMat); + } + + return repls; } void ConversionPatternRewriterImpl::replaceOp( Operation *op, SmallVector<SmallVector<Value>> &&newValues) { - assert(newValues.size() == op->getNumResults()); + assert(newValues.size() == op->getNumResults() && + "incorrect number of replacement values"); + + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + SmallVector<Value> repls = getReplacementValues( + *this, op->getResults(), newValues, currentTypeConverter); + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + op->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + op->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Replace the op with the replacement values and notify the listener. + notifyingRewriter.replaceOp(op, repls); + return; + } + assert(!ignoredOps.contains(op) && "operation was already replaced"); // Check if replaced op is an unresolved materialization, i.e., an @@ -1720,11 +1872,46 @@ void ConversionPatternRewriterImpl::replaceOp( void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument( BlockArgument from, ValueRange to, const TypeConverter *converter) { + if (!config.allowPatternRollback) { + SmallVector<Value> toConv = llvm::to_vector(to); + SmallVector<Value> repls = + getReplacementValues(*this, from, {toConv}, converter); + IRRewriter r(from.getContext()); + Value repl = repls.front(); + if (!repl) + return; + + performReplaceBlockArg(r, from, repl); + return; + } + appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter); mapping.map(from, to); } void ConversionPatternRewriterImpl::eraseBlock(Block *block) { + if (!config.allowPatternRollback) { + // Pattern rollback is not allowed: materialize all IR changes immediately. + // Update internal data structures, so that there are no dangling pointers + // to erased IR. + block->walk([&](Operation *op) { + erasedOps.insert(op); + ignoredOps.remove(op); + if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) { + unresolvedMaterializations.erase(castOp); + patternMaterializations.erase(castOp); + } + // The original op will be erased, so remove it from the set of + // unlegalized ops. + if (config.unlegalizedOps) + config.unlegalizedOps->erase(op); + }); + block->walk([&](Block *block) { erasedBlocks.insert(block); }); + // Erase the block and notify the listener. + notifyingRewriter.eraseBlock(block); + return; + } + assert(!wasOpReplaced(block->getParentOp()) && "attempting to erase a block within a replaced/erased op"); appendRewrite<EraseBlockRewrite>(block); @@ -1758,23 +1945,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted( logger.getOStream() << " (was detached)"; logger.getOStream() << "\n"; }); - assert(!wasOpReplaced(newParentOp) && + + // In rollback mode, it is easier to misuse the API, so perform extra error + // checking. + assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) && "attempting to insert into a region within a replaced/erased op"); (void)newParentOp; + // In "no rollback" mode, the listener is always notified immediately. + if (!config.allowPatternRollback && config.listener) + config.listener->notifyBlockInserted(block, previous, previousIt); + patternInsertedBlocks.insert(block); if (wasDetached) { // If the block was detached, it is most likely a newly created block. - // TODO: If the same block is inserted multiple times from a detached state, - // the rollback mechanism may erase the same block multiple times. This is a - // bug in the rollback-based dialect conversion driver. - appendRewrite<CreateBlockRewrite>(block); + if (config.allowPatternRollback) { + // TODO: If the same block is inserted multiple times from a detached + // state, the rollback mechanism may erase the same block multiple times. + // This is a bug in the rollback-based dialect conversion driver. + appendRewrite<CreateBlockRewrite>(block); + } else { + // In "no rollback" mode, there is an extra data structure for tracking + // erased blocks that must be kept up to date. + erasedBlocks.erase(block); + } return; } // The block was moved from one place to another. - appendRewrite<MoveBlockRewrite>(block, previous, previousIt); + if (config.allowPatternRollback) + appendRewrite<MoveBlockRewrite>(block, previous, previousIt); } void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source, @@ -1807,6 +2008,10 @@ ConversionPatternRewriter::ConversionPatternRewriter( ConversionPatternRewriter::~ConversionPatternRewriter() = default; +const ConversionConfig &ConversionPatternRewriter::getConfig() const { + return impl->config; +} + void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) { assert(op && newOp && "expected non-null op"); replaceOp(op, newOp->getResults()); @@ -1950,7 +2155,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, // a bit more efficient, so we try to do that when possible. bool fastPath = !impl->config.listener; - if (fastPath) + if (fastPath && impl->config.allowPatternRollback) impl->inlineBlockBefore(source, dest, before); // Replace all uses of block arguments. @@ -1976,6 +2181,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest, } void ConversionPatternRewriter::startOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + // Pattern rollback is not allowed: no extra bookkeeping is needed. + PatternRewriter::startOpModification(op); + return; + } assert(!impl->wasOpReplaced(op) && "attempting to modify a replaced/erased op"); #ifndef NDEBUG @@ -1985,20 +2195,29 @@ void ConversionPatternRewriter::startOpModification(Operation *op) { } void ConversionPatternRewriter::finalizeOpModification(Operation *op) { - assert(!impl->wasOpReplaced(op) && - "attempting to modify a replaced/erased op"); - PatternRewriter::finalizeOpModification(op); impl->patternModifiedOps.insert(op); + if (!impl->config.allowPatternRollback) { + PatternRewriter::finalizeOpModification(op); + if (getConfig().listener) + getConfig().listener->notifyOperationModified(op); + return; + } // There is nothing to do here, we only need to track the operation at the // start of the update. #ifndef NDEBUG + assert(!impl->wasOpReplaced(op) && + "attempting to modify a replaced/erased op"); assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); #endif } void ConversionPatternRewriter::cancelOpModification(Operation *op) { + if (!impl->config.allowPatternRollback) { + PatternRewriter::cancelOpModification(op); + return; + } #ifndef NDEBUG assert(impl->pendingRootUpdates.erase(op) && "operation did not have a pending in-place update"); @@ -2355,6 +2574,37 @@ OperationLegalizer::legalizeWithFold(Operation *op, return success(); } +/// Report a fatal error indicating that newly produced or modified IR could +/// not be legalized. +static void +reportNewIrLegalizationFatalError(const Pattern &pattern, + const SetVector<Operation *> &newOps, + const SetVector<Operation *> &modifiedOps, + const SetVector<Block *> &insertedBlocks) { + StringRef detachedBlockStr = "(detached block)"; + std::string newOpNames = llvm::join( + llvm::map_range( + newOps, [](Operation *op) { return op->getName().getStringRef(); }), + ", "); + std::string modifiedOpNames = llvm::join( + llvm::map_range( + newOps, [](Operation *op) { return op->getName().getStringRef(); }), + ", "); + std::string insertedBlockNames = llvm::join( + llvm::map_range(insertedBlocks, + [&](Block *block) { + if (block->getParentOp()) + return block->getParentOp()->getName().getStringRef(); + return detachedBlockStr; + }), + ", "); + llvm::report_fatal_error( + "pattern '" + pattern.getDebugName() + + "' produced IR that could not be legalized. " + "new ops: {" + + newOpNames + "}, " + "modified ops: {" + modifiedOpNames + "}, " + + "inserted block into ops: {" + insertedBlockNames + "}"); +} + LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op, ConversionPatternRewriter &rewriter) { @@ -2398,17 +2648,35 @@ OperationLegalizer::legalizeWithPattern(Operation *op, RewriterState curState = rewriterImpl.getCurrentState(); auto onFailure = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); -#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS if (!rewriterImpl.config.allowPatternRollback) { - // Returning "failure" after modifying IR is not allowed. + // Erase all unresolved materializations. + for (auto op : rewriterImpl.patternMaterializations) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + rewriterImpl.patternMaterializations.clear(); +#if 0 + // Cheap pattern check that could have false positives. Can be enabled + // manually for debugging purposes. E.g., this check would report an API + // violation when an op is created and then erased in the same pattern. + if (!rewriterImpl.patternNewOps.empty() || + !rewriterImpl.patternModifiedOps.empty() || + !rewriterImpl.patternInsertedBlocks.empty()) { + llvm::report_fatal_error("pattern '" + pattern.getDebugName() + + "' rollback of IR modifications requested"); + } +#endif +#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + // Expensive pattern check that can detect more API violations and has no + // fewer false positives than the cheap check. if (checkOp) { OperationFingerPrint fingerPrintAfterPattern(checkOp); if (fingerPrintAfterPattern != *topLevelFingerPrint) llvm::report_fatal_error("pattern '" + pattern.getDebugName() + "' returned failure but IR did change"); } - } #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS + } rewriterImpl.patternNewOps.clear(); rewriterImpl.patternModifiedOps.clear(); rewriterImpl.patternInsertedBlocks.clear(); @@ -2432,6 +2700,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op, // successfully applied. auto onSuccess = [&](const Pattern &pattern) { assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates"); + if (!rewriterImpl.config.allowPatternRollback) { + // Eagerly erase unused materializations. + for (auto op : rewriterImpl.patternMaterializations) { + if (op->use_empty()) { + rewriterImpl.unresolvedMaterializations.erase(op); + op.erase(); + } + } + rewriterImpl.patternMaterializations.clear(); + } SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps); SetVector<Operation *> modifiedOps = moveAndReset(rewriterImpl.patternModifiedOps); @@ -2442,8 +2720,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op, appliedPatterns.erase(&pattern); if (failed(result)) { if (!rewriterImpl.config.allowPatternRollback) - llvm::report_fatal_error("pattern '" + pattern.getDebugName() + - "' produced IR that could not be legalized"); + reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps, + insertedBlocks); rewriterImpl.resetState(curState, pattern.getDebugName()); } if (config.listener) @@ -2522,6 +2800,9 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites( // If the pattern moved or created any blocks, make sure the types of block // arguments get legalized. for (Block *block : insertedBlocks) { + if (impl.erasedBlocks.contains(block)) + continue; + // Only check blocks outside of the current operation. Operation *parentOp = block->getParentOp(); if (!parentOp || parentOp == op || block->getNumArguments() == 0) diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir index 8f74976c59773..25a338df8d790 100644 --- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @main() { // This buffer is properly aligned. There should be no error. // CHECK-NOT: ^ memref is not aligned to 8 diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir index 26c731c921356..4c6a48d577a6c 100644 --- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir @@ -5,6 +5,14 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @store_dynamic(%memref: memref<?xf32>, %index: index) { %cst = arith.constant 1.0 : f32 memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32 diff --git a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir index 8b6308e9c1939..1ac10306395ad 100644 --- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir @@ -1,11 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @cast_to_static_dim(%m: memref<?xf32>) -> memref<10xf32> { %0 = memref.cast %m : memref<?xf32> to memref<10xf32> return %0 : memref<10xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir index 95b9db2832cee..be9417baf93df 100644 --- a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + // Put memref.copy in a function, otherwise the memref.cast may fold. func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) { memref.copy %src, %dest : memref<?xf32> to memref<?xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir index 2e3f271743c93..ef4af62459738 100644 --- a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir @@ -6,6 +6,15 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @main() { %c4 = arith.constant 4 : index %alloca = memref.alloca() : memref<1xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir index b87e5bdf0970c..2e42648297875 100644 --- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir @@ -1,12 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ -// RUN: -lower-affine \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @load(%memref: memref<1xf32>, %index: index) { memref.load %memref[%index] : memref<1xf32> return diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir index 12253fa3b5e83..dd000c6904bcb 100644 --- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir @@ -5,6 +5,14 @@ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @store_dynamic(%memref: memref<?xf32>, %index: index) { %cst = arith.constant 1.0 : f32 memref.store %cst, %memref[%index] : memref<?xf32> diff --git a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir index ec7e4085f2fa5..9fbe5bc60321e 100644 --- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir @@ -1,12 +1,22 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -test-cf-assert \ // RUN: -expand-strided-metadata \ // RUN: -lower-affine \ +// RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ // RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -expand-strided-metadata \ +// RUN: -lower-affine \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @subview(%memref: memref<1xf32>, %offset: index) { memref.subview %memref[%offset] [1] [1] : memref<1xf32> to diff --git a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir index e4aab32d4a390..f37a6d6383c48 100644 --- a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func private @cast_to_static_dim(%t: tensor<?xf32>) -> tensor<10xf32> { %0 = tensor.cast %t : tensor<?xf32> to tensor<10xf32> return %0 : tensor<10xf32> diff --git a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir index c6d8f698b9433..e9e5c040c6488 100644 --- a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir @@ -1,10 +1,20 @@ // RUN: mlir-opt %s -generate-runtime-verification \ -// RUN: -one-shot-bufferize \ -// RUN: -buffer-deallocation-pipeline \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ // RUN: -test-cf-assert \ // RUN: -convert-to-llvm | \ // RUN: mlir-runner -e main -entry-point-result=void \ -// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s func.func @main() { diff --git a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir index 8e3cab7be704d..73fcec4d7abcd 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @extract(%tensor: tensor<1xf32>, %index: index) { tensor.extract %tensor[%index] : tensor<1xf32> return diff --git a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir index 28f9be0fffe64..341a59e8b8102 100644 --- a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir +++ b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir @@ -8,6 +8,17 @@ // RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ // RUN: FileCheck %s +// RUN: mlir-opt %s -generate-runtime-verification \ +// RUN: -one-shot-bufferize="bufferize-function-boundaries" \ +// RUN: -buffer-deallocation-pipeline=private-function-dynamic-ownership \ +// RUN: -test-cf-assert \ +// RUN: -convert-scf-to-cf \ +// RUN: -convert-to-llvm="allow-pattern-rollback=0" \ +// RUN: -reconcile-unrealized-casts | \ +// RUN: mlir-runner -e main -entry-point-result=void \ +// RUN: -shared-libs=%tlir_runner_utils 2>&1 | \ +// RUN: FileCheck %s + func.func @extract_slice(%tensor: tensor<1xf32>, %offset: index) { tensor.extract_slice %tensor[%offset] [1] [1] : tensor<1xf32> to tensor<1xf32> return diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir index 5630d1540e4d5..9a04da7904863 100644 --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -1,9 +1,14 @@ -// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics -profile-actions-to=- %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics -profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER +// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s + +// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B" +// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B" +// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "B" +// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "E" +// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "E" +// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "E" -// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "B" -// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "B" -// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "B" -// CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "E" // Note: Listener notifications appear after the pattern application because // the conversion driver sends all notifications at the end of the conversion // in bulk. @@ -11,8 +16,6 @@ // CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a // CHECK-NEXT: notifyOperationModified: func.return // CHECK-NEXT: notifyOperationErased: test.illegal_op_a -// CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "E" -// CHECK: "name": "pass-execution", "cat": "PERF", "ph": "E" // CHECK-LABEL: verifyDirectPattern func.func @verifyDirectPattern() -> i32 { // CHECK-NEXT: "test.legal_op_a"() <{status = "Success"} @@ -29,7 +32,9 @@ func.func @verifyDirectPattern() -> i32 { // CHECK-NEXT: notifyOperationErased: test.illegal_op_c // CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked // CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e -// CHECK-NEXT: notifyOperationErased: test.illegal_op_e +// Note: func.return is modified a second time when running in no-rollback +// mode. +// CHECK: notifyOperationErased: test.illegal_op_e // CHECK-LABEL: verifyLargerBenefit func.func @verifyLargerBenefit() -> i32 { @@ -70,7 +75,7 @@ func.func @remap_call_1_to_1(%arg0: i64) { // CHECK: notifyBlockInserted into func.func: was unlinked // Contents of the old block are moved to the new block. -// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown +// CHECK-NEXT: notifyOperationInserted: test.return // The old block is erased. // CHECK-NEXT: notifyBlockErased @@ -409,8 +414,10 @@ func.func @test_remap_block_arg() { // CHECK-LABEL: func @test_multiple_1_to_n_replacement() // CHECK: %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, f16, f16) -// CHECK: %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, %[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16 -// CHECK: "test.valid"(%[[cast]]) : (f16) -> () +// Note: There is a bug in the rollback-based conversion driver: it emits a +// "test.cast" : (f16, f16, f16, f16) -> f16, when it should be emitting +// three consecutive casts of (f16, f16) -> f16. +// CHECK: "test.valid"(%{{.*}}) : (f16) -> () func.func @test_multiple_1_to_n_replacement() { %0 = "test.multiple_1_to_n_replacement"() : () -> (f16) "test.invalid"(%0) : (f16) -> () diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp index 7150401bdbdce..6e6d896df0cb9 100644 --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -1301,6 +1301,7 @@ class TestMultiple1ToNReplacement : public ConversionPattern { // Helper function that replaces the given op with a new op of the given // name and doubles each result (1 -> 2 replacement of each result). auto replaceWithDoubleResults = [&](Operation *op, StringRef name) { + rewriter.setInsertionPointAfter(op); SmallVector<Type> types; for (Type t : op->getResultTypes()) { types.push_back(t); @@ -1499,6 +1500,7 @@ struct TestLegalizePatternDriver if (mode == ConversionMode::Partial) { DenseSet<Operation *> unlegalizedOps; ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; DumpNotifications dumpNotifications; config.listener = &dumpNotifications; config.unlegalizedOps = &unlegalizedOps; @@ -1520,6 +1522,7 @@ struct TestLegalizePatternDriver }); ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; DumpNotifications dumpNotifications; config.listener = &dumpNotifications; if (failed(applyFullConversion(getOperation(), target, @@ -1535,6 +1538,7 @@ struct TestLegalizePatternDriver // Analyze the convertible operations. DenseSet<Operation *> legalizedOps; ConversionConfig config; + config.allowPatternRollback = allowPatternRollback; config.legalizableOps = &legalizedOps; if (failed(applyAnalysisConversion(getOperation(), target, std::move(patterns), config))) @@ -1555,6 +1559,10 @@ struct TestLegalizePatternDriver clEnumValN(ConversionMode::Full, "full", "Perform a full conversion"), clEnumValN(ConversionMode::Partial, "partial", "Perform a partial conversion"))}; + + Option<bool> allowPatternRollback{*this, "allow-pattern-rollback", + llvm::cl::desc("Allow pattern rollback"), + llvm::cl::init(true)}; }; } // namespace _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits