llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clangir Author: None (Andres-Salamanca) <details> <summary>Changes</summary> This PR adds support for the `FlattenCFG` transformation on `switch` statements. It also introduces the `SwitchFlatOp`, which is necessary for subsequent lowering to LLVM. --- Patch is 28.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/139154.diff 6 Files Affected: - (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+46) - (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+97) - (modified) clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp (+14-1) - (modified) clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp (+231-4) - (added) clang/test/CIR/IR/switch-flat.cir (+68) - (added) clang/test/CIR/Transforms/switch.cir (+278) ``````````diff diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 7ffa10464dcd3..914af6d1dc6bd 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -971,6 +971,52 @@ def SwitchOp : CIR_Op<"switch", }]; } +//===----------------------------------------------------------------------===// +// SwitchFlatOp +//===----------------------------------------------------------------------===// + +def SwitchFlatOp : CIR_Op<"switch.flat", [AttrSizedOperandSegments, + Terminator]> { + + let description = [{ + The `cir.switch.flat` operation is a region-less and simplified + version of the `cir.switch`. + It's representation is closer to LLVM IR dialect + than the C/C++ language feature. + }]; + + let arguments = (ins + CIR_IntType:$condition, + Variadic<AnyType>:$defaultOperands, + VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands, + ArrayAttr:$case_values, + DenseI32ArrayAttr:$case_operand_segments + ); + + let successors = (successor + AnySuccessor:$defaultDestination, + VariadicSuccessor<AnySuccessor>:$caseDestinations + ); + + let assemblyFormat = [{ + $condition `:` type($condition) `,` + $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? + custom<SwitchFlatOpCases>(ref(type($condition)), $case_values, + $caseDestinations, $caseOperands, + type($caseOperands)) + attr-dict + }]; + + let builders = [ + OpBuilder<(ins "mlir::Value":$condition, + "mlir::Block *":$defaultDestination, + "mlir::ValueRange":$defaultOperands, + CArg<"llvm::ArrayRef<llvm::APInt>", "{}">:$caseValues, + CArg<"mlir::BlockRange", "{}">:$caseDestinations, + CArg<"llvm::ArrayRef<mlir::ValueRange>", "{}">:$caseOperands)> + ]; +} + //===----------------------------------------------------------------------===// // BrOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index b131edaf403ed..ca03013edb485 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -22,6 +22,7 @@ #include "clang/CIR/Dialect/IR/CIROpsDialect.cpp.inc" #include "clang/CIR/Dialect/IR/CIROpsEnums.cpp.inc" #include "clang/CIR/MissingFeatures.h" +#include <numeric> using namespace mlir; using namespace cir; @@ -962,6 +963,102 @@ bool cir::SwitchOp::isSimpleForm(llvm::SmallVectorImpl<CaseOp> &cases) { }); } +//===----------------------------------------------------------------------===// +// SwitchFlatOp +//===----------------------------------------------------------------------===// + +void cir::SwitchFlatOp::build(OpBuilder &builder, OperationState &result, + Value value, Block *defaultDestination, + ValueRange defaultOperands, + ArrayRef<APInt> caseValues, + BlockRange caseDestinations, + ArrayRef<ValueRange> caseOperands) { + + std::vector<mlir::Attribute> caseValuesAttrs; + for (auto &val : caseValues) { + caseValuesAttrs.push_back(cir::IntAttr::get(value.getType(), val)); + } + mlir::ArrayAttr attrs = ArrayAttr::get(builder.getContext(), caseValuesAttrs); + + build(builder, result, value, defaultOperands, caseOperands, attrs, + defaultDestination, caseDestinations); +} + +/// <cases> ::= `[` (case (`,` case )* )? `]` +/// <case> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)? +static ParseResult parseSwitchFlatOpCases( + OpAsmParser &parser, Type flagType, mlir::ArrayAttr &caseValues, + SmallVectorImpl<Block *> &caseDestinations, + SmallVectorImpl<llvm::SmallVector<OpAsmParser::UnresolvedOperand>> + &caseOperands, + SmallVectorImpl<llvm::SmallVector<Type>> &caseOperandTypes) { + if (failed(parser.parseLSquare())) + return failure(); + if (succeeded(parser.parseOptionalRSquare())) + return success(); + llvm::SmallVector<mlir::Attribute> values; + + auto parseCase = [&]() { + int64_t value = 0; + if (failed(parser.parseInteger(value))) + return failure(); + + values.push_back(cir::IntAttr::get(flagType, value)); + + Block *destination; + llvm::SmallVector<OpAsmParser::UnresolvedOperand> operands; + llvm::SmallVector<Type> operandTypes; + if (parser.parseColon() || parser.parseSuccessor(destination)) + return failure(); + if (!parser.parseOptionalLParen()) { + if (parser.parseOperandList(operands, OpAsmParser::Delimiter::None, + /*allowResultNumber=*/false) || + parser.parseColonTypeList(operandTypes) || parser.parseRParen()) + return failure(); + } + caseDestinations.push_back(destination); + caseOperands.emplace_back(operands); + caseOperandTypes.emplace_back(operandTypes); + return success(); + }; + if (failed(parser.parseCommaSeparatedList(parseCase))) + return failure(); + + caseValues = ArrayAttr::get(flagType.getContext(), values); + + return parser.parseRSquare(); +} + +static void printSwitchFlatOpCases(OpAsmPrinter &p, cir::SwitchFlatOp op, + Type flagType, mlir::ArrayAttr caseValues, + SuccessorRange caseDestinations, + OperandRangeRange caseOperands, + const TypeRangeRange &caseOperandTypes) { + p << '['; + p.printNewline(); + if (!caseValues) { + p << ']'; + return; + } + + size_t index = 0; + llvm::interleave( + llvm::zip(caseValues, caseDestinations), + [&](auto i) { + p << " "; + mlir::Attribute a = std::get<0>(i); + p << mlir::cast<cir::IntAttr>(a).getValue(); + p << ": "; + p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]); + }, + [&] { + p << ','; + p.printNewline(); + }); + p.printNewline(); + p << ']'; +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp index 3b4c7bc613133..edbb848322d41 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp @@ -84,6 +84,19 @@ struct RemoveEmptyScope : public OpRewritePattern<ScopeOp> { } }; +struct RemoveEmptySwitch : public OpRewritePattern<SwitchOp> { + using OpRewritePattern<SwitchOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(SwitchOp op, + PatternRewriter &rewriter) const final { + if (!(op.getBody().empty() || isa<YieldOp>(op.getBody().front().front()))) + return failure(); + + rewriter.eraseOp(op); + return success(); + } +}; + //===----------------------------------------------------------------------===// // CIRCanonicalizePass //===----------------------------------------------------------------------===// @@ -127,7 +140,7 @@ void CIRCanonicalizePass::runOnOperation() { assert(!cir::MissingFeatures::callOp()); // CastOp and UnaryOp are here to perform a manual `fold` in // applyOpPatternsGreedily. - if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SelectOp, UnaryOp>(op)) + if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp>(op)) ops.push_back(op); }); diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 4a936d33b022a..70f383b556567 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -171,6 +171,232 @@ class CIRScopeOpFlattening : public mlir::OpRewritePattern<cir::ScopeOp> { } }; +class CIRSwitchOpFlattening : public mlir::OpRewritePattern<cir::SwitchOp> { +public: + using OpRewritePattern<cir::SwitchOp>::OpRewritePattern; + + inline void rewriteYieldOp(mlir::PatternRewriter &rewriter, + cir::YieldOp yieldOp, + mlir::Block *destination) const { + rewriter.setInsertionPoint(yieldOp); + rewriter.replaceOpWithNewOp<cir::BrOp>(yieldOp, yieldOp.getOperands(), + destination); + } + + // Return the new defaultDestination block. + Block *condBrToRangeDestination(cir::SwitchOp op, + mlir::PatternRewriter &rewriter, + mlir::Block *rangeDestination, + mlir::Block *defaultDestination, + const APInt &lowerBound, + const APInt &upperBound) const { + assert(lowerBound.sle(upperBound) && "Invalid range"); + mlir::Block *resBlock = rewriter.createBlock(defaultDestination); + cir::IntType sIntType = cir::IntType::get(op.getContext(), 32, true); + cir::IntType uIntType = cir::IntType::get(op.getContext(), 32, false); + + cir::ConstantOp rangeLength = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, upperBound - lowerBound)); + + cir::ConstantOp lowerBoundValue = rewriter.create<cir::ConstantOp>( + op.getLoc(), cir::IntAttr::get(sIntType, lowerBound)); + cir::BinOp diffValue = + rewriter.create<cir::BinOp>(op.getLoc(), sIntType, cir::BinOpKind::Sub, + op.getCondition(), lowerBoundValue); + + // Use unsigned comparison to check if the condition is in the range. + cir::CastOp uDiffValue = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, diffValue); + cir::CastOp uRangeLength = rewriter.create<cir::CastOp>( + op.getLoc(), uIntType, CastKind::integral, rangeLength); + + cir::CmpOp cmpResult = rewriter.create<cir::CmpOp>( + op.getLoc(), cir::BoolType::get(op.getContext()), cir::CmpOpKind::le, + uDiffValue, uRangeLength); + rewriter.create<cir::BrCondOp>(op.getLoc(), cmpResult, rangeDestination, + defaultDestination); + return resBlock; + } + + mlir::LogicalResult + matchAndRewrite(cir::SwitchOp op, + mlir::PatternRewriter &rewriter) const override { + llvm::SmallVector<CaseOp> cases; + op.collectCases(cases); + + // Empty switch statement: just erase it. + if (cases.empty()) { + rewriter.eraseOp(op); + return mlir::success(); + } + + // Create exit block from the next node of cir.switch op. + mlir::Block *exitBlock = rewriter.splitBlock( + rewriter.getBlock(), op->getNextNode()->getIterator()); + + // We lower cir.switch op in the following process: + // 1. Inline the region from the switch op after switch op. + // 2. Traverse each cir.case op: + // a. Record the entry block, block arguments and condition for every + // case. b. Inline the case region after the case op. + // 3. Replace the empty cir.switch.op with the new cir.switchflat op by the + // recorded block and conditions. + + // inline everything from switch body between the switch op and the exit + // block. + { + cir::YieldOp switchYield = nullptr; + // Clear switch operation. + for (auto &block : llvm::make_early_inc_range(op.getBody().getBlocks())) + if (auto yieldOp = dyn_cast<cir::YieldOp>(block.getTerminator())) + switchYield = yieldOp; + + assert(!op.getBody().empty()); + mlir::Block *originalBlock = op->getBlock(); + mlir::Block *swopBlock = + rewriter.splitBlock(originalBlock, op->getIterator()); + rewriter.inlineRegionBefore(op.getBody(), exitBlock); + + if (switchYield) + rewriteYieldOp(rewriter, switchYield, exitBlock); + + rewriter.setInsertionPointToEnd(originalBlock); + rewriter.create<cir::BrOp>(op.getLoc(), swopBlock); + } + + // Allocate required data structures (disconsider default case in + // vectors). + llvm::SmallVector<mlir::APInt, 8> caseValues; + llvm::SmallVector<mlir::Block *, 8> caseDestinations; + llvm::SmallVector<mlir::ValueRange, 8> caseOperands; + + llvm::SmallVector<std::pair<APInt, APInt>> rangeValues; + llvm::SmallVector<mlir::Block *> rangeDestinations; + llvm::SmallVector<mlir::ValueRange> rangeOperands; + + // Initialize default case as optional. + mlir::Block *defaultDestination = exitBlock; + mlir::ValueRange defaultOperands = exitBlock->getArguments(); + + // Digest the case statements values and bodies. + for (auto caseOp : cases) { + mlir::Region ®ion = caseOp.getCaseRegion(); + + // Found default case: save destination and operands. + switch (caseOp.getKind()) { + case cir::CaseOpKind::Default: + defaultDestination = ®ion.front(); + defaultOperands = defaultDestination->getArguments(); + break; + case cir::CaseOpKind::Range: + assert(caseOp.getValue().size() == 2 && + "Case range should have 2 case value"); + rangeValues.push_back( + {cast<cir::IntAttr>(caseOp.getValue()[0]).getValue(), + cast<cir::IntAttr>(caseOp.getValue()[1]).getValue()}); + rangeDestinations.push_back(®ion.front()); + rangeOperands.push_back(rangeDestinations.back()->getArguments()); + break; + case cir::CaseOpKind::Anyof: + case cir::CaseOpKind::Equal: + // AnyOf cases kind can have multiple values, hence the loop below. + for (auto &value : caseOp.getValue()) { + caseValues.push_back(cast<cir::IntAttr>(value).getValue()); + caseDestinations.push_back(®ion.front()); + caseOperands.push_back(caseDestinations.back()->getArguments()); + } + break; + } + + // Handle break statements. + walkRegionSkipping<cir::LoopOpInterface, cir::SwitchOp>( + region, [&](mlir::Operation *op) { + if (!isa<cir::BreakOp>(op)) + return mlir::WalkResult::advance(); + + lowerTerminator(op, exitBlock, rewriter); + return mlir::WalkResult::skip(); + }); + + // Track fallthrough in cases. + for (auto &blk : region.getBlocks()) { + if (blk.getNumSuccessors()) + continue; + + if (auto yieldOp = dyn_cast<cir::YieldOp>(blk.getTerminator())) { + mlir::Operation *nextOp = caseOp->getNextNode(); + assert(nextOp && "caseOp is not expected to be the last op"); + mlir::Block *oldBlock = nextOp->getBlock(); + mlir::Block *newBlock = + rewriter.splitBlock(oldBlock, nextOp->getIterator()); + rewriter.setInsertionPointToEnd(oldBlock); + rewriter.create<cir::BrOp>(nextOp->getLoc(), mlir::ValueRange(), + newBlock); + rewriteYieldOp(rewriter, yieldOp, newBlock); + } + } + + mlir::Block *oldBlock = caseOp->getBlock(); + mlir::Block *newBlock = + rewriter.splitBlock(oldBlock, caseOp->getIterator()); + + mlir::Block &entryBlock = caseOp.getCaseRegion().front(); + rewriter.inlineRegionBefore(caseOp.getCaseRegion(), newBlock); + + // Create a branch to the entry of the inlined region. + rewriter.setInsertionPointToEnd(oldBlock); + rewriter.create<cir::BrOp>(caseOp.getLoc(), &entryBlock); + } + + // Remove all cases since we've inlined the regions. + for (auto caseOp : cases) { + mlir::Block *caseBlock = caseOp->getBlock(); + // Erase the block with no predecessors here to make the generated code + // simpler a little bit. + if (caseBlock->hasNoPredecessors()) + rewriter.eraseBlock(caseBlock); + else + rewriter.eraseOp(caseOp); + } + + for (size_t index = 0; index < rangeValues.size(); ++index) { + APInt lowerBound = rangeValues[index].first; + APInt upperBound = rangeValues[index].second; + + // The case range is unreachable, skip it. + if (lowerBound.sgt(upperBound)) + continue; + + // If range is small, add multiple switch instruction cases. + // This magical number is from the original CGStmt code. + constexpr int kSmallRangeThreshold = 64; + if ((upperBound - lowerBound) + .ult(llvm::APInt(32, kSmallRangeThreshold))) { + for (APInt iValue = lowerBound; iValue.sle(upperBound); + (void)iValue++) { + caseValues.push_back(iValue); + caseOperands.push_back(rangeOperands[index]); + caseDestinations.push_back(rangeDestinations[index]); + } + continue; + } + + defaultDestination = + condBrToRangeDestination(op, rewriter, rangeDestinations[index], + defaultDestination, lowerBound, upperBound); + defaultOperands = rangeOperands[index]; + } + + // Set switch op to branch to the newly created blocks. + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp<cir::SwitchFlatOp>( + op, op.getCondition(), defaultDestination, defaultOperands, caseValues, + caseDestinations, caseOperands); + + return mlir::success(); + } +}; + class CIRLoopOpInterfaceFlattening : public mlir::OpInterfaceRewritePattern<cir::LoopOpInterface> { public: @@ -306,9 +532,10 @@ class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { }; void populateFlattenCFGPatterns(RewritePatternSet &patterns) { - patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, - CIRScopeOpFlattening, CIRTernaryOpFlattening>( - patterns.getContext()); + patterns + .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening, + CIRSwitchOpFlattening, CIRTernaryOpFlattening>( + patterns.getContext()); } void CIRFlattenCFGPass::runOnOperation() { @@ -321,7 +548,7 @@ void CIRFlattenCFGPass::runOnOperation() { assert(!cir::MissingFeatures::ifOp()); assert(!cir::MissingFeatures::switchOp()); assert(!cir::MissingFeatures::tryOp()); - if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op)) + if (isa<IfOp, ScopeOp, SwitchOp, LoopOpInterface, TernaryOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/IR/switch-flat.cir b/clang/test/CIR/IR/switch-flat.cir new file mode 100644 index 0000000000000..b072c224b4a2c --- /dev/null +++ b/clang/test/CIR/IR/switch-flat.cir @@ -0,0 +1,68 @@ +// RUN: cir-opt %s | FileCheck %s +!s32i = !cir.int<s, 32> + +cir.func @FlatSwitchWithoutDefault(%arg0: !s32i) { + cir.switch.flat %arg0 : !s32i, ^bb2 [ + 1: ^bb1 + ] + ^bb1: + cir.br ^bb2 + ^bb2: + cir.return +} + +// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [ +// CHECK-NEXT: 1: ^bb1 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: cir.br ^bb2 +// CHECK-NEXT: ^bb2: +//CHECK-NEXT: cir.return + +cir.func @FlatSwitchWithDefault(%arg0: !s32i) { + cir.switch.flat %arg0 : !s32i, ^bb2 [ + 1: ^bb1 + ] + ^bb1: + cir.br ^bb3 + ^bb2: + cir.br ^bb3 + ^bb3: + cir.return +} + +// CHECK: cir.switch.flat %arg0 : !s32i, ^bb2 [ +// CHECK-NEXT: 1: ^bb1 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb2: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb3: +// CHECK-NEXT: cir.return + +cir.func @switchWithOperands(%arg0: !s32i, %arg1: !s32i, %arg2: !s32i) { + cir.switch.flat %arg0 : !s32i, ^bb3 [ + 0: ^bb1(%arg1, %arg2 : !s32i, !s32i), + 1: ^bb2(%arg2, %arg1 : !s32i, !s32i) + ] +^bb1: + cir.br ^bb3 + +^bb2: + cir.br ^bb3 + +^bb3: + cir.return +} + +// CHECK: cir.switch.flat %arg0 : !s32i, ^bb3 [ +// CHECK-NEXT: 0: ^bb1(%arg1, %arg2 : !s32i, !s32i), +// CHECK-NEXT: 1: ^bb2(%arg2, %arg1 : !s32i, !s32i) +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb1: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb2: +// CHECK-NEXT: cir.br ^bb3 +// CHECK-NEXT: ^bb3: +// CHECK-NEXT: cir.return diff --git a/clang/test/CIR/Transforms/switch.cir b/clang/test/CIR/Transforms/switch.cir new file mode 100644 index 0000000000000..a05cf37e39728 --- /dev/null +++ b/clang/te... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/139154 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits