https://github.com/matthias-springer updated 
https://github.com/llvm/llvm-project/pull/173505

>From 529a15981141a35691f5cddaa326b15d8da80aa8 Mon Sep 17 00:00:00 2001
From: Matthias Springer <[email protected]>
Date: Thu, 25 Dec 2025 12:55:26 +0000
Subject: [PATCH 1/2] [mlir][SCF] Fold unused `index_switch` results

---
 mlir/lib/Dialect/SCF/IR/SCF.cpp         | 52 ++++++++++++++++++++++++-
 mlir/test/Dialect/SCF/canonicalize.mlir | 31 +++++++++++++++
 2 files changed, 82 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 652414f6cbe54..0a123112cf68f 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -4711,9 +4711,59 @@ struct FoldConstantCase : 
OpRewritePattern<scf::IndexSwitchOp> {
   }
 };
 
+/// Canonicalization patterns that folds away dead results of
+/// "scf.index_switch" ops.
+struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> {
+  using OpRewritePattern<IndexSwitchOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(IndexSwitchOp op,
+                                PatternRewriter &rewriter) const override {
+    // Find dead results.
+    BitVector deadResults(op.getNumResults(), false);
+    SmallVector<Type> newResultTypes;
+    for (auto [idx, result] : llvm::enumerate(op.getResults())) {
+      if (!result.use_empty()) {
+        newResultTypes.push_back(result.getType());
+      } else {
+        deadResults[idx] = true;
+      }
+    }
+    if (!deadResults.any())
+      return rewriter.notifyMatchFailure(op, "no dead results to fold");
+
+    // Create new op without dead results and inline case regions.
+    auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes,
+                                       op.getArg(), op.getCases(),
+                                       op.getCaseRegions().size());
+    auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) {
+      rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin());
+      // Remove respective operands from yield op.
+      Operation *terminator = newRegion.front().getTerminator();
+      assert(isa<YieldOp>(terminator) && "expected yield op");
+      rewriter.modifyOpInPlace(
+          terminator, [&]() { terminator->eraseOperands(deadResults); });
+    };
+    for (auto [oldRegion, newRegion] :
+         llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions()))
+      inlineCaseRegion(oldRegion, newRegion);
+    inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion());
+
+    // Replace op with new op.
+    SmallVector<Value> newResults(op.getNumResults(), Value());
+    unsigned nextNewResult = 0;
+    for (unsigned idx = 0; idx < op.getNumResults(); ++idx) {
+      if (deadResults[idx])
+        continue;
+      newResults[idx] = newOp.getResult(nextNewResult++);
+    }
+    rewriter.replaceOp(op, newResults);
+    return success();
+  }
+};
+
 void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                                 MLIRContext *context) {
-  results.add<FoldConstantCase>(context);
+  results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context);
 }
 
 
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir 
b/mlir/test/Dialect/SCF/canonicalize.mlir
index ac590fc0c47b9..d5d0aee3bbe25 100644
--- a/mlir/test/Dialect/SCF/canonicalize.mlir
+++ b/mlir/test/Dialect/SCF/canonicalize.mlir
@@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0()  {
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @dead_index_switch_result(
+//  CHECK-SAME:     %[[arg0:.*]]: index
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10
+//   CHECK-DAG:   %[[c11:.*]] = arith.constant 11
+//       CHECK:   %[[switch:.*]] = scf.index_switch %[[arg0]] -> index
+//       CHECK:   case 1 {
+//       CHECK:     memref.store %[[c10]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   } 
+//       CHECK:   default {
+//       CHECK:     memref.store %[[c11]]
+//       CHECK:     scf.yield %[[arg0]] : index
+//       CHECK:   }
+//       CHECK:   return %[[switch]]
+func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> 
index {
+  %non_live, %live = scf.index_switch %arg0 -> i32, index
+  case 1 {
+    %c10 = arith.constant 10 : i32
+    memref.store %c10, %arg1[] : memref<i32>
+    scf.yield %c10, %arg0 : i32, index
+  }
+  default {
+    %c11 = arith.constant 11 : i32
+    memref.store %c11, %arg1[] : memref<i32>
+    scf.yield %c11, %arg0 : i32, index
+  }
+  return %live : index
+}

>From db702280cd8b6e4eaf0ea46bab9ea3473ad4c50e Mon Sep 17 00:00:00 2001
From: Matthias Springer <[email protected]>
Date: Wed, 24 Dec 2025 13:26:00 +0000
Subject: [PATCH 2/2] tmp commit

simple test working

draft: do not erase IR, just replace uses
---
 mlir/lib/Transforms/RemoveDeadValues.cpp     | 429 ++++++-------------
 mlir/test/Transforms/remove-dead-values.mlir | 128 ++++--
 2 files changed, 221 insertions(+), 336 deletions(-)

diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp 
b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 62ce5e0bbb77e..421fdb43a4554 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -96,6 +96,7 @@ struct OperandsToCleanup {
   BitVector nonLive;
   Operation *callee =
       nullptr; // Optional: For CallOpInterface ops, stores the callee function
+  bool replaceWithPoison = false;
 };
 
 struct BlockArgsToCleanup {
@@ -199,9 +200,9 @@ static void collectNonLiveValues(DenseSet<Value> 
&nonLiveSet, ValueRange range,
   }
 }
 
-/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
-/// is 1.
-static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
+/// Erase the i-th result of `op` iff toErase[i] is 1.
+static void eraseResults(RewriterBase &rewriter, Operation *op,
+                         BitVector toErase) {
   assert(op->getNumResults() == toErase.size() &&
          "expected the number of results in `op` and the size of `toErase` to "
          "be the same");
@@ -210,7 +211,6 @@ static void dropUsesAndEraseResults(Operation *op, 
BitVector toErase) {
   for (OpResult result : op->getResults())
     if (!toErase[result.getResultNumber()])
       newResultTypes.push_back(result.getType());
-  IRRewriter rewriter(op);
   rewriter.setInsertionPointAfter(op);
   OperationState state(op->getLoc(), op->getName().getStringRef(),
                        op->getOperands(), newResultTypes, op->getAttrs());
@@ -226,14 +226,12 @@ static void dropUsesAndEraseResults(Operation *op, 
BitVector toErase) {
   unsigned indexOfNextNewCallOpResultToReplace = 0;
   for (auto [index, result] : llvm::enumerate(op->getResults())) {
     assert(result && "expected result to be non-null");
-    if (toErase[index]) {
-      result.dropAllUses();
-    } else {
+    if (!toErase[index]) {
       result.replaceAllUsesWith(
           newOp->getResult(indexOfNextNewCallOpResultToReplace++));
     }
   }
-  op->erase();
+  rewriter.eraseOp(op);
 }
 
 /// Convert a list of `Operand`s to a list of `OpOperand`s.
@@ -448,277 +446,74 @@ static void 
processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
     return;
   }
 
-  // Mark live results of `regionBranchOp` in `liveResults`.
-  auto markLiveResults = [&](BitVector &liveResults) {
-    liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
-  };
-
-  // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
-  auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
-    for (Region &region : regionBranchOp->getRegions()) {
-      if (region.empty())
-        continue;
-      SmallVector<Value> arguments(region.front().getArguments());
-      BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
-      liveArgs[&region] = regionLiveArgs;
+  // Compute values that we definitely want to keep.
+  DenseSet<Value> valuesToKeep;
+  for (Value result : regionBranchOp->getResults()) {
+    if (hasLive(result, nonLiveSet, la))
+      valuesToKeep.insert(result);
+  }
+  for (Region &region : regionBranchOp->getRegions()) {
+    if (region.empty())
+      continue;
+    for (Value arg : region.front().getArguments()) {
+      if (hasLive(arg, nonLiveSet, la))
+        valuesToKeep.insert(arg);
     }
-  };
+  }
 
-  // Return the successors of `region` if the latter is not null. Else return
-  // the successors of `regionBranchOp`.
-  auto getSuccessors = [&](RegionBranchPoint point) {
+  // Mapping from operands to forwarded successor inputs. An operand can be
+  // forwarded to multiple successors.
+  DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
+  auto helper = [&](RegionBranchPoint point) {
     SmallVector<RegionSuccessor> successors;
     regionBranchOp.getSuccessorRegions(point, successors);
-    return successors;
-  };
-
-  // Return the operands of `terminator` that are forwarded to `successor` if
-  // the former is not null. Else return the operands of `regionBranchOp`
-  // forwarded to `successor`.
-  auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
-                                    Operation *terminator = nullptr) {
-    OperandRange operands =
-        terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
-                         .getSuccessorOperands(successor)
-                   : regionBranchOp.getEntrySuccessorOperands(successor);
-    SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
-    return opOperands;
-  };
-
-  // Mark the non-forwarded operands of `regionBranchOp` in
-  // `nonForwardedOperands`.
-  auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
-    nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
-    for (const RegionSuccessor &successor :
-         getSuccessors(RegionBranchPoint::parent())) {
-      for (OpOperand *opOperand : getForwardedOpOperands(successor))
-        nonForwardedOperands.reset(opOperand->getOperandNumber());
+    for (const RegionSuccessor &successor : successors) {
+      // Handle branch from point --> successor.
+      ValueRange argsOrResults = successor.getSuccessorInputs();
+      OperandRange operands =
+        point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor) 
: 
cast<RegionBranchTerminatorOpInterface>(point.getTerminatorPredecessorOrNull())
+                       .getSuccessorOperands(successor);
+      assert(argsOrResults.size() == operands.size() && "expected the same 
number of successor inputs as forwarded operands");
+
+      for (auto [opOperand, input] :
+           llvm::zip_equal(operandsToOpOperands(operands), argsOrResults)) {
+        operandToSuccessorInputs[opOperand].push_back(input);
+      }
     }
   };
 
-  // Mark the non-forwarded terminator operands of the various regions of
-  // `regionBranchOp` in `nonForwardedRets`.
-  auto markNonForwardedReturnValues =
-      [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
-        for (Region &region : regionBranchOp->getRegions()) {
-          if (region.empty())
-            continue;
-          // TODO: this isn't correct in face of multiple terminators.
-          Operation *terminator = region.front().getTerminator();
-          nonForwardedRets[terminator] =
-              BitVector(terminator->getNumOperands(), true);
-          for (const RegionSuccessor &successor :
-               getSuccessors(RegionBranchPoint(
-                   cast<RegionBranchTerminatorOpInterface>(terminator)))) {
-            for (OpOperand *opOperand :
-                 getForwardedOpOperands(successor, terminator))
-              
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
-          }
-        }
-      };
-
-  // Update `valuesToKeep` (which is expected to correspond to operands or
-  // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
-  // `region`. When `valuesToKeep` correspond to operands, `region` is null.
-  // Else, `region` is the parent region of the terminator.
-  auto updateOperandsOrTerminatorOperandsToKeep =
-      [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
-          DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) 
{
-        Operation *terminator =
-            region ? region->front().getTerminator() : nullptr;
-        RegionBranchPoint point =
-            terminator
-                ? RegionBranchPoint(
-                      cast<RegionBranchTerminatorOpInterface>(terminator))
-                : RegionBranchPoint::parent();
-
-        for (const RegionSuccessor &successor : getSuccessors(point)) {
-          Region *successorRegion = successor.getSuccessor();
-          for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(successor, terminator),
-                         successor.getSuccessorInputs())) {
-            size_t operandNum = opOperand->getOperandNumber();
-            bool updateBasedOn =
-                successorRegion
-                    ? argsToKeep[successorRegion]
-                                [cast<BlockArgument>(input).getArgNumber()]
-                    : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-            valuesToKeep[operandNum] = valuesToKeep[operandNum] | 
updateBasedOn;
-          }
-        }
-      };
-
-  // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
-  // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if 
a
-  // value is modified, else, false.
-  auto recomputeResultsAndArgsToKeep =
-      [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
-          BitVector &operandsToKeep,
-          DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
-          bool &resultsOrArgsToKeepChanged) {
-        resultsOrArgsToKeepChanged = false;
-
-        // Recompute `resultsToKeep` and `argsToKeep` based on 
`operandsToKeep`.
-        for (const RegionSuccessor &successor :
-             getSuccessors(RegionBranchPoint::parent())) {
-          Region *successorRegion = successor.getSuccessor();
-          for (auto [opOperand, input] :
-               llvm::zip(getForwardedOpOperands(successor),
-                         successor.getSuccessorInputs())) {
-            bool recomputeBasedOn =
-                operandsToKeep[opOperand->getOperandNumber()];
-            bool toRecompute =
-                successorRegion
-                    ? argsToKeep[successorRegion]
-                                [cast<BlockArgument>(input).getArgNumber()]
-                    : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-            if (!toRecompute && recomputeBasedOn)
-              resultsOrArgsToKeepChanged = true;
-            if (successorRegion) {
-              argsToKeep[successorRegion][cast<BlockArgument>(input)
-                                              .getArgNumber()] =
-                  argsToKeep[successorRegion]
-                            [cast<BlockArgument>(input).getArgNumber()] |
-                  recomputeBasedOn;
-            } else {
-              resultsToKeep[cast<OpResult>(input).getResultNumber()] =
-                  resultsToKeep[cast<OpResult>(input).getResultNumber()] |
-                  recomputeBasedOn;
-            }
-          }
-        }
-
-        // Recompute `resultsToKeep` and `argsToKeep` based on
-        // `terminatorOperandsToKeep`.
-        for (Region &region : regionBranchOp->getRegions()) {
-          if (region.empty())
-            continue;
-          Operation *terminator = region.front().getTerminator();
-          for (const RegionSuccessor &successor :
-               getSuccessors(RegionBranchPoint(
-                   cast<RegionBranchTerminatorOpInterface>(terminator)))) {
-            Region *successorRegion = successor.getSuccessor();
-            for (auto [opOperand, input] :
-                 llvm::zip(getForwardedOpOperands(successor, terminator),
-                           successor.getSuccessorInputs())) {
-              bool recomputeBasedOn =
-                  terminatorOperandsToKeep[region.back().getTerminator()]
-                                          [opOperand->getOperandNumber()];
-              bool toRecompute =
-                  successorRegion
-                      ? argsToKeep[successorRegion]
-                                  [cast<BlockArgument>(input).getArgNumber()]
-                      : resultsToKeep[cast<OpResult>(input).getResultNumber()];
-              if (!toRecompute && recomputeBasedOn)
-                resultsOrArgsToKeepChanged = true;
-              if (successorRegion) {
-                argsToKeep[successorRegion][cast<BlockArgument>(input)
-                                                .getArgNumber()] =
-                    argsToKeep[successorRegion]
-                              [cast<BlockArgument>(input).getArgNumber()] |
-                    recomputeBasedOn;
-              } else {
-                resultsToKeep[cast<OpResult>(input).getResultNumber()] =
-                    resultsToKeep[cast<OpResult>(input).getResultNumber()] |
-                    recomputeBasedOn;
-              }
-            }
-          }
-        }
-      };
-
-  // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
-  // `operandsToKeep`, and `terminatorOperandsToKeep`.
-  auto markValuesToKeep =
-      [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
-          BitVector &operandsToKeep,
-          DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
-        bool resultsOrArgsToKeepChanged = true;
-        // We keep updating and recomputing the values until we reach a point
-        // where they stop changing.
-        while (resultsOrArgsToKeepChanged) {
-          // Update the operands that need to be kept.
-          updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
-                                                   resultsToKeep, argsToKeep);
-
-          // Update the terminator operands that need to be kept.
-          for (Region &region : regionBranchOp->getRegions()) {
-            if (region.empty())
-              continue;
-            updateOperandsOrTerminatorOperandsToKeep(
-                terminatorOperandsToKeep[region.back().getTerminator()],
-                resultsToKeep, argsToKeep, &region);
-          }
-
-          // Recompute the results and arguments that need to be kept.
-          recomputeResultsAndArgsToKeep(
-              resultsToKeep, argsToKeep, operandsToKeep,
-              terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
-        }
-      };
-
-  // Scenario 2.
-  // At this point, we know that every non-forwarded operand of 
`regionBranchOp`
-  // is live.
-
-  // Stores the results of `regionBranchOp` that we want to keep.
-  BitVector resultsToKeep;
-  // Stores the mapping from regions of `regionBranchOp` to their arguments 
that
-  // we want to keep.
-  DenseMap<Region *, BitVector> argsToKeep;
-  // Stores the operands of `regionBranchOp` that we want to keep.
-  BitVector operandsToKeep;
-  // Stores the mapping from region terminators in `regionBranchOp` to their
-  // operands that we want to keep.
-  DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
-  // Initializing the above variables...
-
-  // The live results of `regionBranchOp` definitely need to be kept.
-  markLiveResults(resultsToKeep);
-  // Similarly, the live arguments of the regions in `regionBranchOp` 
definitely
-  // need to be kept.
-  markLiveArgs(argsToKeep);
-  // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
-  // A live forwarded operand can be removed but no non-forwarded operand can 
be
-  // removed since it "controls" the flow of data in this control flow op.
-  markNonForwardedOperands(operandsToKeep);
-  // Similarly, the non-forwarded terminator operands of the regions in
-  // `regionBranchOp` definitely need to be kept.
-  markNonForwardedReturnValues(terminatorOperandsToKeep);
-
-  // Mark the values (results, arguments, operands, and terminator operands)
-  // that we want to keep.
-  markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
-                   terminatorOperandsToKeep);
-
-  // Do (1).
-  cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
-
-  // Do (2.a) and (2.b).
+  // TODO: Add example.
+  helper(RegionBranchPoint::parent());
   for (Region &region : regionBranchOp->getRegions()) {
     if (region.empty())
       continue;
-    BitVector argsToRemove = argsToKeep[&region].flip();
-    cl.blocks.push_back({&region.front(), argsToRemove});
-    collectNonLiveValues(nonLiveSet, region.front().getArguments(),
-                         argsToRemove);
+    helper(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>(
+        region.front().getTerminator())));
   }
 
-  // Do (2.c).
-  for (Region &region : regionBranchOp->getRegions()) {
-    if (region.empty())
+  DenseMap<Operation *, BitVector> deadOperandsPerOp;
+  for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+    // If one of the successor inputs is live, the respective operand must be
+    // kept. In that case, all matching successor inputs must be kept.
+    bool anyAlive = llvm::any_of(successorInputs, [&](Value input) {
+      return valuesToKeep.contains(input);
+    });
+    if (anyAlive)
       continue;
-    Operation *terminator = region.front().getTerminator();
-    cl.operands.push_back(
-        {terminator, terminatorOperandsToKeep[terminator].flip()});
+
+    // All successor inputs are dead: ub.poison can be passed as operand.
+    BitVector &deadOperands =
+        deadOperandsPerOp
+            .try_emplace(opOperand->getOwner(),
+                         opOperand->getOwner()->getNumOperands(), false)
+            .first->second;
+    deadOperands.set(opOperand->getOperandNumber());
   }
 
-  // Do (3) and (4).
-  BitVector resultsToRemove = resultsToKeep.flip();
-  collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(),
-                       resultsToRemove);
-  cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove});
+  for (auto [op, deadOperands] : deadOperandsPerOp) {
+    cl.operands.push_back(
+        {op, deadOperands, nullptr, /*replaceWithPoison=*/true});
+  }
 }
 
 /// Steps to process a `BranchOpInterface` operation:
@@ -778,11 +573,37 @@ static void processBranchOp(BranchOpInterface branchOp, 
RunLivenessAnalysis &la,
   }
 }
 
+static SmallVector<Value> createPoisonedValues(OpBuilder &b,
+                                               ValueRange values) {
+  return llvm::map_to_vector(values, [&](Value value) {
+    if (value.getUses().empty())
+      return Value();
+    return ub::PoisonOp::create(b, value.getLoc(), 
value.getType()).getResult();
+  });
+}
+
+namespace {
+struct TrackingListener : public RewriterBase::Listener {
+  void notifyOperationErased(Operation *op) override {
+    if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+      poisonOps.erase(poisonOp);
+  }
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    if (auto poisonOp = dyn_cast<ub::PoisonOp>(op))
+      poisonOps.insert(poisonOp);
+  }
+  DenseSet<ub::PoisonOp> poisonOps;
+};
+} // namespace
 /// Removes dead values collected in RDVFinalCleanupList.
 /// To be run once when all dead values have been collected.
-static void cleanUpDeadVals(RDVFinalCleanupList &list) {
+static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) {
   LDBG() << "Starting cleanup of dead values...";
 
+  TrackingListener listener;
+  IRRewriter rewriter(ctx, &listener);
+
   // 1. Blocks, We must remove the block arguments and successor operands 
before
   // deleting the operation, as they may reside in the region operation.
   LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists";
@@ -800,10 +621,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     });
     // Note: Iterate from the end to make sure that that indices of not yet
     // processes arguments do not change.
+    rewriter.setInsertionPointToStart(b.b);
     for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) {
       if (!b.nonLiveArgs[i])
         continue;
-      b.b->getArgument(i).dropAllUses();
+      b.b->getArgument(i).replaceAllUsesWith(
+          createPoisonedValues(rewriter, b.b->getArgument(i)).front());
       b.b->eraseArgument(i);
     }
   }
@@ -832,22 +655,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     }
   }
 
-  // 3. Operations
-  LDBG() << "Cleaning up " << list.operations.size() << " operations";
-  for (Operation *op : list.operations) {
-    LDBG() << "Erasing operation: "
-           << OpWithFlags(op,
-                          
OpPrintingFlags().skipRegions().printGenericOpForm());
-    if (op->hasTrait<OpTrait::IsTerminator>()) {
-      // When erasing a terminator, insert an unreachable op in its place.
-      OpBuilder b(op);
-      ub::UnreachableOp::create(b, op->getLoc());
-    }
-    op->dropAllUses();
-    op->erase();
-  }
-
-  // 4. Functions
+  // 3. Functions
   LDBG() << "Cleaning up " << list.functions.size() << " functions";
   // Record which function arguments were erased so we can shrink call-site
   // argument segments for CallOpInterface operations (e.g. ops using
@@ -864,12 +672,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
       llvm::interleaveComma(f.nonLiveRets.set_bits(), os);
       os << "]";
     });
