https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/173505
>From 63d57b9e57b3460f0d361b496b982c9c24fcc463 Mon Sep 17 00:00:00 2001 From: Matthias Springer <[email protected]> Date: Wed, 24 Dec 2025 13:26:00 +0000 Subject: [PATCH] tmp commit simple test working draft: do not erase IR, just replace uses --- mlir/lib/Transforms/RemoveDeadValues.cpp | 404 ++++++------------- mlir/test/Transforms/remove-dead-values.mlir | 111 ++--- 2 files changed, 182 insertions(+), 333 deletions(-) diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 62ce5e0bbb77e..12009fcddb782 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -96,6 +96,7 @@ struct OperandsToCleanup { BitVector nonLive; Operation *callee = nullptr; // Optional: For CallOpInterface ops, stores the callee function + bool replaceWithPoison = false; }; struct BlockArgsToCleanup { @@ -199,9 +200,8 @@ static void collectNonLiveValues(DenseSet<Value> &nonLiveSet, ValueRange range, } } -/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i] -/// is 1. -static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { +/// Erase the i-th result of `op` iff toErase[i] is 1. +static void eraseResults(Operation *op, BitVector toErase) { assert(op->getNumResults() == toErase.size() && "expected the number of results in `op` and the size of `toErase` to " "be the same"); @@ -226,9 +226,7 @@ static void dropUsesAndEraseResults(Operation *op, BitVector toErase) { unsigned indexOfNextNewCallOpResultToReplace = 0; for (auto [index, result] : llvm::enumerate(op->getResults())) { assert(result && "expected result to be non-null"); - if (toErase[index]) { - result.dropAllUses(); - } else { + if (!toErase[index]) { result.replaceAllUsesWith( newOp->getResult(indexOfNextNewCallOpResultToReplace++)); } @@ -448,277 +446,74 @@ static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp, return; } - // Mark live results of `regionBranchOp` in `liveResults`. - auto markLiveResults = [&](BitVector &liveResults) { - liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la); - }; - - // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`. - auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) { - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - SmallVector<Value> arguments(region.front().getArguments()); - BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la); - liveArgs[®ion] = regionLiveArgs; + // Compute values that we definitely want to keep. + DenseSet<Value> valuesToKeep; + for (Value result : regionBranchOp->getResults()) { + if (hasLive(result, nonLiveSet, la)) + valuesToKeep.insert(result); + } + for (Region ®ion : regionBranchOp->getRegions()) { + if (region.empty()) + continue; + for (Value arg : region.front().getArguments()) { + if (hasLive(arg, nonLiveSet, la)) + valuesToKeep.insert(arg); } - }; + } - // Return the successors of `region` if the latter is not null. Else return - // the successors of `regionBranchOp`. - auto getSuccessors = [&](RegionBranchPoint point) { + // Mapping from operands to forwarded successor inputs. An operand can be + // forwarded to multiple successors. + DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs; + auto helper = [&](RegionBranchPoint point) { SmallVector<RegionSuccessor> successors; regionBranchOp.getSuccessorRegions(point, successors); - return successors; - }; - - // Return the operands of `terminator` that are forwarded to `successor` if - // the former is not null. Else return the operands of `regionBranchOp` - // forwarded to `successor`. - auto getForwardedOpOperands = [&](const RegionSuccessor &successor, - Operation *terminator = nullptr) { - OperandRange operands = - terminator ? cast<RegionBranchTerminatorOpInterface>(terminator) - .getSuccessorOperands(successor) - : regionBranchOp.getEntrySuccessorOperands(successor); - SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands); - return opOperands; - }; - - // Mark the non-forwarded operands of `regionBranchOp` in - // `nonForwardedOperands`. - auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) { - nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { - for (OpOperand *opOperand : getForwardedOpOperands(successor)) - nonForwardedOperands.reset(opOperand->getOperandNumber()); + for (const RegionSuccessor &successor : successors) { + // Handle branch from point --> successor. + ValueRange argsOrResults = successor.getSuccessorInputs(); + OperandRange operands = + point.isParent() ? regionBranchOp.getEntrySuccessorOperands(successor) : cast<RegionBranchTerminatorOpInterface>(point.getTerminatorPredecessorOrNull()) + .getSuccessorOperands(successor); + assert(argsOrResults.size() == operands.size() && "expected the same number of successor inputs as forwarded operands"); + + for (auto [opOperand, input] : + llvm::zip_equal(operandsToOpOperands(operands), argsOrResults)) { + operandToSuccessorInputs[opOperand].push_back(input); + } } }; - // Mark the non-forwarded terminator operands of the various regions of - // `regionBranchOp` in `nonForwardedRets`. - auto markNonForwardedReturnValues = - [&](DenseMap<Operation *, BitVector> &nonForwardedRets) { - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - // TODO: this isn't correct in face of multiple terminators. - Operation *terminator = region.front().getTerminator(); - nonForwardedRets[terminator] = - BitVector(terminator->getNumOperands(), true); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint( - cast<RegionBranchTerminatorOpInterface>(terminator)))) { - for (OpOperand *opOperand : - getForwardedOpOperands(successor, terminator)) - nonForwardedRets[terminator].reset(opOperand->getOperandNumber()); - } - } - }; - - // Update `valuesToKeep` (which is expected to correspond to operands or - // terminator operands) based on `resultsToKeep` and `argsToKeep`, given - // `region`. When `valuesToKeep` correspond to operands, `region` is null. - // Else, `region` is the parent region of the terminator. - auto updateOperandsOrTerminatorOperandsToKeep = - [&](BitVector &valuesToKeep, BitVector &resultsToKeep, - DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr) { - Operation *terminator = - region ? region->front().getTerminator() : nullptr; - RegionBranchPoint point = - terminator - ? RegionBranchPoint( - cast<RegionBranchTerminatorOpInterface>(terminator)) - : RegionBranchPoint::parent(); - - for (const RegionSuccessor &successor : getSuccessors(point)) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(successor, terminator), - successor.getSuccessorInputs())) { - size_t operandNum = opOperand->getOperandNumber(); - bool updateBasedOn = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - valuesToKeep[operandNum] = valuesToKeep[operandNum] | updateBasedOn; - } - } - }; - - // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and - // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if a - // value is modified, else, false. - auto recomputeResultsAndArgsToKeep = - [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, - BitVector &operandsToKeep, - DenseMap<Operation *, BitVector> &terminatorOperandsToKeep, - bool &resultsOrArgsToKeepChanged) { - resultsOrArgsToKeepChanged = false; - - // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep`. - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint::parent())) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(successor), - successor.getSuccessorInputs())) { - bool recomputeBasedOn = - operandsToKeep[opOperand->getOperandNumber()]; - bool toRecompute = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - if (!toRecompute && recomputeBasedOn) - resultsOrArgsToKeepChanged = true; - if (successorRegion) { - argsToKeep[successorRegion][cast<BlockArgument>(input) - .getArgNumber()] = - argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] | - recomputeBasedOn; - } else { - resultsToKeep[cast<OpResult>(input).getResultNumber()] = - resultsToKeep[cast<OpResult>(input).getResultNumber()] | - recomputeBasedOn; - } - } - } - - // Recompute `resultsToKeep` and `argsToKeep` based on - // `terminatorOperandsToKeep`. - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - Operation *terminator = region.front().getTerminator(); - for (const RegionSuccessor &successor : - getSuccessors(RegionBranchPoint( - cast<RegionBranchTerminatorOpInterface>(terminator)))) { - Region *successorRegion = successor.getSuccessor(); - for (auto [opOperand, input] : - llvm::zip(getForwardedOpOperands(successor, terminator), - successor.getSuccessorInputs())) { - bool recomputeBasedOn = - terminatorOperandsToKeep[region.back().getTerminator()] - [opOperand->getOperandNumber()]; - bool toRecompute = - successorRegion - ? argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] - : resultsToKeep[cast<OpResult>(input).getResultNumber()]; - if (!toRecompute && recomputeBasedOn) - resultsOrArgsToKeepChanged = true; - if (successorRegion) { - argsToKeep[successorRegion][cast<BlockArgument>(input) - .getArgNumber()] = - argsToKeep[successorRegion] - [cast<BlockArgument>(input).getArgNumber()] | - recomputeBasedOn; - } else { - resultsToKeep[cast<OpResult>(input).getResultNumber()] = - resultsToKeep[cast<OpResult>(input).getResultNumber()] | - recomputeBasedOn; - } - } - } - } - }; - - // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`, - // `operandsToKeep`, and `terminatorOperandsToKeep`. - auto markValuesToKeep = - [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep, - BitVector &operandsToKeep, - DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) { - bool resultsOrArgsToKeepChanged = true; - // We keep updating and recomputing the values until we reach a point - // where they stop changing. - while (resultsOrArgsToKeepChanged) { - // Update the operands that need to be kept. - updateOperandsOrTerminatorOperandsToKeep(operandsToKeep, - resultsToKeep, argsToKeep); - - // Update the terminator operands that need to be kept. - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) - continue; - updateOperandsOrTerminatorOperandsToKeep( - terminatorOperandsToKeep[region.back().getTerminator()], - resultsToKeep, argsToKeep, ®ion); - } - - // Recompute the results and arguments that need to be kept. - recomputeResultsAndArgsToKeep( - resultsToKeep, argsToKeep, operandsToKeep, - terminatorOperandsToKeep, resultsOrArgsToKeepChanged); - } - }; - - // Scenario 2. - // At this point, we know that every non-forwarded operand of `regionBranchOp` - // is live. - - // Stores the results of `regionBranchOp` that we want to keep. - BitVector resultsToKeep; - // Stores the mapping from regions of `regionBranchOp` to their arguments that - // we want to keep. - DenseMap<Region *, BitVector> argsToKeep; - // Stores the operands of `regionBranchOp` that we want to keep. - BitVector operandsToKeep; - // Stores the mapping from region terminators in `regionBranchOp` to their - // operands that we want to keep. - DenseMap<Operation *, BitVector> terminatorOperandsToKeep; - - // Initializing the above variables... - - // The live results of `regionBranchOp` definitely need to be kept. - markLiveResults(resultsToKeep); - // Similarly, the live arguments of the regions in `regionBranchOp` definitely - // need to be kept. - markLiveArgs(argsToKeep); - // The non-forwarded operands of `regionBranchOp` definitely need to be kept. - // A live forwarded operand can be removed but no non-forwarded operand can be - // removed since it "controls" the flow of data in this control flow op. - markNonForwardedOperands(operandsToKeep); - // Similarly, the non-forwarded terminator operands of the regions in - // `regionBranchOp` definitely need to be kept. - markNonForwardedReturnValues(terminatorOperandsToKeep); - - // Mark the values (results, arguments, operands, and terminator operands) - // that we want to keep. - markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep, - terminatorOperandsToKeep); - - // Do (1). - cl.operands.push_back({regionBranchOp, operandsToKeep.flip()}); - - // Do (2.a) and (2.b). + // TODO: Add example. + helper(RegionBranchPoint::parent()); for (Region ®ion : regionBranchOp->getRegions()) { if (region.empty()) continue; - BitVector argsToRemove = argsToKeep[®ion].flip(); - cl.blocks.push_back({®ion.front(), argsToRemove}); - collectNonLiveValues(nonLiveSet, region.front().getArguments(), - argsToRemove); + helper(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>( + region.front().getTerminator()))); } - // Do (2.c). - for (Region ®ion : regionBranchOp->getRegions()) { - if (region.empty()) + DenseMap<Operation *, BitVector> deadOperandsPerOp; + for (auto [opOperand, successorInputs] : operandToSuccessorInputs) { + // If one of the successor inputs is live, the respective operand must be + // kept. In that case, all matching successor inputs must be kept. + bool anyAlive = llvm::any_of(successorInputs, [&](Value input) { + return valuesToKeep.contains(input); + }); + if (anyAlive) continue; - Operation *terminator = region.front().getTerminator(); - cl.operands.push_back( - {terminator, terminatorOperandsToKeep[terminator].flip()}); + + // All successor inputs are dead: ub.poison can be passed as operand. + BitVector &deadOperands = + deadOperandsPerOp + .try_emplace(opOperand->getOwner(), + opOperand->getOwner()->getNumOperands(), false) + .first->second; + deadOperands.set(opOperand->getOperandNumber()); } - // Do (3) and (4). - BitVector resultsToRemove = resultsToKeep.flip(); - collectNonLiveValues(nonLiveSet, regionBranchOp.getOperation()->getResults(), - resultsToRemove); - cl.results.push_back({regionBranchOp.getOperation(), resultsToRemove}); + for (auto [op, deadOperands] : deadOperandsPerOp) { + cl.operands.push_back( + {op, deadOperands, nullptr, /*replaceWithPoison=*/true}); + } } /// Steps to process a `BranchOpInterface` operation: @@ -778,9 +573,18 @@ static void processBranchOp(BranchOpInterface branchOp, RunLivenessAnalysis &la, } } +static SmallVector<Value> createPoisonedValues(OpBuilder &b, + ValueRange values) { + return llvm::map_to_vector(values, [&](Value value) { + if (value.getUses().empty()) + return Value(); + return ub::PoisonOp::create(b, value.getLoc(), value.getType()).getResult(); + }); +} + /// Removes dead values collected in RDVFinalCleanupList. /// To be run once when all dead values have been collected. -static void cleanUpDeadVals(RDVFinalCleanupList &list) { +static void cleanUpDeadVals(OpBuilder &builder, RDVFinalCleanupList &list) { LDBG() << "Starting cleanup of dead values..."; // 1. Blocks, We must remove the block arguments and successor operands before @@ -800,10 +604,12 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { }); // Note: Iterate from the end to make sure that that indices of not yet // processes arguments do not change. + builder.setInsertionPointToStart(b.b); for (int i = b.nonLiveArgs.size() - 1; i >= 0; --i) { if (!b.nonLiveArgs[i]) continue; - b.b->getArgument(i).dropAllUses(); + b.b->getArgument(i).replaceAllUsesWith( + createPoisonedValues(builder, b.b->getArgument(i)).front()); b.b->eraseArgument(i); } } @@ -832,22 +638,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { } } - // 3. Operations - LDBG() << "Cleaning up " << list.operations.size() << " operations"; - for (Operation *op : list.operations) { - LDBG() << "Erasing operation: " - << OpWithFlags(op, - OpPrintingFlags().skipRegions().printGenericOpForm()); - if (op->hasTrait<OpTrait::IsTerminator>()) { - // When erasing a terminator, insert an unreachable op in its place. - OpBuilder b(op); - ub::UnreachableOp::create(b, op->getLoc()); - } - op->dropAllUses(); - op->erase(); - } - - // 4. Functions + // 5. Functions LDBG() << "Cleaning up " << list.functions.size() << " functions"; // Record which function arguments were erased so we can shrink call-site // argument segments for CallOpInterface operations (e.g. ops using @@ -864,12 +655,18 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { llvm::interleaveComma(f.nonLiveRets.set_bits(), os); os << "]"; }); - // Drop all uses of the dead arguments. - for (auto deadIdx : f.nonLiveArgs.set_bits()) - f.funcOp.getArgument(deadIdx).dropAllUses(); // Some functions may not allow erasing arguments or results. These calls // return failure in such cases without modifying the function, so it's okay // to proceed. + bool hasBody = !f.funcOp.getFunctionBody().empty(); + if (hasBody) { + builder.setInsertionPointToStart(&f.funcOp.getFunctionBody().front()); + for (auto deadIdx : f.nonLiveArgs.set_bits()) { + f.funcOp.getArgument(deadIdx).replaceAllUsesWith( + createPoisonedValues(builder, f.funcOp.getArgument(deadIdx)) + .front()); + } + } if (succeeded(f.funcOp.eraseArguments(f.nonLiveArgs))) { // Record only if we actually erased something. if (f.nonLiveArgs.any()) @@ -878,7 +675,7 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { (void)f.funcOp.eraseResults(f.nonLiveRets); } - // 5. Operands + // 6. Operands LDBG() << "Cleaning up " << list.operands.size() << " operand lists"; for (OperandsToCleanup &o : list.operands) { // Handle call-specific cleanup only when we have a cached callee reference. @@ -923,11 +720,20 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(o.op, OpPrintingFlags().skipRegions().printGenericOpForm()); }); - o.op->eraseOperands(o.nonLive); + if (o.replaceWithPoison) { + builder.setInsertionPoint(o.op); + for (auto deadIdx : o.nonLive.set_bits()) { + o.op->setOperand( + deadIdx, + createPoisonedValues(builder, o.op->getOperand(deadIdx)).front()); + } + } else { + o.op->eraseOperands(o.nonLive); + } } } - // 6. Results + // 7. Results LDBG() << "Cleaning up " << list.results.size() << " result lists"; for (auto &r : list.results) { LDBG_OS([&](raw_ostream &os) { @@ -937,8 +743,29 @@ static void cleanUpDeadVals(RDVFinalCleanupList &list) { << OpWithFlags(r.op, OpPrintingFlags().skipRegions().printGenericOpForm()); }); - dropUsesAndEraseResults(r.op, r.nonLive); + builder.setInsertionPoint(r.op); + for (auto deadIdx : r.nonLive.set_bits()) { + r.op->getResult(deadIdx).replaceAllUsesWith( + createPoisonedValues(builder, r.op->getResult(deadIdx)).front()); + } + eraseResults(r.op, r.nonLive); } + + // 3. Operations + LDBG() << "Cleaning up " << list.operations.size() << " operations"; + for (Operation *op : list.operations) { + LDBG() << "Erasing operation: " + << OpWithFlags(op, + OpPrintingFlags().skipRegions().printGenericOpForm()); + builder.setInsertionPoint(op); + if (op->hasTrait<OpTrait::IsTerminator>()) { + // When erasing a terminator, insert an unreachable op in its place. + ub::UnreachableOp::create(builder, op->getLoc()); + } + op->replaceAllUsesWith(createPoisonedValues(builder, op->getResults())); + op->erase(); + } + LDBG() << "Finished cleanup of dead values"; } @@ -977,7 +804,8 @@ void RemoveDeadValues::runOnOperation() { } }); - cleanUpDeadVals(finalCleanupList); + OpBuilder b(module->getContext()); + cleanUpDeadVals(b, finalCleanupList); } std::unique_ptr<Pass> mlir::createRemoveDeadValuesPass() { diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index 71306676d48e9..32acf518f9871 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -60,14 +60,14 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { %c1 = arith.constant 1 : index %c10 = arith.constant 10 : index %non_live = arith.constant 0 : index - // CHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { + // CaHECK: [[RESULT:%.+]] = scf.for [[ARG_1:%.*]] = %c0 to %c10 step %c1 iter_args([[ARG_2:%.*]] = %arg0) -> (index) { %result, %result_non_live = scf.for %i = %c0 to %c10 step %c1 iter_args(%live_arg = %arg0, %non_live_arg = %non_live) -> (index, index) { - // CHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index + // CaHECK: [[SUM:%.+]] = arith.addi [[ARG_2]], [[ARG_1]] : index %new_live = arith.addi %live_arg, %i : index - // CHECK: scf.yield [[SUM:%.+]] + // CaHECK: scf.yield [[SUM:%.+]] scf.yield %new_live, %non_live_arg : index, index } - // CHECK: return [[RESULT]] : index + // CaHECK: return [[RESULT]] : index return %result : index } @@ -79,7 +79,8 @@ func.func @cleanable_loop_iter_args_value(%arg0: index) -> index { #map = affine_map<(d0, d1, d2) -> (0, d1, d2)> #map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)> module { - func.func @main() { + // CHECK-LABEL: @dead_linalg_generic + func.func @dead_linalg_generic() { %cst_3 = arith.constant dense<54> : tensor<1x25x13xi32> %cst_7 = arith.constant dense<11> : tensor<1x25x13xi32> // CHECK-NOT: arith.constant @@ -230,17 +231,17 @@ func.func @main() -> (i32, i32) { // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { -// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { -// CHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] -// CHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): -// CHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] -// CHECK-NEXT: scf.yield %[[live_1]] : i32 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[live_and_non_live]]#0 -// CHECK-NEXT: } +// CaHECK: func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%[[arg0:.*]]: i1, %[[arg1:.*]]: i32, %[[arg2:.*]]: i32) -> i32 { +// CaHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg4:.*]] = %[[arg2]]) : (i32) -> (i32, i32) { +// CaHECK-NEXT: %[[live_0:.*]] = arith.addi %[[arg4]], %[[arg4]] +// CaHECK-NEXT: scf.condition(%arg0) %[[live_0]], %[[arg4]] : i32, i32 +// CaHECK-NEXT: } do { +// CaHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CaHECK-NEXT: %[[live_1:.*]] = arith.addi %[[arg6]], %[[arg6]] +// CaHECK-NEXT: scf.yield %[[live_1]] : i32 +// CaHECK-NEXT: } +// CaHECK-NEXT: return %[[live_and_non_live]]#0 +// CaHECK-NEXT: } func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_operand(%arg0: i1, %arg1: i32, %arg2: i32) -> (i32) { %live, %non_live, %non_live_0 = scf.while (%arg3 = %arg1, %arg4 = %arg2) : (i32, i32) -> (i32, i32, i32) { %live_0 = arith.addi %arg4, %arg4 : i32 @@ -284,21 +285,21 @@ func.func @clean_region_branch_op_dont_remove_first_2_results_but_remove_first_o // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { -// CHECK-NEXT: %[[c0:.*]] = arith.constant 0 -// CHECK-NEXT: %[[c1:.*]] = arith.constant 1 -// CHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { -// CHECK-NEXT: func.call @identity() : () -> () -// CHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32 -// CHECK-NEXT: } do { -// CHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): -// CHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32 -// CHECK-NEXT: } -// CHECK-NEXT: return %[[live_and_non_live]]#0 : i32 -// CHECK-NEXT: } -// CHECK: func.func private @identity() { -// CHECK-NEXT: return -// CHECK-NEXT: } +// CaHECK: func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%[[arg2:.*]]: i1) -> i32 { +// CaHECK-NEXT: %[[c0:.*]] = arith.constant 0 +// CaHECK-NEXT: %[[c1:.*]] = arith.constant 1 +// CaHECK-NEXT: %[[live_and_non_live:.*]]:2 = scf.while (%[[arg3:.*]] = %[[c0]], %[[arg4:.*]] = %[[c1]]) : (i32, i32) -> (i32, i32) { +// CaHECK-NEXT: func.call @identity() : () -> () +// CaHECK-NEXT: scf.condition(%[[arg2]]) %[[arg4]], %[[arg3]] : i32, i32 +// CaHECK-NEXT: } do { +// CaHECK-NEXT: ^bb0(%[[arg5:.*]]: i32, %[[arg6:.*]]: i32): +// CaHECK-NEXT: scf.yield %[[arg5]], %[[arg6]] : i32, i32 +// CaHECK-NEXT: } +// CaHECK-NEXT: return %[[live_and_non_live]]#0 : i32 +// CaHECK-NEXT: } +// CaHECK: func.func private @identity() { +// CaHECK-NEXT: return +// CaHECK-NEXT: } func.func @clean_region_branch_op_remove_last_2_results_last_2_arguments_and_last_operand(%arg2: i1) -> (i32) { %c0 = arith.constant 0 : i32 %c1 = arith.constant 1 : i32 @@ -325,17 +326,17 @@ func.func private @identity(%arg1 : i32) -> (i32) { // // Note that this cleanup cannot be done by the `canonicalize` pass. // -// CHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) { -// CHECK-NEXT: scf.index_switch %[[arg0]] -// CHECK-NEXT: case 1 { -// CHECK-NEXT: %[[c10:.*]] = arith.constant 10 -// CHECK-NEXT: memref.store %[[c10]], %[[arg1]][] -// CHECK-NEXT: scf.yield -// CHECK-NEXT: } -// CHECK-NEXT: default { -// CHECK-NEXT: } -// CHECK-NEXT: return -// CHECK-NEXT: } +// CaHECK: func.func @clean_region_branch_op_remove_result(%[[arg0:.*]]: index, %[[arg1:.*]]: memref<i32>) { +// CaHECK-NEXT: scf.index_switch %[[arg0]] +// CaHECK-NEXT: case 1 { +// CaHECK-NEXT: %[[c10:.*]] = arith.constant 10 +// CaHECK-NEXT: memref.store %[[c10]], %[[arg1]][] +// CaHECK-NEXT: scf.yield +// CaHECK-NEXT: } +// CaHECK-NEXT: default { +// CaHECK-NEXT: } +// CaHECK-NEXT: return +// CaHECK-NEXT: } func.func @clean_region_branch_op_remove_result(%arg0 : index, %arg1 : memref<i32>) { %non_live = scf.index_switch %arg0 -> i32 case 1 { @@ -540,9 +541,9 @@ module { } // CHECK-LABEL: func @test_zero_operands -// CHECK: memref.alloca_scope -// CHECK: memref.store -// CHECK-NOT: memref.alloca_scope.return +// CaHECK: memref.alloca_scope +// CaHECK: memref.store +// CaHECK-NOT: memref.alloca_scope.return // ----- @@ -714,3 +715,23 @@ func.func private @remove_dead_branch_op(%c: i1, %arg0: i64, %arg1: i64) -> (i64 ^bb2: return %arg1 : i64 } + +// ----- + +func.func @scf_while_dead_iter_args() -> i32 { + %c5 = arith.constant 5 : i32 + %false = arith.constant false + %result:2 = scf.while (%arg0 = %c5) : (i32) -> (i32, i32) { + vector.print %arg0 : i32 + %cmp2 = arith.cmpi slt, %arg0, %c5 : i32 + scf.condition(%cmp2) {tag = "scf.condition"} %arg0, %arg0 : i32, i32 + } do { + ^bb0(%arg1: i32, %arg2: i32): + %x = scf.execute_region -> i32 { + scf.yield %arg2 : i32 + } + // TODO: not working yet when yielding %x instead of %arg2. + scf.yield {tag = "scf.yield"} %arg2 : i32 + } attributes {tag = "scf.while"} + return %result#0 : i32 +} _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
