================ @@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { } }; +class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { +public: + using OpRewritePattern<cir::SwitchOp>::OpRewritePattern; + + inline void rewriteYieldOp(mlir::PatternRewriter &rewriter, + cir::YieldOp yieldOp, + mlir::Block *destination) const { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(), + destination); + } + + // Return the new defaultDestination block. + Block *condBrToRangeDestination(cir::SwitchOp op, + mlir::PatternRewriter &rewriter, + mlir::Block *rangeDestination, + mlir::Block *defaultDestination, + const APInt &lowerBound, + const APInt &upperBound) const { + assert(lowerBound.sle(upperBound) && "Invalid range"); + mlir::Block *resBlock = rewriter.createBlock(defaultDestination); + cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); + cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); + + cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); + + cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); + cir::BinOp diffValue = + rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, + op.getCondition(), lowerBoundValue); + + // Use unsigned comparison to check if the condition is in the range. + cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, diffValue); + cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, rangeLength); + + cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( + op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, + uDiffValue, uRangeLength); + rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, + defaultDestination); + return resBlock; + } + + mlir::LogicalResult + matchAndRewrite(cir::SwitchOp op, + mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector<CaseOp> cases; + op.collectCases(cases); + + // Empty switch statement: just erase it. + if (cases.empty()) { + rewriter.eraseOp(op); + return mlir::success(); + } + + // Create exit block from the next node of cir.switch op. + mlir::Block *exitBlock = rewriter.splitBlock( + rewriter.getBlock(), op->getNextNode()->getIterator()); + + // We lower cir.switch op in the following process: + // 1. Inline the region from the switch op after switch op. + // 2. Traverse each cir.case op: + // a. Record the entry block, block arguments and condition for every + // case. b. Inline the case region after the case op. + // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the + // recorded block and conditions. + + // inline everything from switch body between the switch op and the exit + // block. + { + cir::YieldOp switchYield = nullptr; + // Clear switch operation. + for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) ---------------- andykaylor wrote:
```suggestion for (mlir::Block &block : llvm::make_early_inc_range(op.getBody().getBlocks())) ``` https://github.com/llvm/llvm-project/pull/139154 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits