https://github.com/matthias-springer created 
https://github.com/llvm/llvm-project/pull/136490

This commit adds a new flag to `ConversionConfig` to disallow the rollback of 
IR modification. This commit is in preparation of the One-Shot Dialect 
Conversion refactoring, which will remove the ability to roll back IR 
modifications from the conversion driver.

RFC: 
https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083

By default, this flag is set to "true". I.e., the rollback of IR modifications 
is allowed. When set to "false", the conversion driver will report a fatal LLVM 
error when an IR rollback is requested. The name of the rolled back pattern is 
included in the error message. Moreover, when the original IR is no longer 
restored after a failed conversion.

Example:
```
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 
offset :11:8: error: pattern '(anonymous namespace)::CmpFOpNanKernelPattern' 
produced IR that could not be legalized
  %0 = arith.cmpf ord, %arg0, %arg1 fastmath<fast> : f32
       ^
within split at llvm-project/mlir/test/Conversion/ArithToSPIRV/fast-math.mlir:1 
offset :11:8: note: see current operation: %1 = "arith.cmpf"(%arg0, %arg1) 
<{fastmath = #arith.fastmath<fast>, predicate = 7 : i64}> : (f32, f32) -> i1
pattern '(anonymous namespace)::CmpFOpNanKernelPattern' rollback of IR 
modifications requested
UNREACHABLE executed at 
llvm-project/mlir/lib/Transforms/Utils/DialectConversion.cpp:1231!
```

The majority of patterns in MLIR have already been updated such that they do 
not trigger any rollbacks, but a few SPIRV patterns remain. More information in 
the RFC.

Depends on #136489.


>From da7b88810b00eacf65884dedec09d3e05e24f8db Mon Sep 17 00:00:00 2001
From: Matthias Springer <msprin...@nvidia.com>
Date: Sun, 20 Apr 2025 13:39:20 +0200
Subject: [PATCH] no rollback flag

---
 .../mlir/Transforms/DialectConversion.h       | 20 +++++++
 .../Transforms/Utils/DialectConversion.cpp    | 57 ++++++++++++++-----
 2 files changed, 63 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h 
b/mlir/include/mlir/Transforms/DialectConversion.h
index b6ab252456e70..b65b3ea971f91 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1219,6 +1219,26 @@ struct ConversionConfig {
   /// materializations and instead inserts "builtin.unrealized_conversion_cast"
   /// ops to ensure that the resulting IR is valid.
   bool buildMaterializations = true;
+
+  /// If set to "true", pattern rollback is allowed. The conversion driver
+  /// rolls back IR modifications in the following situations.
+  ///
+  /// 1. Pattern implementation returns "failure" after modifying IR.
+  /// 2. Pattern produces IR (in-place modification or new IR) that is illegal
+  ///    and cannot be legalized by subsequent foldings / pattern applications.
+  ///
+  /// If set to "false", the conversion driver will produce an LLVM fatal error
+  /// instead of rolling back IR modifications. Moreover, in case of a failed
+  /// conversion, the original IR is not restored. The resulting IR may be a
+  /// mix of original and rewritten IR. (Same as a failed greedy pattern
+  /// rewrite.)
+  ///
+  /// Note: This flag was added in preparation of the One-Shot Dialect
+  /// Conversion refactoring, which will remove the ability to roll back IR
+  /// modifications from the conversion driver. Use this flag to ensure that
+  /// your patterns do not trigger any IR rollbacks. For details, see
+  /// 
https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
+  bool allowPatternRollback = true;
 };
 
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 63225c6bbee7c..7a9da13427f91 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -861,8 +861,10 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   /// conversion process succeeds.
   void applyRewrites();
 
-  /// Reset the state of the rewriter to a previously saved point.
-  void resetState(RewriterState state);
+  /// Reset the state of the rewriter to a previously saved point. Optionally,
+  /// the name of the pattern that triggered the rollback can specified for
+  /// debugging purposes.
+  void resetState(RewriterState state, StringRef patternName = "");
 
   /// Append a rewrite. Rewrites are committed upon success and rolled back 
