| Issue |
86765
|
| Summary |
[MLIR] Top-down applyPatternsAndFoldGreedily fails to remove dead code that bottom-up applyPatternsAndFoldGreedily removes
|
| Labels |
mlir
|
| Assignees |
|
| Reporter |
mlevesquedion
|
Consider the following IR:
```
func.func @foo(%arg0: i1, %arg1: i32) {
%0 = arith.constant 0 : i32
%1 = arith.bitcast %arg1 : i32 to i32
%if = scf.if %arg0 -> (i32) {
scf.yield %0 : i32
} else {
scf.yield %1 : i32
}
%dead_leaf = arith.addi %if, %if : i32
func.return
}
```
When applying a single iteration of `applyPatternsAndFoldGreedily` with `config.useTopDownTraversal = true`, we get:
```
func.func @foo(%arg0: i1, %arg1: i32) {
%c0_i32 = arith.constant 0 : i32
return
}
```
That's pretty good, but `%c0_i32` is clearly dead, and probably should have been removed. Indeed, when doing a bottom-up traversal (`config.useTopDownTraversal = false`) we get:
```
func.func @foo(%arg0: i1, %arg1: i32) {
return
}
```
Interestingly, if instead of `%dead_leaf = arith.addi %if, %if : i32`, we use `%dead_leaf = arith.bitcast %if : i32 to i32`, both types of traversal manage to remove all the dead code.
Note also that this pattern can be chained, in which case a bottom-up traversal will still remove all dead values, whereas a top-down traversal will require 1 iteration for each repetition of the pattern in order to remove all the dead values. This is in fact how I ran into this in the first place: I was working with a fairly large (generated) function where it took 10 iterations to fully remove all the dead code with a top-down traversal, whereas a bottom-up traversal required a single iteration.
For additional context, see this thread in the MLIR Discord: https://discordapp.com/channels/636084430946959380/1221989192826097774
In order to produce the above examples, I used `mlir-opt --canonicalize`, with the following patch applied:
```
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee..200220a297ab 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -43,24 +43,16 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
/// execution.
LogicalResult initialize(MLIRContext *context) override {
// Set the config from possible pass options set in the meantime.
- config.useTopDownTraversal = topDownProcessingEnabled;
+ config.useTopDownTraversal = true;
config.enableRegionSimplification = enableRegionSimplification;
- config.maxIterations = maxIterations;
+ config.maxIterations = 1;
config.maxNumRewrites = maxNumRewrites;
- RewritePatternSet owningPatterns(context);
- for (auto *dialect : context->getLoadedDialects())
- dialect->getCanonicalizationPatterns(owningPatterns);
- for (RegisteredOperationName op : context->getRegisteredOperations())
- op.getCanonicalizationPatterns(owningPatterns, context);
-
- patterns = std::make_shared<FrozenRewritePatternSet>(
- std::move(owningPatterns), disabledPatterns, enabledPatterns);
return success();
}
void runOnOperation() override {
LogicalResult converged =
- applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
+ applyPatternsAndFoldGreedily(getOperation(), RewritePatternSet(&getContext()), config);
// Canonicalization is best-effort. Non-convergence is not a pass failure.
if (testConvergence && failed(converged))
signalPassFailure();
```
(The same results can be obtained by writing a new pass that doesn't use any patterns and merely calls `applyPatternsAndFoldGreedily` with the desired `config`. The above is just an easy way to isolate the behavior of interest.)
This was observed at `HEAD` specifically `4d03a9ecc697a11f0edd3c31440a7cae3398e24a` at time of writing.
_______________________________________________
llvm-bugs mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs