https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173505
>From 529a15981141a35691f5cddaa326b15d8da80aa8 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Thu, 25 Dec 2025 12:55:26 +0000 Subject: [PATCH 1/2] [mlir][SCF] Fold unused `index_switch` results --- mlir/lib/Dialect/SCF/IR/SCF.cpp | 52 ++++++++++++++++++++++++- mlir/test/Dialect/SCF/canonicalize.mlir | 31 +++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 652414f6cbe54..0a123112cf68f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -4711,9 +4711,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { } }; +/// Canonicalization patterns that folds away dead results of +/// "scf.index_switch" ops. +struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { + using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexSwitchOp op, + PatternRewriter &rewriter) const override { + // Find dead results. + BitVector deadResults(op.getNumResults(), false); + SmallVector<Type> newResultTypes; + for (auto [idx, result] : llvm::enumerate(op.getResults())) { + if (!result.use_empty()) { + newResultTypes.push_back(result.getType()); + } else { + deadResults[idx] = true; + } + } + if (!deadResults.any()) + return rewriter.notifyMatchFailure(op, "no dead results to fold"); + + // Create new op without dead results and inline case regions. + auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes, + op.getArg(), op.getCases(), + op.getCaseRegions().size()); + auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) { + rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); + // Remove respective operands from yield op. + Operation *terminator = newRegion.front().getTerminator(); + assert(isa<YieldOp>(terminator) && "expected yield op"); + rewriter.modifyOpInPlace( + terminator, [&]() { terminator->eraseOperands(deadResults); }); + }; + for (auto [oldRegion, newRegion] : + llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions())) + inlineCaseRegion(oldRegion, newRegion); + inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion()); + + // Replace op with new op. + SmallVector<Value> newResults(op.getNumResults(), Value()); + unsigned nextNewResult = 0; + for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { + if (deadResults[idx]) + continue; + newResults[idx] = newOp.getResult(nextNewResult++); + } + rewriter.replaceOp(op, newResults); + return success(); + } +}; + void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<FoldConstantCase>(context); + results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index ac590fc0c47b9..d5d0aee3bbe25 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0() { } return } + +// ----- + +// CHECK-LABEL: func @dead_index_switch_result( +// CHECK-SAME: %[[arg0:.*]]: index +// CHECK-DAG: %[[c10:.*]] = arith.constant 10 +// CHECK-DAG: %[[c11:.*]] = arith.constant 11 +// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index +// CHECK: case 1 { +// CHECK: memref.store %[[c10]] +// CHECK: scf.yield %[[arg0]] : index +// CHECK: } +// CHECK: default { +// CHECK: memref.store %[[c11]] +// CHECK: scf.yield %[[arg0]] : index +// CHECK: } +// CHECK: return %[[switch]] +func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index { + %non_live, %live = scf.index_switch %arg0 -> i32, index + case 1 { + %c10 = arith.constant 10 : i32 + memref.store %c10, %arg1[] : memref<i32> + scf.yield %c10, %arg0 : i32, index + } + default { + %c11 = arith.constant 11 : i32 + memref.store %c11, %arg1[] : memref<i32> + scf.yield %c11, %arg0 : i32, index + } + return %live : index +} >From db702280cd8b6e4eaf0ea46bab9ea3473ad4c50e Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Wed, 24 Dec 2025 13:26:00 +0000 Subject: [PATCH 2/2] tmp commit simple test working draft: do not erase IR, just replace uses --- mlir/lib/Transforms/RemoveDeadValues.cpp | 429 ++++++------------- mlir/test/Transforms/remove-dead-values.mlir | 128 ++++-- 2 files changed, 221 insertions(+), 336 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 62ce5e0bbb77e..421fdb43a4554 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -96,6 +96,7 @@ struct OperandsToCleanup { BitVector nonLive; Operation *callee = nullptr; // Optional: For CallOpInterface ops, stores the callee function + bool replaceWithPoison = false; }; struct BlockArgsToCleanup { @@ -199,9 +200,9 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range, } } -/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] -/// is 1. -static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { +/// Erase the i-th result of `op` iff toErase[i] is 1. +static void eraseResults(RewriterBase &rewriter, Operation *op, + BitVector toErase) { assert(op->getNumResults() == toErase.size() && "expected the number of results in `op` and the size of `toErase` to " "be the same"); @@ -210,7 +211,6 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { for (OpResult result : op->getResults()) if (!toErase[result.getResultNumber()]) newResultTypes.push_back(result.getType()); - IRRewriter rewriter(op); rewriter.setInsertionPointAfter(op); OperationState state(op->getLoc(), op->getName().getStringRef(), op->getOperands(), newResultTypes, op->getAttrs()); @@ -226,14 +226,12 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { unsigned indexOfNextNewCallOpResultToReplace = 0; for (auto [index, result] : llvm::enumerate(op->getResults())) { assert(result && "expected result to be non-null"); - if (toErase[index]) { - result.dropAllUses(); - } else { + if (!toErase[index]) { result.replaceAllUsesWith( newOp->getResult(indexOfNextNewCallOpResultToReplace++)); } } - op->erase(); + rewriter.eraseOp(op); } /// Convert a list of `Operand`s to a list of `OpOperand`s. @@ -448,277 +446,74 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, return; } - // Mark live results of `regionBranchOp` in `liveResults`. - auto markLiveResults = [&](BitVector &liveResults) { - liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); - }; - - // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. - auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) { - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - SmallVector<Value> arguments(region.front().getArguments()); - BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); - liveArgs[®ion] = regionLiveArgs; + // Compute values that we definitely want to keep. + DenseSet<Value> valuesToKeep; + for (Value result : regionBranchOp->getResults()) { + if (hasLive(result, nonLiveSet, la)) + valuesToKeep.insert(result); + } + for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) + continue; + for (Value arg : region.front().getArguments()) { + if (hasLive(arg, nonLiveSet, la)) + valuesToKeep.insert(arg); } - }; + } - // Return the successors of `region` if the latter is not null. Else return - // the successors of `regionBranchOp`. - auto getSuccessors = [&](RegionBranchPoint point) { + // Mapping from operands to forwarded successor inputs. An operand can be + // forwarded to multiple successors. + DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs; + auto helper = [&](RegionBranchPoint point) { SmallVector<RegionSuccessor> successors; regionBranchOp.getSuccessorRegions(point, successors); - return successors; - }; - - // Return the operands of `terminator` that are forwarded to `successor` if - // the former is not null. Else return the operands of `regionBranchOp` - // forwarded to `successor`. - auto getForwardedOpOperands = [&](const RegionSuccessor &successor, - Operation *terminator = nullptr) { - OperandRange operands = - terminator ? cast<RegionBranchTerminatorOpInterface>(terminator) - .getSuccessorOperands(successor) - : regionBranchOp.getEntrySuccessorOperands(successor); - SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands); - return opOperands; - }; - - // Mark the non-forwarded operands of `regionBranchOp` in - // `nonForwardedOperands`. - auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { - nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { - for (OpOperand *opOperand : getForwardedOpOperands(successor)) - nonForwardedOperands.reset(opOperand->getOperandNumber()); + for (const RegionSuccessor &successor : successors) { + // Handle branch from point --> successor. + ValueRange argsOrResults = successor.getSuccessorInputs(); + OperandRange operands = + point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor) : cast<RegionBranchTerminatorOpInterface>(point.getTerminatorPredecessorOrNull()) + .getSuccessorOperands(successor); + assert(argsOrResults.size() == operands.size() && "expected the same number of successor inputs as forwarded operands"); + + for (auto [opOperand, input] : + llvm::zip_equal(operandsToOpOperands(operands), argsOrResults)) { + operandToSuccessorInputs[opOperand].push_back(input); + } } }; - // Mark the non-forwarded terminator operands of the various regions of - // `regionBranchOp` in `nonForwardedRets`. - auto markNonForwardedReturnValues = - [&](DenseMap<Operation *, BitVector> &nonForwardedRets) { - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - // TODO: this isn't correct in face of multiple terminators. - Operation *terminator = region.front().getTerminator(); - nonForwardedRets[terminator] = - BitVector(terminator->getNumOperands(), true); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint( - cast<RegionBranchTerminatorOpInterface>(terminator)))) { - for (OpOperand *opOperand : - getForwardedOpOperands(successor, terminator)) - nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); - } - } - }; - - // Update `valuesToKeep` (which is expected to correspond to operands or - // terminator operands) based on `resultsToKeep` and `argsToKeep`, given - // `region`. When `valuesToKeep` correspond to operands, `region` is null. - // Else, `region` is the parent region of the terminator. - auto updateOperandsOrTerminatorOperandsToKeep = - [&](BitVector &valuesToKeep, BitVector &resultsToKeep, - DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) { - Operation *terminator = - region ? region->front().getTerminator() : nullptr; - RegionBranchPoint point = - terminator - ? RegionBranchPoint( - cast<RegionBranchTerminatorOpInterface>(terminator)) - : RegionBranchPoint::parent(); - - for (const RegionSuccessor &successor : getSuccessors(point)) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(successor, terminator), - successor.getSuccessorInputs())) { - size_t operandNum = opOperand->getOperandNumber(); - bool updateBasedOn = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn; - } - } - }; - - // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and - // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a - // value is modified, else, false. - auto recomputeResultsAndArgsToKeep = - [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, - BitVector &operandsToKeep, - DenseMap<Operation *, BitVector> &terminatorOperandsToKeep, - bool &resultsOrArgsToKeepChanged) { - resultsOrArgsToKeepChanged = false; - - // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(successor), - successor.getSuccessorInputs())) { - bool recomputeBasedOn = - operandsToKeep[opOperand->getOperandNumber()]; - bool toRecompute = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - if (!toRecompute && recomputeBasedOn) - resultsOrArgsToKeepChanged = true; - if (successorRegion) { - argsToKeep[successorRegion][cast<BlockArgument>(input) - .getArgNumber()] = - argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] | - recomputeBasedOn; - } else { - resultsToKeep[cast<OpResult>(input).getResultNumber()] = - resultsToKeep[cast<OpResult>(input).getResultNumber()] | - recomputeBasedOn; - } - } - } - - // Recompute `resultsToKeep` and `argsToKeep` based on - // `terminatorOperandsToKeep`. - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - Operation *terminator = region.front().getTerminator(); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint( - cast<RegionBranchTerminatorOpInterface>(terminator)))) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(successor, terminator), - successor.getSuccessorInputs())) { - bool recomputeBasedOn = - terminatorOperandsToKeep[region.back().getTerminator()] - [opOperand->getOperandNumber()]; - bool toRecompute = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - if (!toRecompute && recomputeBasedOn) - resultsOrArgsToKeepChanged = true; - if (successorRegion) { - argsToKeep[successorRegion][cast<BlockArgument>(input) - .getArgNumber()] = - argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] | - recomputeBasedOn; - } else { - resultsToKeep[cast<OpResult>(input).getResultNumber()] = - resultsToKeep[cast<OpResult>(input).getResultNumber()] | - recomputeBasedOn; - } - } - } - } - }; - - // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`, - // `operandsToKeep`, and `terminatorOperandsToKeep`. - auto markValuesToKeep = - [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, - BitVector &operandsToKeep, - DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) { - bool resultsOrArgsToKeepChanged = true; - // We keep updating and recomputing the values until we reach a point - // where they stop changing. - while (resultsOrArgsToKeepChanged) { - // Update the operands that need to be kept. - updateOperandsOrTerminatorOperandsToKeep(operandsToKeep, - resultsToKeep, argsToKeep); - - // Update the terminator operands that need to be kept. - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - updateOperandsOrTerminatorOperandsToKeep( - terminatorOperandsToKeep[region.back().getTerminator()], - resultsToKeep, argsToKeep, ®ion); - } - - // Recompute the results and arguments that need to be kept. - recomputeResultsAndArgsToKeep( - resultsToKeep, argsToKeep, operandsToKeep, - terminatorOperandsToKeep, resultsOrArgsToKeepChanged); - } - }; - - // Scenario 2. - // At this point, we know that every non-forwarded operand of `regionBranchOp` - // is live. - - // Stores the results of `regionBranchOp` that we want to keep. - BitVector resultsToKeep; - // Stores the mapping from regions of `regionBranchOp` to their arguments that - // we want to keep. - DenseMap<Region *, BitVector> argsToKeep; - // Stores the operands of `regionBranchOp` that we want to keep. - BitVector operandsToKeep; - // Stores the mapping from region terminators in `regionBranchOp` to their - // operands that we want to keep. - DenseMap<Operation *, BitVector> terminatorOperandsToKeep; - - // Initializing the above variables... - - // The live results of `regionBranchOp` definitely need to be kept. - markLiveResults(resultsToKeep); - // Similarly, the live arguments of the regions in `regionBranchOp` definitely - // need to be kept. - markLiveArgs(argsToKeep); - // The non-forwarded operands of `regionBranchOp` definitely need to be kept. - // A live forwarded operand can be removed but no non-forwarded operand can be - // removed since it "controls" the flow of data in this control flow op. - markNonForwardedOperands(operandsToKeep); - // Similarly, the non-forwarded terminator operands of the regions in - // `regionBranchOp` definitely need to be kept. - markNonForwardedReturnValues(terminatorOperandsToKeep); - - // Mark the values (results, arguments, operands, and terminator operands) - // that we want to keep. - markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep, - terminatorOperandsToKeep); - - // Do (1). - cl.operands.push_back({regionBranchOp, operandsToKeep.flip()}); - - // Do (2.a) and (2.b). + // TODO: Add example. + helper(RegionBranchPoint::parent()); for (Region ®ion : regionBranchOp->getRegions()) { if (region.empty()) continue; - BitVector argsToRemove = argsToKeep[®ion].flip(); - cl.blocks.push_back({®ion.front(), argsToRemove}); - collectNonLiveValues(nonLiveSet, region.front().getArguments(), - argsToRemove); + helper(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>( + region.front().getTerminator()))); } - // Do (2.c). - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) + DenseMap<Operation *, BitVector> deadOperandsPerOp; + for (auto [opOperand, successorInputs] : operandToSuccessorInputs) { + // If one of the successor inputs is live, the respective operand must be + // kept. In that case, all matching successor inputs must be kept. + bool anyAlive = llvm::any_of(successorInputs, [&](Value input) { + return valuesToKeep.contains(input); + }); + if (anyAlive) continue; - Operation *terminator = region.front().getTerminator(); - cl.operands.push_back( - {terminator, terminatorOperandsToKeep[terminator].flip()}); + + // All successor inputs are dead: ub.poison can be passed as operand. + BitVector &deadOperands = + deadOperandsPerOp + .try_emplace(opOperand->getOwner(), + opOperand->getOwner()->getNumOperands(), false) + .first->second; + deadOperands.set(opOperand->getOperandNumber()); } - // Do (3) and (4). - BitVector resultsToRemove = resultsToKeep.flip(); - collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(), - resultsToRemove); - cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove}); + for (auto [op, deadOperands] : deadOperandsPerOp) { + cl.operands.push_back( + {op, deadOperands, nullptr, /*replaceWithPoison=*/true}); + } } /// Steps to process a `BranchOpInterface` operation: @@ -778,11 +573,37 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, } } +static SmallVector<Value> createPoisonedValues(OpBuilder &b, + ValueRange values) { + return llvm::map_to_vector(values, [&](Value value) { + if (value.getUses().empty()) + return Value(); + return ub::PoisonOp::create(b, value.getLoc(), value.getType()).getResult(); + }); +} + +namespace { +struct TrackingListener : public RewriterBase::Listener { + void notifyOperationErased(Operation *op) override { + if (auto poisonOp = dyn_cast<ub::PoisonOp>(op)) + poisonOps.erase(poisonOp); + } + void notifyOperationInserted(Operation *op, + OpBuilder::InsertPoint previous) override { + if (auto poisonOp = dyn_cast<ub::PoisonOp>(op)) + poisonOps.insert(poisonOp); + } + DenseSet<ub::PoisonOp> poisonOps; +}; +} // namespace /// Removes dead values collected in RDVFinalCleanupList. /// To be run once when all dead values have been collected. -static void cleanUpDeadVals(RDVFinalCleanupList &list) { +static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { LDBG() << "Starting cleanup of dead values..."; + TrackingListener listener; + IRRewriter rewriter(ctx, &listener); + // 1. Blocks, We must remove the block arguments and successor operands before // deleting the operation, as they may reside in the region operation. LDBG() << "Cleaning up " << list.blocks.size() << " block argument lists"; @@ -800,10 +621,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { }); // Note: Iterate from the end to make sure that that indices of not yet // processes arguments do not change. + rewriter.setInsertionPointToStart(b.b); for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { if (!b.nonLiveArgs[i]) continue; - b.b->getArgument(i).dropAllUses(); + b.b->getArgument(i).replaceAllUsesWith( + createPoisonedValues(rewriter, b.b->getArgument(i)).front()); b.b->eraseArgument(i); } } @@ -832,22 +655,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } } - // 3. Operations - LDBG() << "Cleaning up " << list.operations.size() << " operations"; - for (Operation *op : list.operations) { - LDBG() << "Erasing operation: " - << OpWithFlags(op, - OpPrintingFlags().skipRegions().printGenericOpForm()); - if (op->hasTrait<OpTrait::IsTerminator>()) { - // When erasing a terminator, insert an unreachable op in its place. - OpBuilder b(op); - ub::UnreachableOp::create(b, op->getLoc()); - } - op->dropAllUses(); - op->erase(); - } - - // 4. Functions + // 3. Functions LDBG() << "Cleaning up " << list.functions.size() << " functions"; // Record which function arguments were erased so we can shrink call-site // argument segments for CallOpInterface operations (e.g. ops using @@ -864,12 +672,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { llvm::interleaveComma(f.nonLiveRets.set_bits(), os); os << "]"; }); - // Drop all uses of the dead arguments. - for (auto deadIdx : f.nonLiveArgs.set_bits()) - f.funcOp.getArgument(deadIdx).dropAllUses(); // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. + bool hasBody = !f.funcOp.getFunctionBody().empty(); + if (hasBody) { + rewriter.setInsertionPointToStart(&f.funcOp.getFunctionBody().front()); + for (auto deadIdx : f.nonLiveArgs.set_bits()) { + f.funcOp.getArgument(deadIdx).replaceAllUsesWith( + createPoisonedValues(rewriter, f.funcOp.getArgument(deadIdx)) + .front()); + } + } if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) { // Record only if we actually erased something. if (f.nonLiveArgs.any()) @@ -878,7 +692,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { (void)f.funcOp.eraseResults(f.nonLiveRets); } - // 5. Operands + // 4. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperandsToCleanup &o : list.operands) { // Handle call-specific cleanup only when we have a cached callee reference. @@ -923,11 +737,20 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(o.op, OpPrintingFlags().skipRegions().printGenericOpForm()); }); - o.op->eraseOperands(o.nonLive); + if (o.replaceWithPoison) { + rewriter.setInsertionPoint(o.op); + for (auto deadIdx : o.nonLive.set_bits()) { + o.op->setOperand( + deadIdx, createPoisonedValues(rewriter, o.op->getOperand(deadIdx)) + .front()); + } + } else { + o.op->eraseOperands(o.nonLive); + } } } - // 6. Results + // 5. Results LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { LDBG_OS([&](raw_ostream &os) { @@ -937,8 +760,34 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(r.op, OpPrintingFlags().skipRegions().printGenericOpForm()); }); - dropUsesAndEraseResults(r.op, r.nonLive); + rewriter.setInsertionPoint(r.op); + for (auto deadIdx : r.nonLive.set_bits()) { + r.op->getResult(deadIdx).replaceAllUsesWith( + createPoisonedValues(rewriter, r.op->getResult(deadIdx)).front()); + } + eraseResults(rewriter, r.op, r.nonLive); + } + + // 6. Operations + LDBG() << "Cleaning up " << list.operations.size() << " operations"; + for (Operation *op : list.operations) { + LDBG() << "Erasing operation: " + << OpWithFlags(op, + OpPrintingFlags().skipRegions().printGenericOpForm()); + rewriter.setInsertionPoint(op); + if (op->hasTrait<OpTrait::IsTerminator>()) { + // When erasing a terminator, insert an unreachable op in its place. + ub::UnreachableOp::create(rewriter, op->getLoc()); + } + rewriter.replaceOp(op, createPoisonedValues(rewriter, op->getResults())); + } + + // 7. Remove all dead poison ops. + for (ub::PoisonOp poisonOp : listener.poisonOps) { + if (poisonOp.use_empty()) + poisonOp.erase(); } + LDBG() << "Finished cleanup of dead values"; } @@ -977,7 +826,7 @@ void RemoveDeadValues::runOnOperation() { } }); - cleanUpDeadVals(finalCleanupList); + cleanUpDeadVals(module->getContext(), finalCleanupList); } std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() { diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 71306676d48e9..7bbb06c6488f0 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -1,4 +1,5 @@ -// RUN: mlir-opt %s -remove-dead-values -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -remove-dead-values -split-input-file | FileCheck %s +// RUN: mlir-opt %s -remove-dead-values -canonicalize -split-input-file | FileCheck %s --check-prefix=CHECK-CANONICALIZE // The IR is updated regardless of memref.global private constant // @@ -55,19 +56,20 @@ func.func @acceptable_ir_has_cleanable_loop_of_conditional_and_branch_op(%arg0: // Checking that iter_args are properly handled // +// CHECK-CANONICALIZE-LABEL: func @cleanable_loop_iter_args_value func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index %non_live = arith.constant 0 : index - // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { + // CHECK-CANONICALIZE: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) { - // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index + // CHECK-CANONICALIZE: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index %new_live = arith.addi %live_arg, %i : index - // CHECK: scf.yield [[SUM:%.+]] + // CHECK-CANONICALIZE: scf.yield [[SUM:%.+]] scf.yield %new_live, %non_live_arg : index, index } - // CHECK: return [[RESULT]] : index + // CHECK-CANONICALIZE: return [[RESULT]] : index return %result : index } @@ -79,7 +81,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { #map = affine_map<(d0, d1, d2) -> (0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> module { - func.func @main() { + // CHECK-LABEL: @dead_linalg_generic + func.func @dead_linalg_generic() { %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32> %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32> // CHECK-NOT: arith.constant @@ -230,17 +233,17 @@ func.func @main() -> (i32, i32) { // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { -// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { -// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] -// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): -// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] -// CHECK-NEXT: scf.yield %[[live_1]] : i32 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[live_and_non_live]]#0 -// CHECK-NEXT: } +// CHECK-CANONICALIZE: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { +// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { +// CHECK-CANONICALIZE-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] +// CHECK-CANONICALIZE-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 +// CHECK-CANONICALIZE-NEXT: } do { +// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CHECK-CANONICALIZE-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] +// CHECK-CANONICALIZE-NEXT: scf.yield %[[live_1]] : i32 +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#0 +// CHECK-CANONICALIZE-NEXT: } func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) { %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) { %live_0 = arith.addi %arg4, %arg4 : i32 @@ -284,21 +287,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { -// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 -// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 -// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { -// CHECK-NEXT: func.call @identity() : () -> () -// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): -// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32 -// CHECK-NEXT: } -// CHECK: func.func private @identity() { -// CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { +// CHECK-CANONICALIZE-NEXT: %[[c0:.*]] = arith.constant 0 +// CHECK-CANONICALIZE-NEXT: %[[c1:.*]] = arith.constant 1 +// CHECK-CANONICALIZE-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { +// CHECK-CANONICALIZE-NEXT: func.call @identity() : () -> () +// CHECK-CANONICALIZE-NEXT: scf.condition(%[[arg2]]) %[[arg3]], %[[arg4]] : i32, i32 +// CHECK-CANONICALIZE-NEXT: } do { +// CHECK-CANONICALIZE-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CHECK-CANONICALIZE-NEXT: scf.yield %[[arg6]], %[[arg5]] : i32, i32 +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE-NEXT: return %[[live_and_non_live]]#1 : i32 +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE: func.func private @identity() { +// CHECK-CANONICALIZE-NEXT: return +// CHECK-CANONICALIZE-NEXT: } func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 @@ -325,17 +328,17 @@ func.func private @identity(%arg1 : i32) -> (i32) { // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) { -// CHECK-NEXT: scf.index_switch %[[arg0]] -// CHECK-NEXT: case 1 { -// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 -// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][] -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: default { -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } +// CHECK-CANONICALIZE: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) { +// CHECK-CANONICALIZE-NEXT: %[[c10:.*]] = arith.constant 10 +// CHECK-CANONICALIZE-NEXT: scf.index_switch %[[arg0]] +// CHECK-CANONICALIZE-NEXT: case 1 { +// CHECK-CANONICALIZE-NEXT: memref.store %[[c10]], %[[arg1]][] +// CHECK-CANONICALIZE-NEXT: scf.yield +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE-NEXT: default { +// CHECK-CANONICALIZE-NEXT: } +// CHECK-CANONICALIZE-NEXT: return +// CHECK-CANONICALIZE-NEXT: } func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) { %non_live = scf.index_switch %arg0 -> i32 case 1 { @@ -540,9 +543,9 @@ module { } // CHECK-LABEL: func @test_zero_operands -// CHECK: memref.alloca_scope -// CHECK: memref.store -// CHECK-NOT: memref.alloca_scope.return +// CaHECK: memref.alloca_scope +// CaHECK: memref.store +// CaHECK-NOT: memref.alloca_scope.return // ----- @@ -714,3 +717,36 @@ func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64 ^bb2: return %arg1 : i64 } + +// ----- + +// CHECK-LABEL: func @scf_while_dead_iter_args() +// CHECK: %[[c5:.*]] = arith.constant 5 : i32 +// CHECK: %[[while:.*]]:2 = scf.while (%[[arg0:.*]] = %[[c5]]) : (i32) -> (i32, i32) { +// CHECK: vector.print %[[arg0]] +// CHECK: %[[cmpi:.*]] = arith.cmpi +// CHECK: %[[p0:.*]] = ub.poison : i32 +// CHECK: scf.condition(%[[cmpi]]) %[[arg0]], %[[p0]] +// CHECK: } do { +// CHECK: ^bb0(%[[arg1:.*]]: i32, %[[arg2:.*]]: i32): +// CHECK: %[[p1:.*]] = ub.poison : i32 +// CHECK: scf.yield %[[p1]] +// CHECK: } +// CHECK: return %[[while]]#0 +func.func @scf_while_dead_iter_args() -> i32 { + %c5 = arith.constant 5 : i32 + %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) { + vector.print %arg0 : i32 + // Note: This condition is always "false". (And the liveness analysis + // can figure that out.) + %cmp2 = arith.cmpi slt, %arg0, %c5 : i32 + scf.condition(%cmp2) %arg0, %arg0 : i32, i32 + } do { + ^bb0(%arg1: i32, %arg2: i32): + %x = scf.execute_region -> i32 { + scf.yield %arg2 : i32 + } + scf.yield %x : i32 + } + return %result#0 : i32 +} _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