-    // Drop all uses of the dead arguments.
-    for (auto deadIdx : f.nonLiveArgs.set_bits())
-      f.funcOp.getArgument(deadIdx).dropAllUses();
     // Some functions may not allow erasing arguments or results. These calls
     // return failure in such cases without modifying the function, so it's 
okay
     // to proceed.
+    bool hasBody = !f.funcOp.getFunctionBody().empty();
+    if (hasBody) {
+      rewriter.setInsertionPointToStart(&f.funcOp.getFunctionBody().front());
+      for (auto deadIdx : f.nonLiveArgs.set_bits()) {
+        f.funcOp.getArgument(deadIdx).replaceAllUsesWith(
+            createPoisonedValues(rewriter, f.funcOp.getArgument(deadIdx))
+                .front());
+      }
+    }
     if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) {
       // Record only if we actually erased something.
       if (f.nonLiveArgs.any())
@@ -878,7 +692,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
     (void)f.funcOp.eraseResults(f.nonLiveRets);
   }
 
-  // 5. Operands
+  // 4. Operands
   LDBG() << "Cleaning up " << list.operands.size() << " operand lists";
   for (OperandsToCleanup &o : list.operands) {
     // Handle call-specific cleanup only when we have a cached callee 
reference.
@@ -923,11 +737,20 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
            << OpWithFlags(o.op,
                           
OpPrintingFlags().skipRegions().printGenericOpForm());
       });
-      o.op->eraseOperands(o.nonLive);
+      if (o.replaceWithPoison) {
+        rewriter.setInsertionPoint(o.op);
+        for (auto deadIdx : o.nonLive.set_bits()) {
+          o.op->setOperand(
+              deadIdx, createPoisonedValues(rewriter, 
o.op->getOperand(deadIdx))
+                           .front());
+        }
+      } else {
+        o.op->eraseOperands(o.nonLive);
+      }
     }
   }
 
-  // 6. Results
+  // 5. Results
   LDBG() << "Cleaning up " << list.results.size() << " result lists";
   for (auto &r : list.results) {
     LDBG_OS([&](raw_ostream &os) {
@@ -937,8 +760,34 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) {
          << OpWithFlags(r.op,
                         OpPrintingFlags().skipRegions().printGenericOpForm());
     });
-    dropUsesAndEraseResults(r.op, r.nonLive);
+    rewriter.setInsertionPoint(r.op);
+    for (auto deadIdx : r.nonLive.set_bits()) {
+      r.op->getResult(deadIdx).replaceAllUsesWith(
+          createPoisonedValues(rewriter, r.op->getResult(deadIdx)).front());
+    }
+    eraseResults(rewriter, r.op, r.nonLive);
+  }
+
+  // 6. Operations
+  LDBG() << "Cleaning up " << list.operations.size() << " operations";
+  for (Operation *op : list.operations) {
+    LDBG() << "Erasing operation: "
+           << OpWithFlags(op,
+                          
OpPrintingFlags().skipRegions().printGenericOpForm());
+    rewriter.setInsertionPoint(op);
+    if (op->hasTrait<OpTrait::IsTerminator>()) {
+      // When erasing a terminator, insert an unreachable op in its place.
+      ub::UnreachableOp::create(rewriter, op->getLoc());
+    }
+    rewriter.replaceOp(op, createPoisonedValues(rewriter, op->getResults()));
+  }
+
+  // 7. Remove all dead poison ops.
+  for (ub::PoisonOp poisonOp : listener.poisonOps) {
+    if (poisonOp.use_empty())
+      poisonOp.erase();
   }
+
   LDBG() << "Finished cleanup of dead values";
 }
 
@@ -977,7 +826,7 @@ void RemoveDeadValues::runOnOperation() {
     }
   });
 
