llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-scf Author: Matthias Springer (matthias-springer) <details> <summary>Changes</summary> Add a canonicalization pattern to fold unused `scf.index_switch` results. --- Full diff: https://github.com/llvm/llvm-project/pull/173560.diff 2 Files Affected: - (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+51-1) - (modified) mlir/test/Dialect/SCF/canonicalize.mlir (+31) ``````````diff diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 652414f6cbe54..0a123112cf68f 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -4711,9 +4711,59 @@ struct FoldConstantCase : OpRewritePattern<scf::IndexSwitchOp> { } }; +/// Canonicalization patterns that folds away dead results of +/// "scf.index_switch" ops. +struct FoldUnusedIndexSwitchResults : OpRewritePattern<IndexSwitchOp> { + using OpRewritePattern<IndexSwitchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(IndexSwitchOp op, + PatternRewriter &rewriter) const override { + // Find dead results. + BitVector deadResults(op.getNumResults(), false); + SmallVector<Type> newResultTypes; + for (auto [idx, result] : llvm::enumerate(op.getResults())) { + if (!result.use_empty()) { + newResultTypes.push_back(result.getType()); + } else { + deadResults[idx] = true; + } + } + if (!deadResults.any()) + return rewriter.notifyMatchFailure(op, "no dead results to fold"); + + // Create new op without dead results and inline case regions. + auto newOp = IndexSwitchOp::create(rewriter, op.getLoc(), newResultTypes, + op.getArg(), op.getCases(), + op.getCaseRegions().size()); + auto inlineCaseRegion = [&](Region &oldRegion, Region &newRegion) { + rewriter.inlineRegionBefore(oldRegion, newRegion, newRegion.begin()); + // Remove respective operands from yield op. + Operation *terminator = newRegion.front().getTerminator(); + assert(isa<YieldOp>(terminator) && "expected yield op"); + rewriter.modifyOpInPlace( + terminator, [&]() { terminator->eraseOperands(deadResults); }); + }; + for (auto [oldRegion, newRegion] : + llvm::zip_equal(op.getCaseRegions(), newOp.getCaseRegions())) + inlineCaseRegion(oldRegion, newRegion); + inlineCaseRegion(op.getDefaultRegion(), newOp.getDefaultRegion()); + + // Replace op with new op. + SmallVector<Value> newResults(op.getNumResults(), Value()); + unsigned nextNewResult = 0; + for (unsigned idx = 0; idx < op.getNumResults(); ++idx) { + if (deadResults[idx]) + continue; + newResults[idx] = newOp.getResult(nextNewResult++); + } + rewriter.replaceOp(op, newResults); + return success(); + } +}; + void IndexSwitchOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { - results.add<FoldConstantCase>(context); + results.add<FoldConstantCase, FoldUnusedIndexSwitchResults>(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index ac590fc0c47b9..d5d0aee3bbe25 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -2171,3 +2171,34 @@ func.func @scf_for_all_step_size_0() { } return } + +// ----- + +// CHECK-LABEL: func @dead_index_switch_result( +// CHECK-SAME: %[[arg0:.*]]: index +// CHECK-DAG: %[[c10:.*]] = arith.constant 10 +// CHECK-DAG: %[[c11:.*]] = arith.constant 11 +// CHECK: %[[switch:.*]] = scf.index_switch %[[arg0]] -> index +// CHECK: case 1 { +// CHECK: memref.store %[[c10]] +// CHECK: scf.yield %[[arg0]] : index +// CHECK: } +// CHECK: default { +// CHECK: memref.store %[[c11]] +// CHECK: scf.yield %[[arg0]] : index +// CHECK: } +// CHECK: return %[[switch]] +func.func @dead_index_switch_result(%arg0 : index, %arg1 : memref<i32>) -> index { + %non_live, %live = scf.index_switch %arg0 -> i32, index + case 1 { + %c10 = arith.constant 10 : i32 + memref.store %c10, %arg1[] : memref<i32> + scf.yield %c10, %arg0 : i32, index + } + default { + %c11 = arith.constant 11 : i32 + memref.store %c11, %arg1[] : memref<i32> + scf.yield %c11, %arg0 : i32, index + } + return %live : index +} `````````` </details> https://github.com/llvm/llvm-project/pull/173560 _______________________________________________ llvm-branch-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
