llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
This commit simplifies the `remove-dead-values` pass and fixes a bug in the
handling of `RegionBranchOpInterface` ops. The pass used to produce invalid IR
("null value found") for the newly added test case.
`remove-dead-values` is a pass for additional IR simplification that cannot be
performed by the canonicalizer pass. Based on a liveness analysis, it erases
dead values / IR. (The liveness analysis is a dataflow analysis that has more
information about the IR than a canonicalization pattern, which can see only
"local" information.)
Region-based ops are difficult. The liveness analysis may determine that an SSA
value is dead. However, that does not mean that the value can actually be
removed. Doing so may violate an region data flow (as modeled by the
`RegionBranchOpInterface`). As an example, consider the case where a region
branch terminator may dispatch to one of two region successor with the same
forwarded values. A successor input (block argument) can be erased only if it
is dead on both successors.
Before this commit, there used to be complex logic to determine when it is safe
to erase an SSA value. That logic was broken. The new implementation does not
remove any block arguments or op results of region-based ops. Instead, operands
of region-based ops and region branch terminators are replaced with `ub.poison`
if all of their successor values are dead. This simplifies the IR good enough
for the canonicalizer to perform the remaining region simplification (i.e.,
dropping block arguments etc.).
RFC:
https://discourse.llvm.org/t/rfc-delegate-simplification-of-region-based-ops-from-remove-dead-values-to-canonicalizer/89194
Depends on #<!-- -->173560.
---
Patch is 38.42 KiB, truncated to 20.00 KiB below, full version:
https://github.com/llvm/llvm-project/pull/173505.diff
2 Files Affected:
- (modified) mlir/lib/Transforms/RemoveDeadValues.cpp (+184-314)
- (modified) mlir/test/Transforms/remove-dead-values.mlir (+109-44)
``````````diff
diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp
b/mlir/lib/Transforms/RemoveDeadValues.cpp
index 62ce5e0bbb77e..a347f335c9c1e 100644
--- a/mlir/lib/Transforms/RemoveDeadValues.cpp
+++ b/mlir/lib/Transforms/RemoveDeadValues.cpp
@@ -94,8 +94,11 @@ struct ResultsToCleanup {
struct OperandsToCleanup {
Operation *op;
BitVector nonLive;
- Operation *callee =
- nullptr; // Optional: For CallOpInterface ops, stores the callee function
+ // Optional: For CallOpInterface ops, stores the callee function.
+ Operation *callee = nullptr;
+ // Determines whether the operand should be replaced with a ub.poison result
+ // or erased entirely.
+ bool replaceWithPoison = false;
};
struct BlockArgsToCleanup {
@@ -199,9 +202,9 @@ static void collectNonLiveValues(DenseSet<Value>
&nonLiveSet, ValueRange range,
}
}
-/// Drop the uses of the i-th result of `op` and then erase it iff toErase[i]
-/// is 1.
-static void dropUsesAndEraseResults(Operation *op, BitVector toErase) {
+/// Erase the i-th result of `op` iff toErase[i] is 1.
+static void eraseResults(RewriterBase &rewriter, Operation *op,
+ BitVector toErase) {
assert(op->getNumResults() == toErase.size() &&
"expected the number of results in `op` and the size of `toErase` to "
"be the same");
@@ -210,7 +213,6 @@ static void dropUsesAndEraseResults(Operation *op,
BitVector toErase) {
for (OpResult result : op->getResults())
if (!toErase[result.getResultNumber()])
newResultTypes.push_back(result.getType());
- IRRewriter rewriter(op);
rewriter.setInsertionPointAfter(op);
OperationState state(op->getLoc(), op->getName().getStringRef(),
op->getOperands(), newResultTypes, op->getAttrs());
@@ -226,14 +228,12 @@ static void dropUsesAndEraseResults(Operation *op,
BitVector toErase) {
unsigned indexOfNextNewCallOpResultToReplace = 0;
for (auto [index, result] : llvm::enumerate(op->getResults())) {
assert(result && "expected result to be non-null");
- if (toErase[index]) {
- result.dropAllUses();
- } else {
+ if (!toErase[index]) {
result.replaceAllUsesWith(
newOp->getResult(indexOfNextNewCallOpResultToReplace++));
}
}
- op->erase();
+ rewriter.eraseOp(op);
}
/// Convert a list of `Operand`s to a list of `OpOperand`s.
@@ -404,30 +404,20 @@ static void processFuncOp(FunctionOpInterface funcOp,
Operation *module,
///
/// Scenario 1: If the operation has no memory effects and none of its results
/// are live:
-/// (1') Enqueue all its uses for deletion.
-/// (2') Enqueue the branch itself for deletion.
+/// 1.1. Enqueue all its uses for deletion.
+/// 1.2. Enqueue the branch itself for deletion.
///
/// Scenario 2: Otherwise:
-/// (1) Collect its unnecessary operands (operands forwarded to unnecessary
-/// results or arguments).
-/// (2) Process each of its regions.
-/// (3) Collect the uses of its unnecessary results (results forwarded from
-/// unnecessary operands
-/// or terminator operands).
-/// (4) Add these results to the deletion list.
-///
-/// Processing a region includes:
-/// (a) Collecting the uses of its unnecessary arguments (arguments forwarded
-/// from unnecessary operands
-/// or terminator operands).
-/// (b) Collecting these unnecessary arguments.
-/// (c) Collecting its unnecessary terminator operands (terminator operands
-/// forwarded to unnecessary results
-/// or arguments).
+/// 2.1. Collect block arguments and op results that we would like to keep,
+/// based on their liveness.
+/// 2.2. Find all operands that are forwarded to only dead region successor
+/// inputs. I.e., forwarded to block arguments / op results that we do
+/// not want to keep.
+/// 2.3. Enqueue all such operands for replacement with ub.poison.
///
-/// Value Flow Note: In this operation, values flow as follows:
-/// - From operands and terminator operands (successor operands)
-/// - To arguments and results (successor inputs).
+/// Note: In scenario 2, block arguments and op results are not removed.
+/// However, the IR is simplified such that canonicalization patterns can
+/// remove them later.
static void processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
RunLivenessAnalysis &la,
DenseSet<Value> &nonLiveSet,
@@ -441,284 +431,103 @@ static void
processRegionBranchOp(RegionBranchOpInterface regionBranchOp,
// case, a non-forwarded operand of `regionBranchOp` could be live/non-live.
// It could never be live because of this op but its liveness could have been
// attributed to something else.
- // Do (1') and (2').
if (isMemoryEffectFree(regionBranchOp.getOperation()) &&
!hasLive(regionBranchOp->getResults(), nonLiveSet, la)) {
cl.operations.push_back(regionBranchOp.getOperation());
return;
}
- // Mark live results of `regionBranchOp` in `liveResults`.
- auto markLiveResults = [&](BitVector &liveResults) {
- liveResults = markLives(regionBranchOp->getResults(), nonLiveSet, la);
- };
-
- // Mark live arguments in the regions of `regionBranchOp` in `liveArgs`.
- auto markLiveArgs = [&](DenseMap<Region *, BitVector> &liveArgs) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- SmallVector<Value> arguments(region.front().getArguments());
- BitVector regionLiveArgs = markLives(arguments, nonLiveSet, la);
- liveArgs[®ion] = regionLiveArgs;
+ // Compute values that are alive.
+ DenseSet<Value> valuesToKeep;
+ for (Value result : regionBranchOp->getResults()) {
+ if (hasLive(result, nonLiveSet, la))
+ valuesToKeep.insert(result);
+ }
+ for (Region ®ion : regionBranchOp->getRegions()) {
+ if (region.empty())
+ continue;
+ for (Value arg : region.front().getArguments()) {
+ if (hasLive(arg, nonLiveSet, la))
+ valuesToKeep.insert(arg);
}
- };
+ }
- // Return the successors of `region` if the latter is not null. Else return
- // the successors of `regionBranchOp`.
- auto getSuccessors = [&](RegionBranchPoint point) {
+ // Mapping from operands to forwarded successor inputs. An operand can be
+ // forwarded to multiple successors.
+ DenseMap<OpOperand *, SmallVector<Value>> operandToSuccessorInputs;
+ auto helper = [&](RegionBranchPoint point) {
SmallVector<RegionSuccessor> successors;
regionBranchOp.getSuccessorRegions(point, successors);
- return successors;
- };
-
- // Return the operands of `terminator` that are forwarded to `successor` if
- // the former is not null. Else return the operands of `regionBranchOp`
- // forwarded to `successor`.
- auto getForwardedOpOperands = [&](const RegionSuccessor &successor,
- Operation *terminator = nullptr) {
- OperandRange operands =
- terminator ? cast<RegionBranchTerminatorOpInterface>(terminator)
- .getSuccessorOperands(successor)
- : regionBranchOp.getEntrySuccessorOperands(successor);
- SmallVector<OpOperand *> opOperands = operandsToOpOperands(operands);
- return opOperands;
- };
-
- // Mark the non-forwarded operands of `regionBranchOp` in
- // `nonForwardedOperands`.
- auto markNonForwardedOperands = [&](BitVector &nonForwardedOperands) {
- nonForwardedOperands.resize(regionBranchOp->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- for (OpOperand *opOperand : getForwardedOpOperands(successor))
- nonForwardedOperands.reset(opOperand->getOperandNumber());
+ for (const RegionSuccessor &successor : successors) {
+ // Handle branch from point --> successor.
+ ValueRange argsOrResults = successor.getSuccessorInputs();
+ OperandRange operands =
+ point.isParent() ?
regionBranchOp.getEntrySuccessorOperands(successor)
+ : cast<RegionBranchTerminatorOpInterface>(
+ point.getTerminatorPredecessorOrNull())
+ .getSuccessorOperands(successor);
+ assert(
+ argsOrResults.size() == operands.size() &&
+ "expected the same number of successor inputs as forwarded
operands");
+
+ for (auto [opOperand, input] :
+ llvm::zip_equal(operandsToOpOperands(operands), argsOrResults)) {
+ operandToSuccessorInputs[opOperand].push_back(input);
+ }
}
};
- // Mark the non-forwarded terminator operands of the various regions of
- // `regionBranchOp` in `nonForwardedRets`.
- auto markNonForwardedReturnValues =
- [&](DenseMap<Operation *, BitVector> &nonForwardedRets) {
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- // TODO: this isn't correct in face of multiple terminators.
- Operation *terminator = region.front().getTerminator();
- nonForwardedRets[terminator] =
- BitVector(terminator->getNumOperands(), true);
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- for (OpOperand *opOperand :
- getForwardedOpOperands(successor, terminator))
-
nonForwardedRets[terminator].reset(opOperand->getOperandNumber());
- }
- }
- };
-
- // Update `valuesToKeep` (which is expected to correspond to operands or
- // terminator operands) based on `resultsToKeep` and `argsToKeep`, given
- // `region`. When `valuesToKeep` correspond to operands, `region` is null.
- // Else, `region` is the parent region of the terminator.
- auto updateOperandsOrTerminatorOperandsToKeep =
- [&](BitVector &valuesToKeep, BitVector &resultsToKeep,
- DenseMap<Region *, BitVector> &argsToKeep, Region *region = nullptr)
{
- Operation *terminator =
- region ? region->front().getTerminator() : nullptr;
- RegionBranchPoint point =
- terminator
- ? RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator))
- : RegionBranchPoint::parent();
-
- for (const RegionSuccessor &successor : getSuccessors(point)) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- size_t operandNum = opOperand->getOperandNumber();
- bool updateBasedOn =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- valuesToKeep[operandNum] = valuesToKeep[operandNum] |
updateBasedOn;
- }
- }
- };
-
- // Recompute `resultsToKeep` and `argsToKeep` based on `operandsToKeep` and
- // `terminatorOperandsToKeep`. Store true in `resultsOrArgsToKeepChanged` if
a
- // value is modified, else, false.
- auto recomputeResultsAndArgsToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep,
- bool &resultsOrArgsToKeepChanged) {
- resultsOrArgsToKeepChanged = false;
-
- // Recompute `resultsToKeep` and `argsToKeep` based on
`operandsToKeep`.
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint::parent())) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- operandsToKeep[opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
-
- // Recompute `resultsToKeep` and `argsToKeep` based on
- // `terminatorOperandsToKeep`.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- Operation *terminator = region.front().getTerminator();
- for (const RegionSuccessor &successor :
- getSuccessors(RegionBranchPoint(
- cast<RegionBranchTerminatorOpInterface>(terminator)))) {
- Region *successorRegion = successor.getSuccessor();
- for (auto [opOperand, input] :
- llvm::zip(getForwardedOpOperands(successor, terminator),
- successor.getSuccessorInputs())) {
- bool recomputeBasedOn =
- terminatorOperandsToKeep[region.back().getTerminator()]
- [opOperand->getOperandNumber()];
- bool toRecompute =
- successorRegion
- ? argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()]
- : resultsToKeep[cast<OpResult>(input).getResultNumber()];
- if (!toRecompute && recomputeBasedOn)
- resultsOrArgsToKeepChanged = true;
- if (successorRegion) {
- argsToKeep[successorRegion][cast<BlockArgument>(input)
- .getArgNumber()] =
- argsToKeep[successorRegion]
- [cast<BlockArgument>(input).getArgNumber()] |
- recomputeBasedOn;
- } else {
- resultsToKeep[cast<OpResult>(input).getResultNumber()] =
- resultsToKeep[cast<OpResult>(input).getResultNumber()] |
- recomputeBasedOn;
- }
- }
- }
- }
- };
-
- // Mark the values that we want to keep in `resultsToKeep`, `argsToKeep`,
- // `operandsToKeep`, and `terminatorOperandsToKeep`.
- auto markValuesToKeep =
- [&](BitVector &resultsToKeep, DenseMap<Region *, BitVector> &argsToKeep,
- BitVector &operandsToKeep,
- DenseMap<Operation *, BitVector> &terminatorOperandsToKeep) {
- bool resultsOrArgsToKeepChanged = true;
- // We keep updating and recomputing the values until we reach a point
- // where they stop changing.
- while (resultsOrArgsToKeepChanged) {
- // Update the operands that need to be kept.
- updateOperandsOrTerminatorOperandsToKeep(operandsToKeep,
- resultsToKeep, argsToKeep);
-
- // Update the terminator operands that need to be kept.
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
- continue;
- updateOperandsOrTerminatorOperandsToKeep(
- terminatorOperandsToKeep[region.back().getTerminator()],
- resultsToKeep, argsToKeep, ®ion);
- }
-
- // Recompute the results and arguments that need to be kept.
- recomputeResultsAndArgsToKeep(
- resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep, resultsOrArgsToKeepChanged);
- }
- };
-
- // Scenario 2.
- // At this point, we know that every non-forwarded operand of
`regionBranchOp`
- // is live.
-
- // Stores the results of `regionBranchOp` that we want to keep.
- BitVector resultsToKeep;
- // Stores the mapping from regions of `regionBranchOp` to their arguments
that
- // we want to keep.
- DenseMap<Region *, BitVector> argsToKeep;
- // Stores the operands of `regionBranchOp` that we want to keep.
- BitVector operandsToKeep;
- // Stores the mapping from region terminators in `regionBranchOp` to their
- // operands that we want to keep.
- DenseMap<Operation *, BitVector> terminatorOperandsToKeep;
-
- // Initializing the above variables...
-
- // The live results of `regionBranchOp` definitely need to be kept.
- markLiveResults(resultsToKeep);
- // Similarly, the live arguments of the regions in `regionBranchOp`
definitely
- // need to be kept.
- markLiveArgs(argsToKeep);
- // The non-forwarded operands of `regionBranchOp` definitely need to be kept.
- // A live forwarded operand can be removed but no non-forwarded operand can
be
- // removed since it "controls" the flow of data in this control flow op.
- markNonForwardedOperands(operandsToKeep);
- // Similarly, the non-forwarded terminator operands of the regions in
- // `regionBranchOp` definitely need to be kept.
- markNonForwardedReturnValues(terminatorOperandsToKeep);
-
- // Mark the values (results, arguments, operands, and terminator operands)
- // that we want to keep.
- markValuesToKeep(resultsToKeep, argsToKeep, operandsToKeep,
- terminatorOperandsToKeep);
-
- // Do (1).
- cl.operands.push_back({regionBranchOp, operandsToKeep.flip()});
-
- // Do (2.a) and (2.b).
+ // Example:
+ //
+ // %0 = scf.while : () -> i32 {
+ // scf.condition(...) %forwarded_value : i32
+ // } do {
+ // ^bb0(%arg0: i32):
+ // scf.yield
+ // }
+ // // No uses of %0.
+ //
+ // In the above example, %forwarded_value is forwarded to %arg0 and %0. Both
+ // %arg0 and %0 are dead, so %forwarded_value can be replaced with a
+ // ub.poison result.
+ //
+ // operandToSuccessorInputs[%forwarded_value] = {%arg0, %0}
+ //
+ helper(RegionBranchPoint::parent());
for (Region ®ion : regionBranchOp->getRegions()) {
if (region.empty())
continue;
- BitVector argsToRemove = argsToKeep[®ion].flip();
- cl.blocks.push_back({®ion.front(), argsToRemove});
- collectNonLiveValues(nonLiveSet, region.front().getArguments(),
- argsToRemove);
+ helper(RegionBranchPoint(cast<RegionBranchTerminatorOpInterface>(
+ region.front().getTerminator())));
}
- // Do (2.c).
- for (Region ®ion : regionBranchOp->getRegions()) {
- if (region.empty())
+ DenseMap<Operation *, BitVector> deadOperandsPerOp;
+ for (auto [opOperand, successorInputs] : operandToSuccessorInputs) {
+ // If one of the successor inputs is live, the respective operand must be
+ // kept.
+ bool anyAlive = llvm::any_of(successorInputs, [&](Value input) {
+ return valuesToKeep.contains(input);
+ });
+ if (anyAlive)
continue;
- Operation *terminator = region.front().getTerminator();
- cl.operands.push_back(
- {terminator, terminatorOperandsToKeep[terminator].flip()});
+
+ // All successor inputs are dead: ub.poison can be passed as operand.
+ // Create an entry in `deadOperandsPerOp` (initialized to "false", i.e.,
+ // no "dead" op operands) if it's the first time that we are seeing an op
+ // operand for this op. Otherwise, just take the existing bit vector from
+ // the map.
+ BitVector &deadOperands =
+ deadOperandsPerOp
+ .try_emplace(opOperand->getOwner(),
+ opOperand->getOwner()->getNumOperands(), false)
+ .first->second;
+ deadOperands.set(opOperand->getOperandNumber());
}
- // Do (3) and (4).
- BitVector resultsToRemove = ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/173505
_______________________________________________
llvm-branch-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits