Author: Morris Hafner Date: 2025-04-30T18:15:51+02:00 New Revision: d7f096e3fe611ae2cc7403c3cf2f88255a47b61d
URL: https://github.com/llvm/llvm-project/commit/d7f096e3fe611ae2cc7403c3cf2f88255a47b61d DIFF: https://github.com/llvm/llvm-project/commit/d7f096e3fe611ae2cc7403c3cf2f88255a47b61d.diff LOG: [CIR] Upstream TernaryOp (#137184) This patch adds TernaryOp to CIR plus a pass that flattens the operator in FlattenCFG. This is the first PR out of (probably) 3 wrt. TernaryOp. I split the patches up to make reviewing them easier. As such, this PR is only about adding the CIR operation. The next PR will be about the CodeGen bits from the C++ conditional operator and the final one will add the cir-simplify transform for TernaryOp and SelectOp. --------- Co-authored-by: Morris Hafner <mhaf...@nvidia.com> Co-authored-by: Andy Kaylor <akay...@nvidia.com> Added: clang/test/CIR/IR/ternary.cir clang/test/CIR/Lowering/ternary.cir clang/test/CIR/Transforms/ternary.cir Modified: clang/include/clang/CIR/Dialect/IR/CIROps.td clang/lib/CIR/Dialect/IR/CIRDialect.cpp clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp Removed: ################################################################################ diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 0492f0bc931e7..8319364b9e5e3 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -610,9 +610,9 @@ def ConditionOp : CIR_Op<"condition", [ //===----------------------------------------------------------------------===// def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator, - ParentOneOf<["IfOp", "ScopeOp", "SwitchOp", - "WhileOp", "ForOp", "CaseOp", - "DoWhileOp"]>]> { + ParentOneOf<["CaseOp", "DoWhileOp", "ForOp", + "IfOp", "ScopeOp", "SwitchOp", + "TernaryOp", "WhileOp"]>]> { let summary = "Represents the default branching behaviour of a region"; let description = [{ The `cir.yield` operation terminates regions on diff erent CIR operations, @@ -1462,6 +1462,63 @@ def SelectOp : CIR_Op<"select", [Pure, }]; } +//===----------------------------------------------------------------------===// +// TernaryOp +//===----------------------------------------------------------------------===// + +def TernaryOp : CIR_Op<"ternary", + [DeclareOpInterfaceMethods<RegionBranchOpInterface>, + RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> { + let summary = "The `cond ? a : b` C/C++ ternary operation"; + let description = [{ + The `cir.ternary` operation represents C/C++ ternary, much like a `select` + operation. The first argument is a `cir.bool` condition to evaluate, followed + by two regions to execute (true or false). This is diff erent from `cir.if` + since each region is one block sized and the `cir.yield` closing the block + scope should have one argument. + + `cir.ternary` also represents the GNU binary conditional operator ?: which + reuses the parent operation for both the condition and the true branch to + evaluate it only once. + + Example: + + ```mlir + // cond = a && b; + + %x = cir.ternary (%cond, true_region { + ... + cir.yield %a : i32 + }, false_region { + ... + cir.yield %b : i32 + }) -> i32 + ``` + }]; + let arguments = (ins CIR_BoolType:$cond); + let regions = (region AnyRegion:$trueRegion, + AnyRegion:$falseRegion); + let results = (outs Optional<CIR_AnyType>:$result); + + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<(ins "mlir::Value":$cond, + "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$trueBuilder, + "llvm::function_ref<void(mlir::OpBuilder &, mlir::Location)>":$falseBuilder) + > + ]; + + // All constraints already verified elsewhere. + let hasVerifier = 0; + + let assemblyFormat = [{ + `(` $cond `,` + `true` $trueRegion `,` + `false` $falseRegion + `)` `:` functional-type(operands, results) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // GlobalOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 25993063ee7fd..21b77b5327ca7 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1187,6 +1187,49 @@ LogicalResult cir::BinOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// TernaryOp +//===----------------------------------------------------------------------===// + +/// Given the region at `point`, or the parent operation if `point` is None, +/// return the successor regions. These are the regions that may be selected +/// during the flow of control. `operands` is a set of optional attributes that +/// correspond to a constant value for each operand, or null if that operand is +/// not a constant. +void cir::TernaryOp::getSuccessorRegions( + mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> ®ions) { + // The `true` and the `false` region branch back to the parent operation. + if (!point.isParent()) { + regions.push_back(RegionSuccessor(this->getODSResults(0))); + return; + } + + // When branching from the parent operation, both the true and false + // regions are considered possible successors + regions.push_back(RegionSuccessor(&getTrueRegion())); + regions.push_back(RegionSuccessor(&getFalseRegion())); +} + +void cir::TernaryOp::build( + OpBuilder &builder, OperationState &result, Value cond, + function_ref<void(OpBuilder &, Location)> trueBuilder, + function_ref<void(OpBuilder &, Location)> falseBuilder) { + result.addOperands(cond); + OpBuilder::InsertionGuard guard(builder); + Region *trueRegion = result.addRegion(); + Block *block = builder.createBlock(trueRegion); + trueBuilder(builder, result.location); + Region *falseRegion = result.addRegion(); + builder.createBlock(falseRegion); + falseBuilder(builder, result.location); + + auto yield = dyn_cast<YieldOp>(block->getTerminator()); + assert((yield && yield.getNumOperands() <= 1) && + "expected zero or one result type"); + if (yield.getNumOperands() == 1) + result.addTypes(TypeRange{yield.getOperandTypes().front()}); +} + //===----------------------------------------------------------------------===// // ShiftOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp index 72ccfa8d4e14e..4a936d33b022a 100644 --- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp +++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp @@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening } }; +class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> { +public: + using OpRewritePattern<cir::TernaryOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cir::TernaryOp op, + mlir::PatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Block *condBlock = rewriter.getInsertionBlock(); + Block::iterator opPosition = rewriter.getInsertionPoint(); + Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition); + llvm::SmallVector<mlir::Location, 2> locs; + // Ternary result is optional, make sure to populate the location only + // when relevant. + if (op->getResultTypes().size()) + locs.push_back(loc); + Block *continueBlock = + rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs); + rewriter.create<cir::BrOp>(loc, remainingOpsBlock); + + Region &trueRegion = op.getTrueRegion(); + Block *trueBlock = &trueRegion.front(); + mlir::Operation *trueTerminator = trueRegion.back().getTerminator(); + rewriter.setInsertionPointToEnd(&trueRegion.back()); + auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator); + + rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(), + continueBlock); + rewriter.inlineRegionBefore(trueRegion, continueBlock); + + Block *falseBlock = continueBlock; + Region &falseRegion = op.getFalseRegion(); + + falseBlock = &falseRegion.front(); + mlir::Operation *falseTerminator = falseRegion.back().getTerminator(); + rewriter.setInsertionPointToEnd(&falseRegion.back()); + auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator); + rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, falseYieldOp.getArgs(), + continueBlock); + rewriter.inlineRegionBefore(falseRegion, continueBlock); + + rewriter.setInsertionPointToEnd(condBlock); + rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock); + + rewriter.replaceOp(op, continueBlock->getArguments()); + + // Ok, we're done! + return mlir::success(); + } +}; + void populateFlattenCFGPatterns(RewritePatternSet &patterns) { - patterns - .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, CIRScopeOpFlattening>( - patterns.getContext()); + patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, + CIRScopeOpFlattening, CIRTernaryOpFlattening>( + patterns.getContext()); } void CIRFlattenCFGPass::runOnOperation() { @@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() { getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) { assert(!cir::MissingFeatures::ifOp()); assert(!cir::MissingFeatures::switchOp()); - assert(!cir::MissingFeatures::ternaryOp()); assert(!cir::MissingFeatures::tryOp()); - if (isa<IfOp, ScopeOp, LoopOpInterface>(op)) + if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op)) ops.push_back(op); }); diff --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir new file mode 100644 index 0000000000000..3827dc77726df --- /dev/null +++ b/clang/test/CIR/IR/ternary.cir @@ -0,0 +1,30 @@ +// RUN: cir-opt %s | cir-opt | FileCheck %s +!u32i = !cir.int<u, 32> + +module { + cir.func @blue(%arg0: !cir.bool) -> !u32i { + %0 = cir.ternary(%arg0, true { + %a = cir.const #cir.int<0> : !u32i + cir.yield %a : !u32i + }, false { + %b = cir.const #cir.int<1> : !u32i + cir.yield %b : !u32i + }) : (!cir.bool) -> !u32i + cir.return %0 : !u32i + } +} + +// CHECK: module { + +// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i { +// CHECK: %0 = cir.ternary(%arg0, true { +// CHECK: %1 = cir.const #cir.int<0> : !u32i +// CHECK: cir.yield %1 : !u32i +// CHECK: }, false { +// CHECK: %1 = cir.const #cir.int<1> : !u32i +// CHECK: cir.yield %1 : !u32i +// CHECK: }) : (!cir.bool) -> !u32i +// CHECK: cir.return %0 : !u32i +// CHECK: } + +// CHECK: } diff --git a/clang/test/CIR/Lowering/ternary.cir b/clang/test/CIR/Lowering/ternary.cir new file mode 100644 index 0000000000000..247c6ae3a1e17 --- /dev/null +++ b/clang/test/CIR/Lowering/ternary.cir @@ -0,0 +1,30 @@ +// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s +// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s + +!u32i = !cir.int<u, 32> + +module { + cir.func @blue(%arg0: !cir.bool) -> !u32i { + %0 = cir.ternary(%arg0, true { + %a = cir.const #cir.int<0> : !u32i + cir.yield %a : !u32i + }, false { + %b = cir.const #cir.int<1> : !u32i + cir.yield %b : !u32i + }) : (!cir.bool) -> !u32i + cir.return %0 : !u32i + } +} + +// LLVM-LABEL: define i32 {{.*}}@blue( +// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]]) +// LLVM: br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label %[[B2:[[:alnum:]]+]] +// LLVM: [[B1]]: +// LLVM: br label %[[M:[[:alnum:]]+]] +// LLVM: [[B2]]: +// LLVM: br label %[[M]] +// LLVM: [[M]]: +// LLVM: [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ] +// LLVM: br label %[[B3:[[:alnum:]]+]] +// LLVM: [[B3]]: +// LLVM: ret i32 [[R]] diff --git a/clang/test/CIR/Transforms/ternary.cir b/clang/test/CIR/Transforms/ternary.cir new file mode 100644 index 0000000000000..67ef7f95a6b52 --- /dev/null +++ b/clang/test/CIR/Transforms/ternary.cir @@ -0,0 +1,68 @@ +// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s + +!s32i = !cir.int<s, 32> + +module { + cir.func @foo(%arg0: !s32i) -> !s32i { + %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} + %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> + %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i + %3 = cir.const #cir.int<0> : !s32i + %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool + %5 = cir.ternary(%4, true { + %7 = cir.const #cir.int<3> : !s32i + cir.yield %7 : !s32i + }, false { + %7 = cir.const #cir.int<5> : !s32i + cir.yield %7 : !s32i + }) : (!cir.bool) -> !s32i + cir.store %5, %1 : !s32i, !cir.ptr<!s32i> + %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i + cir.return %6 : !s32i + } + +// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i { +// CHECK: %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64} +// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64} +// CHECK: cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i> +// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i +// CHECK: %3 = cir.const #cir.int<0> : !s32i +// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool +// CHECK: cir.brcond %4 ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: %5 = cir.const #cir.int<3> : !s32i +// CHECK: cir.br ^bb3(%5 : !s32i) +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: %6 = cir.const #cir.int<5> : !s32i +// CHECK: cir.br ^bb3(%6 : !s32i) +// CHECK: ^bb3(%7: !s32i): // 2 preds: ^bb1, ^bb2 +// CHECK: cir.br ^bb4 +// CHECK: ^bb4: // pred: ^bb3 +// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i> +// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i +// CHECK: cir.return %8 : !s32i +// CHECK: } + + cir.func @foo2(%arg0: !cir.bool) { + cir.ternary(%arg0, true { + cir.yield + }, false { + cir.yield + }) : (!cir.bool) -> () + cir.return + } + +// CHECK: cir.func @foo2(%arg0: !cir.bool) { +// CHECK: cir.brcond %arg0 ^bb1, ^bb2 +// CHECK: ^bb1: // pred: ^bb0 +// CHECK: cir.br ^bb3 +// CHECK: ^bb2: // pred: ^bb0 +// CHECK: cir.br ^bb3 +// CHECK: ^bb3: // 2 preds: ^bb1, ^bb2 +// CHECK: cir.br ^bb4 +// CHECK: ^bb4: // pred: ^bb3 +// CHECK: cir.return +// CHECK: } + +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits