Author: Mehdi Amini
Date: 2025-09-12T14:50:53+01:00
New Revision: db67d6aa9438f09f72c8574264e1d6c98bfc2d5f

URL: 
https://github.com/llvm/llvm-project/commit/db67d6aa9438f09f72c8574264e1d6c98bfc2d5f
DIFF: 
https://github.com/llvm/llvm-project/commit/db67d6aa9438f09f72c8574264e1d6c98bfc2d5f.diff

LOG: Revert "[mlir][Transforms] Fix crash in `reconcile-unrealized-casts` 
(#158067)"

This reverts commit 03e3ce82b926a4c138e6e0bacfcd1d5572c3e380.

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    
mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
    
mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
    mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
    mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h 
b/mlir/include/mlir/Transforms/DialectConversion.h
index f8caae3ce9995..a096f82a4cfd8 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1428,9 +1428,6 @@ struct ConversionConfig {
 ///
 /// In the above example, %0 can be used instead of %3 and all cast ops are
 /// folded away.
-void reconcileUnrealizedCasts(
-    const DenseSet<UnrealizedConversionCastOp> &castOps,
-    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);
 void reconcileUnrealizedCasts(
     ArrayRef<UnrealizedConversionCastOp> castOps,
     SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps = nullptr);

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp 
b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index d53e1e78f2027..df9700f11200f 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3100,7 +3100,6 @@ unsigned OperationLegalizer::applyCostModelToPatterns(
 
//===----------------------------------------------------------------------===//
 // OperationConverter
 
//===----------------------------------------------------------------------===//
-
 namespace {
 enum OpConversionMode {
   /// In this mode, the conversion will ignore failed conversions to allow
@@ -3118,13 +3117,6 @@ enum OpConversionMode {
 } // namespace
 
 namespace mlir {
-
-// Predeclaration only.
-static void reconcileUnrealizedCasts(
-    const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
-        &castOps,
-    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps);
-
 // This class converts operations to a given conversion target via a set of
 // rewrite patterns. The conversion behaves 
diff erently depending on the
 // conversion mode.
@@ -3272,13 +3264,18 @@ LogicalResult 
OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   // After a successful conversion, apply rewrites.
   rewriterImpl.applyRewrites();
 
-  // Reconcile all UnrealizedConversionCastOps that were inserted by the
-  // dialect conversion frameworks. (Not the ones that were inserted by
-  // patterns.)
+  // Gather all unresolved materializations.
+  SmallVector<UnrealizedConversionCastOp> allCastOps;
   const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
       &materializations = rewriterImpl.unresolvedMaterializations;
+  for (auto it : materializations)
+    allCastOps.push_back(it.first);
+
+  // Reconcile all UnrealizedConversionCastOps that were inserted by the
+  // dialect conversion frameworks. (Not the one that were inserted by
+  // patterns.)
   SmallVector<UnrealizedConversionCastOp> remainingCastOps;
-  reconcileUnrealizedCasts(materializations, &remainingCastOps);
+  reconcileUnrealizedCasts(allCastOps, &remainingCastOps);
 
   // Drop markers.
   for (UnrealizedConversionCastOp castOp : remainingCastOps)
@@ -3306,19 +3303,20 @@ LogicalResult 
OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 // Reconcile Unrealized Casts
 
//===----------------------------------------------------------------------===//
 
-/// Try to reconcile all given UnrealizedConversionCastOps and store the
-/// left-over ops in `remainingCastOps` (if provided). See documentation in
-/// DialectConversion.h for more details.
-/// The `isCastOpOfInterestFn` is used to filter the cast ops to proceed: the
-/// algorithm may visit an operand (or user) which is a cast op, but will not
-/// try to reconcile it if not in the filtered set.
-template <typename RangeT>
-static void reconcileUnrealizedCastsImpl(
-    RangeT castOps,
-    function_ref<bool(UnrealizedConversionCastOp)> isCastOpOfInterestFn,
+void mlir::reconcileUnrealizedCasts(
+    ArrayRef<UnrealizedConversionCastOp> castOps,
     SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
-  // A worklist of cast ops to process.
   SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
+  // This set is maintained only if `remainingCastOps` is provided.
+  DenseSet<Operation *> erasedOps;
+
+  // Helper function that adds all operands to the worklist that are an
+  // unrealized_conversion_cast op result.
+  auto enqueueOperands = [&](UnrealizedConversionCastOp castOp) {
+    for (Value v : castOp.getInputs())
+      if (auto inputCastOp = v.getDefiningOp<UnrealizedConversionCastOp>())
+        worklist.insert(inputCastOp);
+  };
 
   // Helper function that return the unrealized_conversion_cast op that
   // defines all inputs of the given op (in the same order). Return "nullptr"
@@ -3339,110 +3337,39 @@ static void reconcileUnrealizedCastsImpl(
   // Process ops in the worklist bottom-to-top.
   while (!worklist.empty()) {
     UnrealizedConversionCastOp castOp = worklist.pop_back_val();
+    if (castOp->use_empty()) {
+      // DCE: If the op has no users, erase it. Add the operands to the
+      // worklist to find additional DCE opportunities.
+      enqueueOperands(castOp);
+      if (remainingCastOps)
+        erasedOps.insert(castOp.getOperation());
+      castOp->erase();
+      continue;
+    }
 
     // Traverse the chain of input cast ops to see if an op with the same
     // input types can be found.
     UnrealizedConversionCastOp nextCast = castOp;
     while (nextCast) {
       if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
-        if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
-              return v.getDefiningOp() == castOp;
-            })) {
-          // Ran into a cycle.
-          break;
-        }
-
         // Found a cast where the input types match the output types of the
-        // matched op. We can directly use those inputs.
+        // matched op. We can directly use those inputs and the matched op can
+        // be removed.
+        enqueueOperands(castOp);
         castOp.replaceAllUsesWith(nextCast.getInputs());
+        if (remainingCastOps)
+          erasedOps.insert(castOp.getOperation());
+        castOp->erase();
         break;
       }
       nextCast = getInputCast(nextCast);
     }
   }
 
-  // A set of all alive cast ops. I.e., ops whose results are (transitively)
-  // used by an op that is not a cast op.
-  DenseSet<Operation *> liveOps;
-
-  // Helper function that marks the given op and transitively reachable input
-  // cast ops as alive.
-  auto markOpLive = [&](Operation *rootOp) {
-    SmallVector<Operation *> worklist;
-    worklist.push_back(rootOp);
-    while (!worklist.empty()) {
-      Operation *op = worklist.pop_back_val();
-      if (liveOps.insert(op).second) {
-        // Successfully inserted: process reachable input cast ops.
-        for (Value v : op->getOperands())
-          if (auto castOp = v.getDefiningOp<UnrealizedConversionCastOp>())
-            if (isCastOpOfInterestFn(castOp))
-              worklist.push_back(castOp);
-      }
-    }
-  };
-
-  // Find all alive cast ops.
-  for (UnrealizedConversionCastOp op : castOps) {
-    // The op may have been marked live already as being an operand of another
-    // live cast op.
-    if (liveOps.contains(op.getOperation()))
-      continue;
-    // If any of the users is not a cast op, mark the current op (and its
-    // input ops) as live.
-    if (llvm::any_of(op->getUsers(), [&](Operation *user) {
-          auto castOp = dyn_cast<UnrealizedConversionCastOp>(user);
-          return !castOp || !isCastOpOfInterestFn(castOp);
-        }))
-      markOpLive(op);
-  }
-
-  // Erase all dead cast ops.
-  for (UnrealizedConversionCastOp op : castOps) {
-    if (liveOps.contains(op)) {
-      // Op is alive and was not erased. Add it to the remaining cast ops.
-      if (remainingCastOps)
+  if (remainingCastOps)
+    for (UnrealizedConversionCastOp op : castOps)
+      if (!erasedOps.contains(op.getOperation()))
         remainingCastOps->push_back(op);
-      continue;
-    }
-
-    // Op is dead. Erase it.
-    op->dropAllUses();
-    op->erase();
-  }
-}
-
-void mlir::reconcileUnrealizedCasts(
-    ArrayRef<UnrealizedConversionCastOp> castOps,
-    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
-  // Set of all cast ops for faster lookups.
-  DenseSet<UnrealizedConversionCastOp> castOpSet;
-  for (UnrealizedConversionCastOp op : castOps)
-    castOpSet.insert(op);
-  reconcileUnrealizedCasts(castOpSet, remainingCastOps);
-}
-
-void mlir::reconcileUnrealizedCasts(
-    const DenseSet<UnrealizedConversionCastOp> &castOps,
-    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
-  reconcileUnrealizedCastsImpl(
-      llvm::make_range(castOps.begin(), castOps.end()),
-      [&](UnrealizedConversionCastOp castOp) {
-        return castOps.contains(castOp);
-      },
-      remainingCastOps);
-}
-
-static void mlir::reconcileUnrealizedCasts(
-    const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationInfo>
-        &castOps,
-    SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
-  reconcileUnrealizedCastsImpl(
-      castOps.keys(),
-      [&](UnrealizedConversionCastOp castOp) {
-        return castOps.contains(castOp);
-      },
-      remainingCastOps);
 }
 
 
//===----------------------------------------------------------------------===//

diff  --git 
a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir 
b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index ac5ca321c066f..3573114f5e038 100644
--- 
a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ 
b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,53 +194,3 @@ func.func @emptyCast() -> index {
     %0 = builtin.unrealized_conversion_cast to index
     return %0 : index
 }
-
-// -----
-
-// CHECK-LABEL: test.graph_region
-//  CHECK-NEXT:   "test.return"() : () -> ()
-test.graph_region {
-  %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
-  %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
-  %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
-  "test.return"() : () -> ()
-}
-
-// -----
-
-// CHECK-LABEL: test.graph_region
-//  CHECK-NEXT:   %[[cast0:.*]] = builtin.unrealized_conversion_cast 
%[[cast2:.*]] : i32 to i64
-//  CHECK-NEXT:   %[[cast1:.*]] = builtin.unrealized_conversion_cast 
%[[cast0]] : i64 to i16
-//  CHECK-NEXT:   %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : 
i16 to i32
-//  CHECK-NEXT:   "test.user"(%[[cast2]]) : (i32) -> ()
-//  CHECK-NEXT:   "test.return"() : () -> ()
-test.graph_region {
-  %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
-  %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
-  %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
-  "test.user"(%2) : (i32) -> ()
-  "test.return"() : () -> ()
-}
-
-// -----
-
-// CHECK-LABEL: test.graph_region
-//  CHECK-NEXT:   "test.return"() : () -> ()
-test.graph_region {
-  %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
-  "test.return"() : () -> ()
-}
-
-// -----
-
-// CHECK-LABEL: test.graph_region
-//  CHECK-NEXT:   %[[c0:.*]] = arith.constant
-//  CHECK-NEXT:   %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], 
%[[cast]]#1 : i32, i32 to i32, i32
-//  CHECK-NEXT:   "test.user"(%[[cast]]#0) : (i32) -> ()
-//  CHECK-NEXT:   "test.return"() : () -> ()
-test.graph_region {
-  %cst = arith.constant 0 : i32
-  %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
-  "test.user"(%0) : (i32) -> ()
-  "test.return"() : () -> ()
-}

diff  --git 
a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
 
b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 01a826a638606..25a338df8d790 100644
--- 
a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ 
b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,8 +1,7 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
 // RUN:     -expand-strided-metadata \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-to-llvm \
-// RUN:     -reconcile-unrealized-casts | \
+// RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s

diff  --git 
a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 1144a7caf36e8..4c6a48d577a6c 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,7 +1,6 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-to-llvm \
-// RUN:     -reconcile-unrealized-casts | \
+// RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s

diff  --git 
a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir 
b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index 82e63805cd027..dd000c6904bcb 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,7 +1,6 @@
 // RUN: mlir-opt %s -generate-runtime-verification \
 // RUN:     -test-cf-assert \
-// RUN:     -convert-to-llvm \
-// RUN:     -reconcile-unrealized-casts | \
+// RUN:     -convert-to-llvm | \
 // RUN: mlir-runner -e main -entry-point-result=void \
 // RUN:     -shared-libs=%mlir_runner_utils 2>&1 | \
 // RUN: FileCheck %s


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to