================ @@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { } }; +class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { +public: + using OpRewritePattern<cir::SwitchOp>::OpRewritePattern; + + inline void rewriteYieldOp(mlir::PatternRewriter &rewriter, + cir::YieldOp yieldOp, + mlir::Block *destination) const { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(), + destination); + } + + // Return the new defaultDestination block. + Block *condBrToRangeDestination(cir::SwitchOp op, + mlir::PatternRewriter &rewriter, + mlir::Block *rangeDestination, + mlir::Block *defaultDestination, + const APInt &lowerBound, + const APInt &upperBound) const { + assert(lowerBound.sle(upperBound) && "Invalid range"); + mlir::Block *resBlock = rewriter.createBlock(defaultDestination); + cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); + cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); + + cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); + + cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); + cir::BinOp diffValue = + rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, + op.getCondition(), lowerBoundValue); + + // Use unsigned comparison to check if the condition is in the range. + cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, diffValue); + cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, rangeLength); + + cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( + op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, + uDiffValue, uRangeLength); + rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, + defaultDestination); + return resBlock; + } + + mlir::LogicalResult + matchAndRewrite(cir::SwitchOp op, + mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector<CaseOp> cases; + op.collectCases(cases); + + // Empty switch statement: just erase it. + if (cases.empty()) { + rewriter.eraseOp(op); + return mlir::success(); + } + + // Create exit block from the next node of cir.switch op. + mlir::Block *exitBlock = rewriter.splitBlock( + rewriter.getBlock(), op->getNextNode()->getIterator()); + + // We lower cir.switch op in the following process: + // 1. Inline the region from the switch op after switch op. + // 2. Traverse each cir.case op: + // a. Record the entry block, block arguments and condition for every + // case. b. Inline the case region after the case op. + // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the + // recorded block and conditions. + + // inline everything from switch body between the switch op and the exit + // block. + { + cir::YieldOp switchYield = nullptr; + // Clear switch operation. + for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) + if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) + switchYield = yieldOp; + + assert(!op.getBody().empty()); + mlir::Block *originalBlock = op->getBlock(); + mlir::Block *swopBlock = + rewriter.splitBlock(originalBlock, op->getIterator()); + rewriter.inlineRegionBefore(op.getBody(), exitBlock); + + if (switchYield) + rewriteYieldOp(rewriter, switchYield, exitBlock); + + rewriter.setInsertionPointToEnd(originalBlock); + rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); + } + + // Allocate required data structures (disconsider default case in + // vectors). + llvm::SmallVector<mlir::APInt, 8> caseValues; + llvm::SmallVector<mlir::Block *, 8> caseDestinations; + llvm::SmallVector<mlir::ValueRange, 8> caseOperands; + + llvm::SmallVector<std::pair<APInt, APInt>> rangeValues; + llvm::SmallVector<mlir::Block *> rangeDestinations; + llvm::SmallVector<mlir::ValueRange> rangeOperands; + + // Initialize default case as optional. + mlir::Block *defaultDestination = exitBlock; + mlir::ValueRange defaultOperands = exitBlock->getArguments(); + + // Digest the case statements values and bodies. + for (auto caseOp : cases) { + mlir::Region ®ion = caseOp.getCaseRegion(); + + // Found default case: save destination and operands. + switch (caseOp.getKind()) { + case cir::CaseOpKind::Default: + defaultDestination = ®ion.front(); + defaultOperands = defaultDestination->getArguments(); + break; + case cir::CaseOpKind::Range: + assert(caseOp.getValue().size() == 2 && + "Case range should have 2 case value"); + rangeValues.push_back( + {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(), + cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()}); + rangeDestinations.push_back(®ion.front()); + rangeOperands.push_back(rangeDestinations.back()->getArguments()); + break; + case cir::CaseOpKind::Anyof: + case cir::CaseOpKind::Equal: + // AnyOf cases kind can have multiple values, hence the loop below. + for (auto &value : caseOp.getValue()) { + caseValues.push_back(cast<cir::IntAttr>(value).getValue()); + caseDestinations.push_back(®ion.front()); + caseOperands.push_back(caseDestinations.back()->getArguments()); + } + break; + } + + // Handle break statements. + walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>( + region, [&](mlir::Operation *op) { + if (!isa<cir::BreakOp>(op)) + return mlir::WalkResult::advance(); + + lowerTerminator(op, exitBlock, rewriter); + return mlir::WalkResult::skip(); + }); + + // Track fallthrough in cases. + for (auto &blk : region.getBlocks()) { + if (blk.getNumSuccessors()) + continue; + + if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) { + mlir::Operation *nextOp = caseOp->getNextNode(); + assert(nextOp && "caseOp is not expected to be the last op"); + mlir::Block *oldBlock = nextOp->getBlock(); + mlir::Block *newBlock = + rewriter.splitBlock(oldBlock, nextOp->getIterator()); + rewriter.setInsertionPointToEnd(oldBlock); + rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(), + newBlock); + rewriteYieldOp(rewriter, yieldOp, newBlock); + } + } + + mlir::Block *oldBlock = caseOp->getBlock(); + mlir::Block *newBlock = + rewriter.splitBlock(oldBlock, caseOp->getIterator()); + + mlir::Block &entryBlock = caseOp.getCaseRegion().front(); + rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock); + + // Create a branch to the entry of the inlined region. + rewriter.setInsertionPointToEnd(oldBlock); + rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock); + } + + // Remove all cases since we've inlined the regions. + for (auto caseOp : cases) { + mlir::Block *caseBlock = caseOp->getBlock(); + // Erase the block with no predecessors here to make the generated code + // simpler a little bit. + if (caseBlock->hasNoPredecessors()) + rewriter.eraseBlock(caseBlock); + else + rewriter.eraseOp(caseOp); + } + + for (size_t index = 0; index < rangeValues.size(); ++index) { + APInt lowerBound = rangeValues[index].first; + APInt upperBound = rangeValues[index].second; + + // The case range is unreachable, skip it. + if (lowerBound.sgt(upperBound)) + continue; + + // If range is small, add multiple switch instruction cases. + // This magical number is from the original CGStmt code. + constexpr int kSmallRangeThreshold = 64; + if ((upperBound - lowerBound) ---------------- andykaylor wrote:
Can you add a test case for a range that is within this threshold? https://github.com/llvm/llvm-project/pull/139154 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits