https://github.com/joker-eph updated https://github.com/llvm/llvm-project/pull/154038
>From a33b1af2c37a62191ba346c43ec238e7e653d29e Mon Sep 17 00:00:00 2001 From: Mehdi Amini <joker....@gmail.com> Date: Sun, 17 Aug 2025 14:24:35 -0700 Subject: [PATCH] [MLIR] Stop visiting unreachable blocks in the walkAndApplyPatterns driver This is similar to the fix to the greedy driver in #153957 ; except that instead of removing unreachable code, we just ignore it. Operations like: %add = arith.addi %add, %add : i64 are legal in unreachable code. Unfortunately many patterns would be unsafe to apply on such IR and can lead to crashes or infinite loops. --- .../Transforms/WalkPatternRewriteDriver.h | 2 ++ .../Utils/WalkPatternRewriteDriver.cpp | 27 +++++++++++++++++++ .../IR/test-walk-pattern-rewrite-driver.mlir | 20 ++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h index 6d62ae3dd43dc..7d5c1d5cebb26 100644 --- a/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/WalkPatternRewriteDriver.h @@ -27,6 +27,8 @@ namespace mlir { /// This is intended as the simplest and most lightweight pattern rewriter in /// cases when a simple walk gets the job done. /// +/// The driver will skip unreachable blocks. +/// /// Note: Does not apply patterns to the given operation itself. void walkAndApplyPatterns(Operation *op, const FrozenRewritePatternSet &patterns, diff --git a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp index 2111e29120567..1382550e0f7e6 100644 --- a/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/WalkPatternRewriteDriver.cpp @@ -27,6 +27,26 @@ namespace mlir { +// Find all reachable blocks in the region and add them to the visitedBlocks +// set. +static void findReachableBlocks(Region ®ion, + DenseSet<Block *> &reachableBlocks) { + Block *entryBlock = ®ion.front(); + reachableBlocks.insert(entryBlock); + // Traverse the CFG and add all reachable blocks to the blockList. + SmallVector<Block *> worklist({entryBlock}); + while (!worklist.empty()) { + Block *block = worklist.pop_back_val(); + Operation *terminator = &block->back(); + for (Block *successor : terminator->getSuccessors()) { + if (reachableBlocks.contains(successor)) + continue; + worklist.push_back(successor); + reachableBlocks.insert(successor); + } + } +} + namespace { struct WalkAndApplyPatternsAction final : tracing::ActionImpl<WalkAndApplyPatternsAction> { @@ -98,6 +118,8 @@ void walkAndApplyPatterns(Operation *op, regionIt = region->begin(); if (regionIt != region->end()) blockIt = regionIt->begin(); + if (!llvm::hasSingleElement(*region)) + findReachableBlocks(*region, reachableBlocks); } // Advance the iterator to the next reachable operation. void advance() { @@ -105,6 +127,9 @@ void walkAndApplyPatterns(Operation *op, hasVisitedRegions = false; if (blockIt == regionIt->end()) { ++regionIt; + while (regionIt != region->end() && + !reachableBlocks.contains(&*regionIt)) + ++regionIt; if (regionIt != region->end()) blockIt = regionIt->begin(); return; @@ -121,6 +146,8 @@ void walkAndApplyPatterns(Operation *op, Region::iterator regionIt; // The Operation currently being iterated over. Block::iterator blockIt; + // The set of blocks that are reachable in the current region. + DenseSet<Block *> reachableBlocks; // Whether we've visited the nested regions of the current op already. bool hasVisitedRegions = false; }; diff --git a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir index c75c478ec3734..c3063416b0360 100644 --- a/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir +++ b/mlir/test/IR/test-walk-pattern-rewrite-driver.mlir @@ -119,3 +119,23 @@ func.func @erase_nested_block() -> i32 { }): () -> (i32) return %a : i32 } + + +// CHECK-LABEL: func.func @unreachable_replace_with_new_op +// CHECK: "test.new_op" +// CHECK: "test.replace_with_new_op" +// CHECK-SAME: unreachable +// CHECK: "test.new_op" +func.func @unreachable_replace_with_new_op() { + "test.br"()[^bb1] : () -> () +^bb1: + %a = "test.replace_with_new_op"() : () -> (i32) + "test.br"()[^end] : () -> () // Test jumping over the unreachable block is visited as well. +^unreachable: + %b = "test.replace_with_new_op"() {test.unreachable} : () -> (i32) + return +^end: + %c = "test.replace_with_new_op"() : () -> (i32) + return +} + _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits