https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/151865

>From e789428b751e4c8037fd90c114777a23c33973f6 Mon Sep 17 00:00:00 2001
From: Matthias Springer <m...@m-sp.org>
Date: Sun, 22 Jun 2025 09:21:22 +0000
Subject: [PATCH] :WIP

update

erase immediately

update 2

fix

fix some tests
---
 mlir/include/mlir/Conversion/Passes.td        |   2 +
 .../mlir/Transforms/DialectConversion.h       |  23 +-
 .../ConvertToLLVM/ConvertToLLVMPass.cpp       |  26 +-
 .../Dialect/Linalg/Transforms/Detensorize.cpp |  35 +-
 .../Transforms/SparseTensorConversion.cpp     |   2 +-
 .../Transforms/Utils/DialectConversion.cpp    | 473 ++++++++++++++----
 ...assume-alignment-runtime-verification.mlir |   9 +
 .../atomic-rmw-runtime-verification.mlir      |   8 +
 .../MemRef/cast-runtime-verification.mlir     |  11 +-
 .../MemRef/copy-runtime-verification.mlir     |   9 +
 .../MemRef/dim-runtime-verification.mlir      |   9 +
 .../MemRef/load-runtime-verification.mlir     |  12 +-
 .../MemRef/store-runtime-verification.mlir    |   8 +
 .../MemRef/subview-runtime-verification.mlir  |  12 +-
 .../Tensor/cast-runtime-verification.mlir     |  11 +
 .../Tensor/dim-runtime-verification.mlir      |  16 +-
 .../Tensor/extract-runtime-verification.mlir  |  11 +
 .../extract_slice-runtime-verification.mlir   |  11 +
 mlir/test/Transforms/test-legalizer.mlir      |  29 +-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  10 +-
 20 files changed, 592 insertions(+), 135 deletions(-)

diff --git a/mlir/include/mlir/Conversion/Passes.td 
b/mlir/include/mlir/Conversion/Passes.td
index 6e1baaf23fcf7..4a9464ff265e0 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
                "Test conversion patterns of only the specified dialects">,
     Option<"useDynamic", "dynamic", "bool", "false",
            "Use op conversion attributes to configure the conversion">,
+    Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "false",
+           "Experimental performance flag to disallow pattern rollback">
   ];
 }
 
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h 
b/mlir/include/mlir/Transforms/DialectConversion.h
index f6437657c9a93..84b9035dc6358 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -728,6 +728,9 @@ class ConversionPatternRewriter final : public 
PatternRewriter {
 public:
   ~ConversionPatternRewriter() override;
 
+  /// Return the configuration of the current dialect conversion.
+  const ConversionConfig &getConfig() const;
+
   /// Apply a signature conversion to given block. This replaces the block with
   /// a new block containing the updated signature. The operations of the given
   /// block are inlined into the newly-created block, which is returned.
@@ -1228,18 +1231,18 @@ struct ConversionConfig {
   /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
   ///    and cannot be legalized by subsequent foldings / pattern applications.
   ///
-  /// If set to "false", the conversion driver will produce an LLVM fatal error
-  /// instead of rolling back IR modifications. Moreover, in case of a failed
-  /// conversion, the original IR is not restored. The resulting IR may be a
-  /// mix of original and rewritten IR. (Same as a failed greedy pattern
-  /// rewrite.)
+  /// Experimental: If set to "false", the conversion driver will produce an
+  /// LLVM fatal error instead of rolling back IR modifications. Moreover, in
+  /// case of a failed conversion, the original IR is not restored. The
+  /// resulting IR may be a mix of original and rewritten IR. (Same as a failed
+  /// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+  /// with ASAN to detect invalid pattern API usage.
   ///
-  /// Note: This flag was added in preparation of the One-Shot Dialect
-  /// Conversion refactoring, which will remove the ability to roll back IR
-  /// modifications from the conversion driver. Use this flag to ensure that
-  /// your patterns do not trigger any IR rollbacks. For details, see
+  /// When pattern rollback is disabled, the conversion driver has to maintain
+  /// less internal state. This is more efficient, but not supported by all
+  /// lowering patterns. For details, see
   /// 
https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
-  bool allowPatternRollback = true;
+  bool allowPatternRollback = false;
 };
 
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp 
b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
index ed5d6d4a7fe40..cdb715064b0f7 100644
--- a/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
+++ b/mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp
@@ -31,7 +31,8 @@ namespace {
 class ConvertToLLVMPassInterface {
 public:
   ConvertToLLVMPassInterface(MLIRContext *context,
-                             ArrayRef<std::string> filterDialects);
+                             ArrayRef<std::string> filterDialects,
+                             bool allowPatternRollback = true);
   virtual ~ConvertToLLVMPassInterface() = default;
 
   /// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface {
   MLIRContext *context;
   /// List of dialects names to use as filters.
   ArrayRef<std::string> filterDialects;
+  /// An experimental flag to disallow pattern rollback. This is more efficient
+  /// but not supported by all lowering patterns.
+  bool allowPatternRollback;
 };
 
 /// This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public 
ConvertToLLVMPassInterface {
 
   /// Apply the conversion driver.
   LogicalResult transform(Operation *op, AnalysisManager manager) const final {
-    if (failed(applyPartialConversion(op, *target, *patterns)))
+    ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
+    if (failed(applyPartialConversion(op, *target, *patterns, config)))
       return failure();
     return success();
   }
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public 
ConvertToLLVMPassInterface {
                                               patterns);
 
     // Apply the conversion.
-    if (failed(applyPartialConversion(op, target, std::move(patterns))))
+    ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
+    if (failed(applyPartialConversion(op, target, std::move(patterns), 
config)))
       return failure();
     return success();
   }
@@ -206,9 +214,11 @@ class ConvertToLLVMPass
     std::shared_ptr<ConvertToLLVMPassInterface> impl;
     // Choose the pass implementation.
     if (useDynamic)
-      impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
+      impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
+                                                    allowPatternRollback);
     else
-      impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
+      impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
+                                                   allowPatternRollback);
     if (failed(impl->initialize()))
       return failure();
     this->impl = impl;
@@ -228,8 +238,10 @@ class ConvertToLLVMPass
 
//===----------------------------------------------------------------------===//
 
 ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