upon
   /// failure.
@@ -873,8 +875,9 @@ struct ConversionPatternRewriterImpl : public 
RewriterBase::Listener {
   }
 
   /// Undo the rewrites (motions, splits) one by one in reverse order until
-  /// "numRewritesToKeep" rewrites remains.
-  void undoRewrites(unsigned numRewritesToKeep = 0);
+  /// "numRewritesToKeep" rewrites remains. Optionally, the name of the pattern
+  /// that triggered the rollback can specified for debugging purposes.
+  void undoRewrites(unsigned numRewritesToKeep = 0, StringRef patternName = 
"");
 
   /// Remap the given values to those with potentially different types. Returns
   /// success if the values could be remapped, failure otherwise. 
`valueDiagTag`
@@ -1204,9 +1207,10 @@ RewriterState 
ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(rewrites.size(), ignoredOps.size(), replacedOps.size());
 }
 
-void ConversionPatternRewriterImpl::resetState(RewriterState state) {
+void ConversionPatternRewriterImpl::resetState(RewriterState state,
+                                               StringRef patternName) {
   // Undo any rewrites.
-  undoRewrites(state.numRewrites);
+  undoRewrites(state.numRewrites, patternName);
 
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
@@ -1216,10 +1220,19 @@ void 
ConversionPatternRewriterImpl::resetState(RewriterState state) {
     replacedOps.pop_back();
 }
 
-void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
+void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
+                                                 StringRef patternName) {
   for (auto &rewrite :
-       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep)))
+       llvm::reverse(llvm::drop_begin(rewrites, numRewritesToKeep))) {
+    if (!config.allowPatternRollback &&
+        !isa<UnresolvedMaterializationRewrite>(rewrite)) {
+      // Unresolved materializations can always be rolled back (erased).
+      std::string errorMessage = "pattern '" + std::string(patternName) +
+                                 "' rollback of IR modifications requested";
+      llvm_unreachable(errorMessage.c_str());
+    }
     rewrite->rollback();
+  }
   rewrites.resize(numRewritesToKeep);
 }
 
@@ -2158,7 +2171,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     });
     if (config.listener)
       config.listener->notifyPatternEnd(pattern, failure());
-    rewriterImpl.resetState(curState);
+    rewriterImpl.resetState(curState, pattern.getDebugName());
     appliedPatterns.erase(&pattern);
   };
 
@@ -2168,8 +2181,13 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
     assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
     auto result = legalizePatternResult(op, pattern, rewriter, curState);
     appliedPatterns.erase(&pattern);
-    if (failed(result))
-      rewriterImpl.resetState(curState);
+    if (failed(result)) {
+      if (!rewriterImpl.config.allowPatternRollback)
+        op->emitError("pattern '")
+            << pattern.getDebugName()
+            << "' produced IR that could not be legalized";
+      rewriterImpl.resetState(curState, pattern.getDebugName());
+    }
     if (config.listener)
       config.listener->notifyPatternEnd(pattern, result);
     return result;
@@ -2674,9 +2692,20 @@ LogicalResult 
OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
   ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
 
-  for (auto *op : toConvert)
-    if (failed(convert(rewriter, op)))
-      return rewriterImpl.undoRewrites(), failure();
+  for (auto *op : toConvert) {
+    if (failed(convert(rewriter, op))) {
+      // Dialect conversion failed.
+      if (rewriterImpl.config.allowPatternRollback) {
+        // Rollback is allowed: restore the original IR.
+        rewriterImpl.undoRewrites();
+      } else {
+        // Rollback is not allowed: apply all modifications that have been
+        // performed so far.
+        rewriterImpl.applyRewrites();
+      }
+      return failure();
+    }
+  }
 
   // After a successful conversion, apply rewrites.
   rewriterImpl.applyRewrites();

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to