-  cleanUpDeadVals(finalCleanupList);
+  cleanUpDeadVals(module->getContext(), finalCleanupList);
 }
 
 std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() {
diff --git a/mlir/test/Transforms/remove-dead-values.mlir 
b/mlir/test/Transforms/remove-dead-values.mlir
index 71306676d48e9..7bbb06c6488f0 100644
--- a/mlir/test/Transforms/remove-dead-values.mlir
+++ b/mlir/test/Transforms/remove-dead-values.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics 
| FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values -split-input-file | FileCheck %s
+// RUN: mlir-opt %s -remove-dead-values -canonicalize -split-input-file | 
FileCheck %s --check-prefix=CHECK-CANONICALIZE
 
 // The IR is updated regardless of memref.global private constant
 //
@@ -55,19 +56,20 @@ func.func 
@acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0:
 
 // Checking that iter_args are properly handled
 //
+// CHECK-CANONICALIZE-LABEL: func @cleanable_loop_iter_args_value
 func.func @cleanable_loop_iter_args_value(%arg0: index) -> index {
   %c0 = arith.constant 0 : index
   %c1 = arith.constant 1 : index
   %c10 = arith.constant 10 : index
   %non_live = arith.constant 0 : index
-  // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 
iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
+  // CHECK-CANONICALIZE: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 
step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) {
   %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 
iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) {
-    // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index
+    // CHECK-CANONICALIZE: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : 
index
     %new_live = arith.addi %live_arg, %i : index
-    // CHECK: scf.yield [[SUM:%.+]]
+    // CHECK-CANONICALIZE: scf.yield [[SUM:%.+]]
     scf.yield %new_live, %non_live_arg : index, index
   }
-  // CHECK: return [[RESULT]] : index
+  // CHECK-CANONICALIZE: return [[RESULT]] : index
   return %result : index
 }
 
@@ -79,7 +81,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> 
index {
 #map = affine_map<(d0, d1, d2) -> (0, d1, d2)>
 #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
 module {
-  func.func @main() {
+  // CHECK-LABEL: @dead_linalg_generic
+  func.func @dead_linalg_generic() {
     %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32>
     %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32>
     // CHECK-NOT: arith.constant
@@ -230,17 +233,17 @@ func.func @main() -> (i32, i32) {
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
-// CHECK:       func.func 
@clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]:
 i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
-// CHECK-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = 
%[[arg2]]) : (i32) -> (i32, i32) {
-// CHECK-NEXT:      %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]]
-// CHECK-NEXT:      scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32
-// CHECK-NEXT:    } do {
-// CHECK-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT:      %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]]
-// CHECK-NEXT:      scf.yield %[[live_1]] : i32
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return %[[live_and_non_live]]#0
-// CHECK-NEXT:  }
+// CHECK-CANONICALIZE:       func.func 
@clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]:
 i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 {
+// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while 
(%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT:      %[[live_0:.*]] = arith.addi %[[arg4]], 
%[[arg4]]
+// CHECK-CANONICALIZE-NEXT:      scf.condition(%arg0) %[[live_0]], %[[arg4]] : 
i32, i32
+// CHECK-CANONICALIZE-NEXT:    } do {
+// CHECK-CANONICALIZE-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT:      %[[live_1:.*]] = arith.addi %[[arg6]], 
%[[arg6]]
+// CHECK-CANONICALIZE-NEXT:      scf.yield %[[live_1]] : i32
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    return %[[live_and_non_live]]#0
+// CHECK-CANONICALIZE-NEXT:  }
 func.func 
@clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0:
 i1, %arg1: i32, %arg2: i32) -> (i32) {
   %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : 
(i32, i32) -> (i32, i32, i32) {
     %live_0 = arith.addi %arg4, %arg4 : i32
@@ -284,21 +287,21 @@ func.func 
@clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
-// CHECK:       func.func 
@clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]:
 i1) -> i32 {
-// CHECK-NEXT:    %[[c0:.*]] = arith.constant 0
-// CHECK-NEXT:    %[[c1:.*]] = arith.constant 1
-// CHECK-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = 
%[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
-// CHECK-NEXT:      func.call @identity() : () -> ()
-// CHECK-NEXT:      scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32
-// CHECK-NEXT:    } do {
-// CHECK-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
-// CHECK-NEXT:      scf.yield %[[arg5]], %[[arg6]] : i32, i32
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return %[[live_and_non_live]]#0 : i32
-// CHECK-NEXT:  }
-// CHECK:       func.func private @identity() {
-// CHECK-NEXT:    return
-// CHECK-NEXT:  }
+// CHECK-CANONICALIZE:       func.func 
@clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]:
 i1) -> i32 {
+// CHECK-CANONICALIZE-NEXT:    %[[c0:.*]] = arith.constant 0
+// CHECK-CANONICALIZE-NEXT:    %[[c1:.*]] = arith.constant 1
+// CHECK-CANONICALIZE-NEXT:    %[[live_and_non_live:.*]]:2 = scf.while 
(%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) {
+// CHECK-CANONICALIZE-NEXT:      func.call @identity() : () -> ()
+// CHECK-CANONICALIZE-NEXT:      scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] 
: i32, i32
+// CHECK-CANONICALIZE-NEXT:    } do {
+// CHECK-CANONICALIZE-NEXT:    ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32):
+// CHECK-CANONICALIZE-NEXT:      scf.yield %[[arg6]], %[[arg5]] : i32, i32
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    return %[[live_and_non_live]]#1 : i32
+// CHECK-CANONICALIZE-NEXT:  }
+// CHECK-CANONICALIZE:       func.func private @identity() {
+// CHECK-CANONICALIZE-NEXT:    return
+// CHECK-CANONICALIZE-NEXT:  }
 func.func 
@clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2:
 i1) -> (i32) {
   %c0 = arith.constant 0 : i32
   %c1 = arith.constant 1 : i32
@@ -325,17 +328,17 @@ func.func private @identity(%arg1 : i32) -> (i32) {
 //
 // Note that this cleanup cannot be done by the `canonicalize` pass.
 //
-// CHECK:       func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: 
index, %[[arg1:.*]]: memref<i32>) {
-// CHECK-NEXT:    scf.index_switch %[[arg0]]
-// CHECK-NEXT:    case 1 {
-// CHECK-NEXT:      %[[c10:.*]] = arith.constant 10
-// CHECK-NEXT:      memref.store %[[c10]], %[[arg1]][]
-// CHECK-NEXT:      scf.yield
-// CHECK-NEXT:    }
-// CHECK-NEXT:    default {
-// CHECK-NEXT:    }
-// CHECK-NEXT:    return
-// CHECK-NEXT:  }
+// CHECK-CANONICALIZE:       func.func 
@clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: 
memref<i32>) {
+// CHECK-CANONICALIZE-NEXT:      %[[c10:.*]] = arith.constant 10
+// CHECK-CANONICALIZE-NEXT:    scf.index_switch %[[arg0]]
+// CHECK-CANONICALIZE-NEXT:    case 1 {
+// CHECK-CANONICALIZE-NEXT:      memref.store %[[c10]], %[[arg1]][]
+// CHECK-CANONICALIZE-NEXT:      scf.yield
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    default {
+// CHECK-CANONICALIZE-NEXT:    }
+// CHECK-CANONICALIZE-NEXT:    return
+// CHECK-CANONICALIZE-NEXT:  }
 func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : 
memref<i32>) {
   %non_live = scf.index_switch %arg0 -> i32
   case 1 {
@@ -540,9 +543,9 @@ module {
 }
 
 // CHECK-LABEL: func @test_zero_operands
-// CHECK: memref.alloca_scope
-// CHECK: memref.store
-// CHECK-NOT: memref.alloca_scope.return
+// CaHECK: memref.alloca_scope
+// CaHECK: memref.store
+// CaHECK-NOT: memref.alloca_scope.return
 
 // -----
 
@@ -714,3 +717,36 @@ func.func private @remove_dead_branch_op(%c: i1, %arg0: 
i64, %arg1: i64) -> (i64
 ^bb2:
   return %arg1 : i64
 }
+
+// -----
+
+// CHECK-LABEL: func @scf_while_dead_iter_args()
+// CHECK: %[[c5:.*]] = arith.constant 5 : i32
+// CHECK: %[[while:.*]]:2 = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> 
(i32, i32) {
+// CHECK:   vector.print %[[arg0]]
+// CHECK:   %[[cmpi:.*]] = arith.cmpi
+// CHECK:   %[[p0:.*]] = ub.poison : i32
+// CHECK:   scf.condition(%[[cmpi]]) %[[arg0]], %[[p0]]
+// CHECK: } do {
+// CHECK: ^bb0(%[[arg1:.*]]: i32, %[[arg2:.*]]: i32):
+// CHECK:   %[[p1:.*]] = ub.poison : i32
+// CHECK:   scf.yield %[[p1]]
+// CHECK: }
+// CHECK: return %[[while]]#0
+func.func @scf_while_dead_iter_args() -> i32 {
+  %c5 = arith.constant 5 : i32
+  %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) {
+    vector.print %arg0 : i32
+    // Note: This condition is always "false". (And the liveness analysis
+    // can figure that out.)
+    %cmp2 = arith.cmpi slt, %arg0, %c5 : i32
+    scf.condition(%cmp2) %arg0, %arg0 : i32, i32
+  } do {
+  ^bb0(%arg1: i32, %arg2: i32):
+    %x = scf.execute_region -> i32 {
+      scf.yield %arg2 : i32
+    }
+    scf.yield %x : i32
+  }
+  return %result#0 : i32
+}

_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to