-    MLIRContext *context, ArrayRef<std::string> filterDialects)
-    : context(context), filterDialects(filterDialects) {}
+    MLIRContext *context, ArrayRef<std::string> filterDialects,
+    bool allowPatternRollback)
+    : context(context), filterDialects(filterDialects),
+      allowPatternRollback(allowPatternRollback) {}
 
 void ConvertToLLVMPassInterface::getDependentDialects(
     DialectRegistry &registry) {
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp 
b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
index 830905495e759..d6f26fa200dbc 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp
@@ -458,6 +458,22 @@ struct LinalgDetensorize
     }
   };
 
+  /// A listener that forwards notifyBlockErased and notifyOperationErased to
+  /// the given callbacks.
+  struct CallbackListener : public RewriterBase::Listener {
+    CallbackListener(std::function<void(Operation *op)> onOperationErased,
+                     std::function<void(Block *block)> onBlockErased)
+        : onOperationErased(onOperationErased), onBlockErased(onBlockErased) {}
+
+    void notifyBlockErased(Block *block) override { onBlockErased(block); }
+    void notifyOperationErased(Operation *op) override {
+      onOperationErased(op);
+    }
+
+    std::function<void(Operation *op)> onOperationErased;
+    std::function<void(Block *block)> onBlockErased;
+  };
+
   void runOnOperation() override {
     MLIRContext *context = &getContext();
     DetensorizeTypeConverter typeConverter;
@@ -551,8 +567,23 @@ struct LinalgDetensorize
     populateBranchOpInterfaceTypeConversionPattern(patterns, typeConverter,
                                                    shouldConvertBranchOperand);
 
-    if (failed(
-            applyFullConversion(getOperation(), target, std::move(patterns))))
+    ConversionConfig config;
+    CallbackListener listener(/*onOperationErased=*/
+                              [&](Operation *op) {
+                                opsToDetensor.erase(op);
+                                detensorableBranchOps.erase(op);
+                              },
+                              /*onBlockErased=*/
+                              [&](Block *block) {
+                                for (BlockArgument arg :
+                                     block->getArguments()) {
+                                  blockArgsToDetensor.erase(arg);
+                                }
+                              });
+
+    config.listener = &listener;
+    if (failed(applyFullConversion(getOperation(), target, std::move(patterns),
+                                   config)))
       signalPassFailure();
 
     RewritePatternSet canonPatterns(context);
diff --git 
a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp 
b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
index 134aef3a6c719..0e88d31dae8e8 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp
@@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public 
OpConversionPattern<CompressOp> {
                    {tensor, lvlCoords, values, filled, added, count},
                    EmitCInterface::On);
     Operation *parent = getTop(op);
+    rewriter.setInsertionPointAfter(parent);
     rewriter.replaceOp(op, adaptor.getTensor());
     // Deallocate the buffers on exit of the loop nest.
-    rewriter.setInsertionPointAfter(parent);
     memref::DeallocOp::create(rewriter, loc, values);
     memref::DeallocOp::create(rewriter, loc, filled);
     memref::DeallocOp::create(rewriter, loc, added);
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4a60bfe489f9c..90cd58c8285b5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -121,17 +121,8 @@ struct ConversionValueMapping {
   /// false positives.
   bool isMappedTo(Value value) const { return mappedTo.contains(value); }
 
-  /// Lookup a value in the mapping. If `skipPureTypeConversions` is "true",
-  /// pure type conversions are not considered. Return an empty vector if no
-  /// mapping was found.
-  ///
-  /// Note: This mapping data structure supports N:M mappings. This function
-  /// first tries to look up mappings for each input value individually (and
-  /// then composes the results). If such a lookup is unsuccessful, the entire
-  /// vector is looked up together. If the lookup is still unsuccessful, an
-  /// empty vector is returned.
-  ValueVector lookup(const ValueVector &from,
-                     bool skipPureTypeConversions = false) const;
+  /// Lookup a value in the mapping.
+  ValueVector lookup(const ValueVector &from) const;
 
   template <typename T>
   struct IsValueVector : std::is_same<std::decay_t<T>, ValueVector> {};
@@ -191,43 +182,28 @@ struct ConversionValueMapping {
 /// conversions.)
 static const StringRef kPureTypeConversionMarker = "__pure_type_conversion__";
 
+/// Return the operation that defines all values in the vector. Return nullptr
+/// if the values are not defined by the same operation.
+static Operation *getCommonDefiningOp(const ValueVector &values) {
+  assert(!values.empty() && "expected non-empty value vector");
+  Operation *op = values.front().getDefiningOp();
+  for (Value v : llvm::drop_begin(values)) {
+    if (v.getDefiningOp() != op)
+      return nullptr;
+  }
+  return op;
+}
+
 /// A vector of values is a pure type conversion if all values are defined by
 /// the same operation and the operation has the `kPureTypeConversionMarker`
 /// attribute.
 static bool isPureTypeConversion(const ValueVector &values) {
   assert(!values.empty() && "expected non-empty value vector");
-  Operation *op = values.front().getDefiningOp();
-  for (Value v : llvm::drop_begin(values))
-    if (v.getDefiningOp() != op)
-      return false;
+  Operation *op = getCommonDefiningOp(values);
   return op && op->hasAttr(kPureTypeConversionMarker);
 }
 
-ValueVector ConversionValueMapping::lookup(const ValueVector &from,
-                                           bool skipPureTypeConversions) const 
{
-  // If possible, replace each value with (one or multiple) mapped values.
-  ValueVector next;
-  for (Value v : from) {
-    auto it = mapping.find({v});
-    if (it != mapping.end()) {
-      llvm::append_range(next, it->second);
-    } else {
-      next.push_back(v);
-    }
-  }
-  if (next != from) {
-    // At least one value was replaced.
-    return next;
-  }
-
-  // Otherwise: Check if there is a mapping for the entire vector. Such
-  // mappings are materializations. (N:M mapping are not supported for value
-  // replacements.)
-  //
-  // Note: From a correctness point of view, materializations do not have to
-  // be stored (and looked up) in the mapping. But for performance reasons,
-  // we choose to reuse existing IR (when possible) instead of creating it
-  // multiple times.
+ValueVector ConversionValueMapping::lookup(const ValueVector &from) const {
   auto it = mapping.find(from);
   if (it == mapping.end()) {
     // No mapping found: The lookup stops here.
@@ -874,7 +850,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), config(config) {}
+      : context(ctx), config(config), notifyingRewriter(ctx, config.listener) 
{}
 
   
//===--------------------------------------------------------------------===//
   // State Management
@@ -896,6 +872,7 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   /// failure.
   template <typename RewriteTy, typename... Args>
   void appendRewrite(Args &&...args) {
+    assert(config.allowPatternRollback && "appending rewrites is not allowed");
     rewrites.push_back(
         std::make_unique<RewriteTy>(*this, std::forward<Args>(args)...));
   }
@@ -922,15 +899,8 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   bool wasOpReplaced(Operation *op) const;
 
   /// Lookup the most recently mapped values with the desired types in the
-  /// mapping.
-  ///
-  /// Special cases:
-  /// - If the desired type range is empty, simply return the most recently
-  ///   mapped values.
-  /// - If there is no mapping to the desired types, also return the most
-  ///   recently mapped values.
-  /// - If there is no mapping for the given values at all, return the given
-  ///   value.
+  /// mapping, taking into account only replacements. Perform a best-effort
+  /// search for existing materializations with the desired types.
   ///
   /// If `skipPureTypeConversions` is "true", materializations that are pure
   /// type conversions are not considered.
@@ -1099,6 +1069,9 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   ConversionValueMapping mapping;
 
   /// Ordered list of block operations (creations, splits, motions).
+  /// This vector is maintained only if `allowPatternRollback` is set to
+  /// "true". Otherwise, all IR rewrites are materialized immediately and no
+  /// bookkeeping is needed.
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
   /// A set of operations that should no longer be considered for legalization.
@@ -1122,6 +1095,10 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   /// by the current pattern.
   SetVector<Block *> patternInsertedBlocks;
 
+  /// A list of unresolved materializations that were created by the current
+  /// pattern.
+  DenseSet<UnrealizedConversionCastOp> patternMaterializations;
+
   /// A mapping for looking up metadata of unresolved materializations.
   DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       unresolvedMaterializations;
@@ -1137,6 +1114,23 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   /// Dialect conversion configuration.
   const ConversionConfig &config;
 
+  /// A set of erased operations. This set is utilized only if
+  /// `allowPatternRollback` is set to "false". Conceptually, this set is
+  /// simialar to `replacedOps` (which is maintained when the flag is set to
+  /// "true"). However, erasing from a DenseSet is more efficient than erasing
+  /// from a SetVector.
+  DenseSet<Operation *> erasedOps;
+
+  /// A set of erased blocks. This set is utilized only if
+  /// `allowPatternRollback` is set to "false".
+  DenseSet<Block *> erasedBlocks;
+
+  /// A rewriter that notifies the listener (if any) about all IR
+  /// modifications. This rewriter is utilized only if `allowPatternRollback`
+  /// is set to "false". If the flag is set to "true", the listener is notified
+  /// with a separate mechanism (e.g., in `IRRewrite::commit`).
+  IRRewriter notifyingRewriter;
+
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -1173,11 +1167,8 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
-  if (!repl)
-    return;
-
+static void performReplaceBlockArg(RewriterBase &rewriter, BlockArgument arg,
+                                   Value repl) {
   if (isa<BlockArgument>(repl)) {
     rewriter.replaceAllUsesWith(arg, repl);
     return;
@@ -1194,6 +1185,13 @@ void ReplaceBlockArgRewrite::commit(RewriterBase 
&rewriter) {
   });
 }
 
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(arg, converter);
+  if (!repl)
+    return;
+  performReplaceBlockArg(rewriter, arg, repl);
+}
+
 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase({arg}); }
 
 void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
@@ -1279,6 +1277,63 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
 ValueVector ConversionPatternRewriterImpl::lookupOrDefault(
     Value from, TypeRange desiredTypes, bool skipPureTypeConversions) const {
+
+  // Helper function that looks up a single value.
+  auto lookup = [&](const ValueVector &values) -> ValueVector {
+    assert(!values.empty() && "expected non-empty value vector");
+
+    // If the pattern rollback is enabled, use the mapping to look up the
+    // values.
+    if (config.allowPatternRollback)
+      return mapping.lookup(values);
+
+    // Otherwise, look up values by examining the IR. All replacements have
+    // already been materialized in IR.
+    Operation *op = getCommonDefiningOp(values);
+    if (!op)
+      return {};
+    auto castOp = dyn_cast<UnrealizedConversionCastOp>(op);
+    if (!castOp)
+      return {};
+    if (!this->unresolvedMaterializations.contains(castOp))
+      return {};
+    if (castOp.getOutputs() != values)
+      return {};
+    return castOp.getInputs();
+  };
+
+  auto composedLookup = [&](const ValueVector &values) -> ValueVector {
+    // If possible, replace each value with (one or multiple) mapped values.
+    ValueVector next;
+    for (Value v : values) {
+      ValueVector r = lookup({v});
+      if (!r.empty()) {
+        llvm::append_range(next, r);
+      } else {
+        next.push_back(v);
+      }
+    }
+    if (next != values) {
+      // At least one value was replaced.
+      return next;
+    }
+
+    // Otherwise: Check if there is a mapping for the entire vector. Such
+    // mappings are materializations. (N:M mapping are not supported for value
+    // replacements.)
+    //
+    // Note: From a correctness point of view, materializations do not have to
+    // be stored (and looked up) in the mapping. But for performance reasons,
+    // we choose to reuse existing IR (when possible) instead of creating it
+    // multiple times.
+    ValueVector r = lookup(values);
+    if (r.empty()) {
+      // No mapping found: The lookup stops here.
+      return {};
+    }
+    return r;
+  };
+
   // Try to find the deepest values that have the desired types. If there is no
   // such mapping, simply return the deepest values.
   ValueVector desiredValue;
@@ -1300,7 +1355,7 @@ ValueVector 
ConversionPatternRewriterImpl::lookupOrDefault(
       desiredValue = current;
 
     // Lookup next value in the mapping.
-    ValueVector next = mapping.lookup(current, skipPureTypeConversions);
+    ValueVector next = composedLookup(current);
     if (next.empty())
       break;
     current = std::move(next);
@@ -1345,15 +1400,8 @@ void 
ConversionPatternRewriterImpl::resetState(RewriterState state,
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
                                                  StringRef patternName) {
   for (auto &rewrite :
-       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
-    if (!config.allowPatternRollback &&
-        !isa<UnresolvedMaterializationRewrite>(rewrite)) {
-      // Unresolved materializations can always be rolled back (erased).
-      llvm::report_fatal_error("pattern '" + patternName +
-                               "' rollback of IR modifications requested");
-    }
+       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
     rewrite->rollback();
-  }
   rewrites.resize(numRewritesToKeep);
 }
 
@@ -1417,12 +1465,12 @@ LogicalResult 
ConversionPatternRewriterImpl::remapValues(
 
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
   // Check to see if this operation is ignored or was replaced.
-  return replacedOps.count(op) || ignoredOps.count(op);
+  return wasOpReplaced(op) || ignoredOps.count(op);
 }
 
 bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
   // Check to see if this operation was replaced.
-  return replacedOps.count(op);
+  return replacedOps.count(op) || erasedOps.count(op);
 }
 
 
//===----------------------------------------------------------------------===//
@@ -1506,7 +1554,8 @@ Block 
*ConversionPatternRewriterImpl::applySignatureConversion(
   // a bit more efficient, so we try to do that when possible.
   bool fastPath = !config.listener;
   if (fastPath) {
-    appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+    if (config.allowPatternRollback)
+      appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
     newBlock->getOperations().splice(newBlock->end(), block->getOperations());
   } else {
     while (!block->empty())
@@ -1554,7 +1603,8 @@ Block 
*ConversionPatternRewriterImpl::applySignatureConversion(
     replaceUsesOfBlockArgument(origArg, replArgs, converter);
   }
 
-  appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
+  if (config.allowPatternRollback)
+    appendRewrite<BlockTypeConversionRewrite>(/*origBlock=*/block, newBlock);
 
   // Erase the old block. (It is just unlinked for now and will be erased 
during
   // cleanup.)
@@ -1583,23 +1633,32 @@ ValueRange 
ConversionPatternRewriterImpl::buildUnresolvedMaterialization(
   // tracking the materialization like we do for other operations.
   OpBuilder builder(outputTypes.front().getContext());
   builder.setInsertionPoint(ip.getBlock(), ip.getPoint());
-  auto convertOp =
+  UnrealizedConversionCastOp convertOp =
       UnrealizedConversionCastOp::create(builder, loc, outputTypes, inputs);
   if (isPureTypeConversion)
     convertOp->setAttr(kPureTypeConversionMarker, builder.getUnitAttr());
-  if (!valuesToMap.empty())
-    mapping.map(valuesToMap, convertOp.getResults());
+
+  // Register the materialization.
   if (castOp)
     *castOp = convertOp;
   unresolvedMaterializations[convertOp] =
       UnresolvedMaterializationInfo(converter, kind, originalType);
-  appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
-                                                  std::move(valuesToMap));
+  if (config.allowPatternRollback) {
+    if (!valuesToMap.empty())
+      mapping.map(valuesToMap, convertOp.getResults());
+    appendRewrite<UnresolvedMaterializationRewrite>(convertOp,
+                                                    std::move(valuesToMap));
+  } else {
+    patternMaterializations.insert(convertOp);
+  }
   return convertOp.getResults();
 }
 
 Value ConversionPatternRewriterImpl::findOrBuildReplacementValue(
     Value value, const TypeConverter *converter) {
+  assert(config.allowPatternRollback &&
+         "this code path is valid only in rollback mode");
+
   // Try to find a replacement value with the same type in the conversion value
   // mapping. This includes cached materializations. We try to reuse those
   // instead of generating duplicate IR.
@@ -1661,26 +1720,119 @@ void 
ConversionPatternRewriterImpl::notifyOperationInserted(
       logger.getOStream() << " (was detached)";
     logger.getOStream() << "\n";
   });
-  assert(!wasOpReplaced(op->getParentOp()) &&
+
+  // In rollback mode, it is easier to misuse the API, so perform extra error
+  // checking.
+  assert(!(config.allowPatternRollback && wasOpReplaced(op->getParentOp())) &&
          "attempting to insert into a block within a replaced/erased op");
 
+  // In "no rollback" mode, the listener is always notified immediately.
+  if (!config.allowPatternRollback && config.listener)
+    config.listener->notifyOperationInserted(op, previous);
+
   if (wasDetached) {
-    // If the op was detached, it is most likely a newly created op.
-    // TODO: If the same op is inserted multiple times from a detached state,
-    // the rollback mechanism may erase the same op multiple times. This is a
-    // bug in the rollback-based dialect conversion driver.
-    appendRewrite<CreateOperationRewrite>(op);
+    // If the op was detached, it is most likely a newly created op. Add it the
+    // set of newly created ops, so that it will be legalized. If this op is
+    // not a newly created op, it will be legalized a second time, which is
+    // inefficient but harmless.
     patternNewOps.insert(op);
+
+    if (config.allowPatternRollback) {
+      // TODO: If the same op is inserted multiple times from a detached
+      // state, the rollback mechanism may erase the same op multiple times.
+      // This is a bug in the rollback-based dialect conversion driver.
+      appendRewrite<CreateOperationRewrite>(op);
+    } else {
+      // In "no rollback" mode, there is an extra data structure for tracking
+      // erased operations that must be kept up to date.
+      erasedOps.erase(op);
+    }
     return;
   }
 
   // The op was moved from one place to another.
-  appendRewrite<MoveOperationRewrite>(op, previous);
+  if (config.allowPatternRollback)
+    appendRewrite<MoveOperationRewrite>(op, previous);
+}
+
+/// Given that `fromRange` is about to be replaced with `toRange`, compute
+/// replacement values with the types of `fromRange`.
+static SmallVector<Value>
+getReplacementValues(ConversionPatternRewriterImpl &impl, ValueRange fromRange,
+                     const SmallVector<SmallVector<Value>> &toRange,
+                     const TypeConverter *converter) {
+  assert(!impl.config.allowPatternRollback &&
+         "this code path is valid only in 'no rollback' mode");
+  SmallVector<Value> repls;
+  for (auto [from, to] : llvm::zip_equal(fromRange, toRange)) {
+    if (from.use_empty()) {
+      // The replaced value is dead. No replacement value is needed.
+      repls.push_back(Value());
+      continue;
+    }
+
+    if (to.empty()) {
+      // The replaced value is dropped. Materialize a replacement value "out of
+      // thin air".
+      Value srcMat = impl.buildUnresolvedMaterialization(
+          MaterializationKind::Source, computeInsertPoint(from), from.getLoc(),
+          /*valuesToMap=*/{}, /*inputs=*/ValueRange(),
+          /*outputTypes=*/from.getType(), /*originalType=*/Type(),
+          converter)[0];
+      repls.push_back(srcMat);
+      continue;
+    }
+
+    if (TypeRange(to) == TypeRange(from.getType())) {
+      // The replacement value already has the correct type. Use it directly.
+      repls.push_back(to[0]);
+      continue;
+    }
+
+    // The replacement value has the wrong type. Build a source materialization
+    // to the original type.
+    // TODO: This is a bit inefficient. We should try to reuse existing
+    // materializations if possible. This would require an extension of the
+    // `lookupOrDefault` API.
+    Value srcMat = impl.buildUnresolvedMaterialization(
+        MaterializationKind::Source, computeInsertPoint(to), from.getLoc(),
+        /*valuesToMap=*/{}, /*inputs=*/to, /*outputTypes=*/from.getType(),
+        /*originalType=*/Type(), converter)[0];
+    repls.push_back(srcMat);
+  }
+
+  return repls;
 }
 
 void ConversionPatternRewriterImpl::replaceOp(
     Operation *op, SmallVector<SmallVector<Value>> &&newValues) {
-  assert(newValues.size() == op->getNumResults());
+  assert(newValues.size() == op->getNumResults() &&
+         "incorrect number of replacement values");
+
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed: materialize all IR changes immediately.
+    SmallVector<Value> repls = getReplacementValues(
+        *this, op->getResults(), newValues, currentTypeConverter);
+    // Update internal data structures, so that there are no dangling pointers
+    // to erased IR.
+    op->walk([&](Operation *op) {
+      erasedOps.insert(op);
+      ignoredOps.remove(op);
+      if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+        unresolvedMaterializations.erase(castOp);
+        patternMaterializations.erase(castOp);
+      }
+      // The original op will be erased, so remove it from the set of
+      // unlegalized ops.
+      if (config.unlegalizedOps)
+        config.unlegalizedOps->erase(op);
+    });
+    op->walk([&](Block *block) { erasedBlocks.insert(block); });
+    // Replace the op with the replacement values and notify the listener.
+    notifyingRewriter.replaceOp(op, repls);
+    return;
+  }
+
   assert(!ignoredOps.contains(op) && "operation was already replaced");
 
   // Check if replaced op is an unresolved materialization, i.e., an
@@ -1720,11 +1872,46 @@ void ConversionPatternRewriterImpl::replaceOp(
 
 void ConversionPatternRewriterImpl::replaceUsesOfBlockArgument(
     BlockArgument from, ValueRange to, const TypeConverter *converter) {
+  if (!config.allowPatternRollback) {
+    SmallVector<Value> toConv = llvm::to_vector(to);
+    SmallVector<Value> repls =
+        getReplacementValues(*this, from, {toConv}, converter);
+    IRRewriter r(from.getContext());
+    Value repl = repls.front();
+    if (!repl)
+      return;
+
+    performReplaceBlockArg(r, from, repl);
+    return;
+  }
+
   appendRewrite<ReplaceBlockArgRewrite>(from.getOwner(), from, converter);
   mapping.map(from, to);
 }
 
 void ConversionPatternRewriterImpl::eraseBlock(Block *block) {
+  if (!config.allowPatternRollback) {
+    // Pattern rollback is not allowed: materialize all IR changes immediately.
+    // Update internal data structures, so that there are no dangling pointers
+    // to erased IR.
+    block->walk([&](Operation *op) {
+      erasedOps.insert(op);
+      ignoredOps.remove(op);
+      if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op)) {
+        unresolvedMaterializations.erase(castOp);
+        patternMaterializations.erase(castOp);
+      }
+      // The original op will be erased, so remove it from the set of
+      // unlegalized ops.
+      if (config.unlegalizedOps)
+        config.unlegalizedOps->erase(op);
+    });
+    block->walk([&](Block *block) { erasedBlocks.insert(block); });
+    // Erase the block and notify the listener.
+    notifyingRewriter.eraseBlock(block);
+    return;
+  }
+
   assert(!wasOpReplaced(block->getParentOp()) &&
          "attempting to erase a block within a replaced/erased op");
   appendRewrite<EraseBlockRewrite>(block);
@@ -1758,23 +1945,37 @@ void ConversionPatternRewriterImpl::notifyBlockInserted(
           logger.getOStream() << " (was detached)";
         logger.getOStream() << "\n";
       });
-  assert(!wasOpReplaced(newParentOp) &&
+
+  // In rollback mode, it is easier to misuse the API, so perform extra error
+  // checking.
+  assert(!(config.allowPatternRollback && wasOpReplaced(newParentOp)) &&
          "attempting to insert into a region within a replaced/erased op");
   (void)newParentOp;
 
+  // In "no rollback" mode, the listener is always notified immediately.
+  if (!config.allowPatternRollback && config.listener)
+    config.listener->notifyBlockInserted(block, previous, previousIt);
+
   patternInsertedBlocks.insert(block);
 
   if (wasDetached) {
     // If the block was detached, it is most likely a newly created block.
-    // TODO: If the same block is inserted multiple times from a detached 
state,
-    // the rollback mechanism may erase the same block multiple times. This is 
a
-    // bug in the rollback-based dialect conversion driver.
-    appendRewrite<CreateBlockRewrite>(block);
+    if (config.allowPatternRollback) {
+      // TODO: If the same block is inserted multiple times from a detached
+      // state, the rollback mechanism may erase the same block multiple times.
+      // This is a bug in the rollback-based dialect conversion driver.
+      appendRewrite<CreateBlockRewrite>(block);
+    } else {
+      // In "no rollback" mode, there is an extra data structure for tracking
+      // erased blocks that must be kept up to date.
+      erasedBlocks.erase(block);
+    }
     return;
   }
 
   // The block was moved from one place to another.
-  appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
+  if (config.allowPatternRollback)
+    appendRewrite<MoveBlockRewrite>(block, previous, previousIt);
 }
 
 void ConversionPatternRewriterImpl::inlineBlockBefore(Block *source,
@@ -1807,6 +2008,10 @@ ConversionPatternRewriter::ConversionPatternRewriter(
 
 ConversionPatternRewriter::~ConversionPatternRewriter() = default;
 
+const ConversionConfig &ConversionPatternRewriter::getConfig() const {
+  return impl->config;
+}
+
 void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
   assert(op && newOp && "expected non-null op");
   replaceOp(op, newOp->getResults());
@@ -1950,7 +2155,7 @@ void ConversionPatternRewriter::inlineBlockBefore(Block 
*source, Block *dest,
   // a bit more efficient, so we try to do that when possible.
   bool fastPath = !impl->config.listener;
 
-  if (fastPath)
+  if (fastPath && impl->config.allowPatternRollback)
     impl->inlineBlockBefore(source, dest, before);
 
   // Replace all uses of block arguments.
@@ -1976,6 +2181,11 @@ void ConversionPatternRewriter::inlineBlockBefore(Block 
*source, Block *dest,
 }
 
 void ConversionPatternRewriter::startOpModification(Operation *op) {
+  if (!impl->config.allowPatternRollback) {
+    // Pattern rollback is not allowed: no extra bookkeeping is needed.
+    PatternRewriter::startOpModification(op);
+    return;
+  }
   assert(!impl->wasOpReplaced(op) &&
          "attempting to modify a replaced/erased op");
 #ifndef NDEBUG
@@ -1985,20 +2195,29 @@ void 
ConversionPatternRewriter::startOpModification(Operation *op) {
 }
 
 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
-  assert(!impl->wasOpReplaced(op) &&
-         "attempting to modify a replaced/erased op");
-  PatternRewriter::finalizeOpModification(op);
   impl->patternModifiedOps.insert(op);
+  if (!impl->config.allowPatternRollback) {
+    PatternRewriter::finalizeOpModification(op);
+    if (getConfig().listener)
+      getConfig().listener->notifyOperationModified(op);
+    return;
+  }
 
   // There is nothing to do here, we only need to track the operation at the
   // start of the update.
 #ifndef NDEBUG
+  assert(!impl->wasOpReplaced(op) &&
+         "attempting to modify a replaced/erased op");
   assert(impl->pendingRootUpdates.erase(op) &&
          "operation did not have a pending in-place update");
 #endif
 }
 
 void ConversionPatternRewriter::cancelOpModification(Operation *op) {
+  if (!impl->config.allowPatternRollback) {
+    PatternRewriter::cancelOpModification(op);
+    return;
+  }
 #ifndef NDEBUG
   assert(impl->pendingRootUpdates.erase(op) &&
          "operation did not have a pending in-place update");
@@ -2355,6 +2574,37 @@ OperationLegalizer::legalizeWithFold(Operation *op,
   return success();
 }
 
+/// Report a fatal error indicating that newly produced or modified IR could
+/// not be legalized.
+static void
+reportNewIrLegalizationFatalError(const Pattern &pattern,
+                                  const SetVector<Operation *> &newOps,
+                                  const SetVector<Operation *> &modifiedOps,
+                                  const SetVector<Block *> &insertedBlocks) {
+  StringRef detachedBlockStr = "(detached block)";
+  std::string newOpNames = llvm::join(
+      llvm::map_range(
+          newOps, [](Operation *op) { return op->getName().getStringRef(); }),
+      ", ");
+  std::string modifiedOpNames = llvm::join(
+      llvm::map_range(
+          newOps, [](Operation *op) { return op->getName().getStringRef(); }),
+      ", ");
+  std::string insertedBlockNames = llvm::join(
+      llvm::map_range(insertedBlocks,
+                      [&](Block *block) {
+                        if (block->getParentOp())
+                          return 
block->getParentOp()->getName().getStringRef();
+                        return detachedBlockStr;
+                      }),
+      ", ");
+  llvm::report_fatal_error(
+      "pattern '" + pattern.getDebugName() +
+      "' produced IR that could not be legalized. " + "new ops: {" +
+      newOpNames + "}, " + "modified ops: {" + modifiedOpNames + "}, " +
+      "inserted block into ops: {" + insertedBlockNames + "}");
+}
+
 LogicalResult
 OperationLegalizer::legalizeWithPattern(Operation *op,
                                         ConversionPatternRewriter &rewriter) {
@@ -2398,17 +2648,35 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   RewriterState curState = rewriterImpl.getCurrentState();
   auto onFailure = [&](const Pattern &pattern) {
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
-#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
     if (!rewriterImpl.config.allowPatternRollback) {
-      // Returning "failure" after modifying IR is not allowed.
+      // Erase all unresolved materializations.
+      for (auto op : rewriterImpl.patternMaterializations) {
+        rewriterImpl.unresolvedMaterializations.erase(op);
+        op.erase();
+      }
+      rewriterImpl.patternMaterializations.clear();
+#if 0
+      // Cheap pattern check that could have false positives. Can be enabled
+      // manually for debugging purposes. E.g., this check would report an API
+      // violation when an op is created and then erased in the same pattern.
+      if (!rewriterImpl.patternNewOps.empty() ||
+          !rewriterImpl.patternModifiedOps.empty() ||
+          !rewriterImpl.patternInsertedBlocks.empty()) {
+        llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
+                                 "' rollback of IR modifications requested");
+      }
+#endif
+#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+      // Expensive pattern check that can detect more API violations and has no
+      // fewer false positives than the cheap check.
       if (checkOp) {
         OperationFingerPrint fingerPrintAfterPattern(checkOp);
         if (fingerPrintAfterPattern != *topLevelFingerPrint)
           llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
                                    "' returned failure but IR did change");
       }
-    }
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
+    }
     rewriterImpl.patternNewOps.clear();
     rewriterImpl.patternModifiedOps.clear();
     rewriterImpl.patternInsertedBlocks.clear();
@@ -2432,6 +2700,16 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
   // successfully applied.
   auto onSuccess = [&](const Pattern &pattern) {
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
+    if (!rewriterImpl.config.allowPatternRollback) {
+      // Eagerly erase unused materializations.
+      for (auto op : rewriterImpl.patternMaterializations) {
+        if (op->use_empty()) {
+          rewriterImpl.unresolvedMaterializations.erase(op);
+          op.erase();
+        }
+      }
+      rewriterImpl.patternMaterializations.clear();
+    }
     SetVector<Operation *> newOps = moveAndReset(rewriterImpl.patternNewOps);
     SetVector<Operation *> modifiedOps =
         moveAndReset(rewriterImpl.patternModifiedOps);
@@ -2442,8 +2720,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     appliedPatterns.erase(&pattern);
     if (failed(result)) {
       if (!rewriterImpl.config.allowPatternRollback)
-        llvm::report_fatal_error("pattern '" + pattern.getDebugName() +
-                                 "' produced IR that could not be legalized");
+        reportNewIrLegalizationFatalError(pattern, newOps, modifiedOps,
+                                          insertedBlocks);
       rewriterImpl.resetState(curState, pattern.getDebugName());
     }
     if (config.listener)
@@ -2522,6 +2800,9 @@ LogicalResult 
OperationLegalizer::legalizePatternBlockRewrites(
   // If the pattern moved or created any blocks, make sure the types of block
   // arguments get legalized.
   for (Block *block : insertedBlocks) {
+    if (impl.erasedBlocks.contains(block))
+      continue;
+
     // Only check blocks outside of the current operation.
     Operation *parentOp = block->getParentOp();
     if (!parentOp || parentOp == op || block->getNumArguments() == 0)
diff --git 
a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
 
b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 8f74976c59773..25a338df8d790 100644
--- 
a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ 
b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -6,6 +6,15 @@
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @main() {
   // This buffer is properly aligned. There should be no error.
   // CHECK-NOT: ^ memref is not aligned to 8
diff --git 
a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 26c731c921356..4c6a48d577a6c 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -5,6 +5,14 @@
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
   %cst = arith.constant 1.0 : f32
   memref.atomic_rmw addf %cst, %memref[%index] : (f32, memref<?xf32>) -> f32
diff --git 
a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
index 8b6308e9c1939..1ac10306395ad 100644
--- a/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/cast-runtime-verification.mlir
@@ -1,11 +1,20 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -test-cf-assert \
 // RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
 // RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @cast_to_static_dim(%m: memref<?xf32>) -> memref<10xf32> {
   %0 = memref.cast %m : memref<?xf32> to memref<10xf32>
   return %0 : memref<10xf32>
diff --git 
a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
index 95b9db2832cee..be9417baf93df 100644
--- a/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/copy-runtime-verification.mlir
@@ -6,6 +6,15 @@
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 // Put memref.copy in a function, otherwise the memref.cast may fold.
 func.func @memcpy_helper(%src: memref<?xf32>, %dest: memref<?xf32>) {
   memref.copy %src, %dest : memref<?xf32> to memref<?xf32>
diff --git a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
index 2e3f271743c93..ef4af62459738 100644
--- a/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/dim-runtime-verification.mlir
@@ -6,6 +6,15 @@
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @main() {
   %c4 = arith.constant 4 : index
   %alloca = memref.alloca() : memref<1xf32>
diff --git 
a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
index b87e5bdf0970c..2e42648297875 100644
--- a/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/load-runtime-verification.mlir
@@ -1,12 +1,20 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -test-cf-assert \
 // RUN:     -expand-strided-metadata \
-// RUN:     -lower-affine \
+// RUN:     -test-cf-assert \
 // RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @load(%memref: memref<1xf32>, %index: index) {
     memref.load %memref[%index] :  memref<1xf32>
     return
diff --git 
a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index 12253fa3b5e83..dd000c6904bcb 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -5,6 +5,14 @@
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @store_dynamic(%memref: memref<?xf32>, %index: index) {
   %cst = arith.constant 1.0 : f32
   memref.store %cst, %memref[%index] :  memref<?xf32>
diff --git 
a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
index ec7e4085f2fa5..9fbe5bc60321e 100644
--- a/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/subview-runtime-verification.mlir
@@ -1,12 +1,22 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -test-cf-assert \
 // RUN:     -expand-strided-metadata \
 // RUN:     -lower-affine \
+// RUN:     -test-cf-assert \
 // RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -expand-strided-metadata \
+// RUN:     -lower-affine \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @subview(%memref: memref<1xf32>, %offset: index) {
     memref.subview %memref[%offset] [1] [1] : 
         memref<1xf32> to 
diff --git 
a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir
index e4aab32d4a390..f37a6d6383c48 100644
--- a/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/cast-runtime-verification.mlir
@@ -8,6 +8,17 @@
 // RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN:     -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-scf-to-cf \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func private @cast_to_static_dim(%t: tensor<?xf32>) -> tensor<10xf32> {
   %0 = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
   return %0 : tensor<10xf32>
diff --git a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir
index c6d8f698b9433..e9e5c040c6488 100644
--- a/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/dim-runtime-verification.mlir
@@ -1,10 +1,20 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
-// RUN:     -one-shot-bufferize \
-// RUN:     -buffer-deallocation-pipeline \
+// RUN:     -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN:     -buffer-deallocation-pipeline=private-function-dynamic-ownership \
 // RUN:     -test-cf-assert \
 // RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
-// RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
+// RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN:     -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
 func.func @main() {
diff --git 
a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir
index 8e3cab7be704d..73fcec4d7abcd 100644
--- a/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/Tensor/extract-runtime-verification.mlir
@@ -8,6 +8,17 @@
 // RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN:     -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-scf-to-cf \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @extract(%tensor: tensor<1xf32>, %index: index) {
     tensor.extract %tensor[%index] :  tensor<1xf32>
     return
diff --git 
a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
index 28f9be0fffe64..341a59e8b8102 100644
--- 
a/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
+++ 
b/mlir/test/Integration/Dialect/Tensor/extract_slice-runtime-verification.mlir
@@ -8,6 +8,17 @@
 // RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s
 
+// RUN: mlir-opt %s -generate-runtime-verification \
+// RUN:     -one-shot-bufferize="bufferize-function-boundaries" \
+// RUN:     -buffer-deallocation-pipeline=private-function-dynamic-ownership \
+// RUN:     -test-cf-assert \
+// RUN:     -convert-scf-to-cf \
+// RUN:     -convert-to-llvm="allow-pattern-rollback=0" \
+// RUN:     -reconcile-unrealized-casts | \
+// RUN: mlir-runner -e main -entry-point-result=void \
+// RUN:     -shared-libs=%tlir_runner_utils 2>&1 | \
+// RUN: FileCheck %s
+
 func.func @extract_slice(%tensor: tensor<1xf32>, %offset: index) {
     tensor.extract_slice %tensor[%offset] [1] [1] : tensor<1xf32> to 
tensor<1xf32>
     return
diff --git a/mlir/test/Transforms/test-legalizer.mlir 
b/mlir/test/Transforms/test-legalizer.mlir
index 5630d1540e4d5..9a04da7904863 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,9 +1,14 @@
-// RUN: mlir-opt -allow-unregistered-dialect -split-input-file 
-test-legalize-patterns -verify-diagnostics -profile-actions-to=- %s | 
FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file 
-test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics %s | 
FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file 
-test-legalize-patterns="allow-pattern-rollback=1" -verify-diagnostics 
-profile-actions-to=- %s | FileCheck %s --check-prefix=CHECK-PROFILER
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file 
-test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | 
FileCheck %s
+
+// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "B"
+// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "B"
+// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "B"
+// CHECK-PROFILER: "name": "apply-pattern", "cat": "PERF", "ph": "E"
+// CHECK-PROFILER: "name": "apply-conversion", "cat": "PERF", "ph": "E"
+// CHECK-PROFILER: "name": "pass-execution", "cat": "PERF", "ph": "E"
 
-//      CHECK: "name": "pass-execution", "cat": "PERF", "ph": "B"
-//      CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "B"
-//      CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "B"
-//      CHECK: "name": "apply-pattern", "cat": "PERF", "ph": "E"
 // Note: Listener notifications appear after the pattern application because
 // the conversion driver sends all notifications at the end of the conversion
 // in bulk.
@@ -11,8 +16,6 @@
 // CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
 // CHECK-NEXT: notifyOperationModified: func.return
 // CHECK-NEXT: notifyOperationErased: test.illegal_op_a
-//      CHECK: "name": "apply-conversion", "cat": "PERF", "ph": "E"
-//      CHECK: "name": "pass-execution", "cat": "PERF", "ph": "E"
 // CHECK-LABEL: verifyDirectPattern
 func.func @verifyDirectPattern() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -29,7 +32,9 @@ func.func @verifyDirectPattern() -> i32 {
 // CHECK-NEXT: notifyOperationErased: test.illegal_op_c
 // CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
 // CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
-// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+// Note: func.return is modified a second time when running in no-rollback
+//       mode.
+//      CHECK: notifyOperationErased: test.illegal_op_e
 
 // CHECK-LABEL: verifyLargerBenefit
 func.func @verifyLargerBenefit() -> i32 {
@@ -70,7 +75,7 @@ func.func @remap_call_1_to_1(%arg0: i64) {
 // CHECK:      notifyBlockInserted into func.func: was unlinked
 
 // Contents of the old block are moved to the new block.
-// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact 
position unknown
+// CHECK-NEXT: notifyOperationInserted: test.return
 
 // The old block is erased.
 // CHECK-NEXT: notifyBlockErased
@@ -409,8 +414,10 @@ func.func @test_remap_block_arg() {
 
 // CHECK-LABEL: func @test_multiple_1_to_n_replacement()
 //       CHECK:   %[[legal_op:.*]]:4 = "test.legal_op"() : () -> (f16, f16, 
f16, f16)
-//       CHECK:   %[[cast:.*]] = "test.cast"(%[[legal_op]]#0, %[[legal_op]]#1, 
%[[legal_op]]#2, %[[legal_op]]#3) : (f16, f16, f16, f16) -> f16
-//       CHECK:   "test.valid"(%[[cast]]) : (f16) -> ()
+// Note: There is a bug in the rollback-based conversion driver: it emits a
+// "test.cast" : (f16, f16, f16, f16) -> f16, when it should be emitting
+// three consecutive casts of (f16, f16) -> f16.
+//       CHECK:   "test.valid"(%{{.*}}) : (f16) -> ()
 func.func @test_multiple_1_to_n_replacement() {
   %0 = "test.multiple_1_to_n_replacement"() : () -> (f16)
   "test.invalid"(%0) : (f16) -> ()
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp 
b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 7150401bdbdce..189d5a1bb5fb6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1116,8 +1116,8 @@ struct TestNonRootReplacement : public RewritePattern {
     auto illegalOp = ILLegalOpF::create(rewriter, op->getLoc(), resultType);
     auto legalOp = LegalOpB::create(rewriter, op->getLoc(), resultType);
 
-    rewriter.replaceOp(illegalOp, legalOp);
     rewriter.replaceOp(op, illegalOp);
+    rewriter.replaceOp(illegalOp, legalOp);
     return success();
   }
 };
@@ -1301,6 +1301,7 @@ class TestMultiple1ToNReplacement : public 
ConversionPattern {
     // Helper function that replaces the given op with a new op of the given
     // name and doubles each result (1 -> 2 replacement of each result).
     auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+      rewriter.setInsertionPointAfter(op);
       SmallVector<Type> types;
       for (Type t : op->getResultTypes()) {
         types.push_back(t);
@@ -1499,6 +1500,7 @@ struct TestLegalizePatternDriver
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
       ConversionConfig config;
+      config.allowPatternRollback = allowPatternRollback;
       DumpNotifications dumpNotifications;
       config.listener = &dumpNotifications;
       config.unlegalizedOps = &unlegalizedOps;
@@ -1520,6 +1522,7 @@ struct TestLegalizePatternDriver
       });
 
       ConversionConfig config;
+      config.allowPatternRollback = allowPatternRollback;
       DumpNotifications dumpNotifications;
       config.listener = &dumpNotifications;
       if (failed(applyFullConversion(getOperation(), target,
@@ -1535,6 +1538,7 @@ struct TestLegalizePatternDriver
     // Analyze the convertible operations.
     DenseSet<Operation *> legalizedOps;
     ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
     config.legalizableOps = &legalizedOps;
     if (failed(applyAnalysisConversion(getOperation(), target,
                                        std::move(patterns), config)))
@@ -1555,6 +1559,10 @@ struct TestLegalizePatternDriver
           clEnumValN(ConversionMode::Full, "full", "Perform a full 
conversion"),
           clEnumValN(ConversionMode::Partial, "partial",
                      "Perform a partial conversion"))};
+
+  Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+                                    llvm::cl::desc("Allow pattern rollback"),
+                                    llvm::cl::init(true)};
 };
 } // namespace
 

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to