Issue 86765
Summary [MLIR] Top-down applyPatternsAndFoldGreedily fails to remove dead code that bottom-up applyPatternsAndFoldGreedily removes
Labels mlir
Assignees
Reporter mlevesquedion
    Consider the following IR:

```
func.func @foo(%arg0: i1, %arg1: i32) {
 %0 = arith.constant 0 : i32
  %1 = arith.bitcast %arg1 : i32 to i32
 %if = scf.if %arg0 -> (i32) {
    scf.yield %0 : i32
  } else {
 scf.yield %1 : i32
  }
  %dead_leaf = arith.addi %if, %if : i32
 func.return
}
```

When applying a single iteration of `applyPatternsAndFoldGreedily` with `config.useTopDownTraversal = true`, we get:

```
func.func @foo(%arg0: i1, %arg1: i32) {
  %c0_i32 = arith.constant 0 : i32
  return
}
```

That's pretty good, but `%c0_i32` is clearly dead, and probably should have been removed. Indeed, when doing a bottom-up traversal (`config.useTopDownTraversal = false`) we get:

```
func.func @foo(%arg0: i1, %arg1: i32) {
 return
}
```

Interestingly, if instead of `%dead_leaf = arith.addi %if, %if : i32`, we use `%dead_leaf = arith.bitcast %if : i32 to i32`, both types of traversal manage to remove all the dead code.

Note also that this pattern can be chained, in which case a bottom-up traversal will still remove all dead values, whereas a top-down traversal will require 1 iteration for each repetition of the pattern in order to remove all the dead values. This is in fact how I ran into this in the first place: I was working with a fairly large (generated) function where it took 10 iterations to fully remove all the dead code with a top-down traversal, whereas a bottom-up traversal required a single iteration.

For additional context, see this thread in the MLIR Discord: https://discordapp.com/channels/636084430946959380/1221989192826097774

In order to produce the above examples, I used `mlir-opt --canonicalize`, with the following patch applied:

```
diff --git a/mlir/lib/Transforms/Canonicalizer.cpp b/mlir/lib/Transforms/Canonicalizer.cpp
index d50019bd6aee..200220a297ab 100644
--- a/mlir/lib/Transforms/Canonicalizer.cpp
+++ b/mlir/lib/Transforms/Canonicalizer.cpp
@@ -43,24 +43,16 @@ struct Canonicalizer : public impl::CanonicalizerBase<Canonicalizer> {
   /// execution.
   LogicalResult initialize(MLIRContext *context) override {
     // Set the config from possible pass options set in the meantime.
-    config.useTopDownTraversal = topDownProcessingEnabled;
+ config.useTopDownTraversal = true;
 config.enableRegionSimplification = enableRegionSimplification;
- config.maxIterations = maxIterations;
+    config.maxIterations = 1;
 config.maxNumRewrites = maxNumRewrites;

-    RewritePatternSet owningPatterns(context);
-    for (auto *dialect : context->getLoadedDialects())
- dialect->getCanonicalizationPatterns(owningPatterns);
-    for (RegisteredOperationName op : context->getRegisteredOperations())
- op.getCanonicalizationPatterns(owningPatterns, context);
-
- patterns = std::make_shared<FrozenRewritePatternSet>(
- std::move(owningPatterns), disabledPatterns, enabledPatterns);
     return success();
   }
   void runOnOperation() override {
 LogicalResult converged =
- applyPatternsAndFoldGreedily(getOperation(), *patterns, config);
+ applyPatternsAndFoldGreedily(getOperation(), RewritePatternSet(&getContext()), config);
     // Canonicalization is best-effort. Non-convergence is not a pass failure.
     if (testConvergence && failed(converged))
       signalPassFailure();
 ```
 
 (The same results can be obtained by writing a new pass that doesn't use any patterns and merely calls `applyPatternsAndFoldGreedily` with the desired `config`. The above is just an easy way to isolate the behavior of interest.)
 
 This was observed at `HEAD` specifically `4d03a9ecc697a11f0edd3c31440a7cae3398e24a` at time of writing.
_______________________________________________
llvm-bugs mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-bugs

Reply via email to