Author: Morris Hafner
Date: 2025-04-30T18:15:51+02:00
New Revision: d7f096e3fe611ae2cc7403c3cf2f88255a47b61d

URL: 
https://github.com/llvm/llvm-project/commit/d7f096e3fe611ae2cc7403c3cf2f88255a47b61d
DIFF: 
https://github.com/llvm/llvm-project/commit/d7f096e3fe611ae2cc7403c3cf2f88255a47b61d.diff

LOG: [CIR] Upstream TernaryOp (#137184)

This patch adds TernaryOp to CIR plus a pass that flattens the operator
in FlattenCFG.

This is the first PR out of (probably) 3 wrt. TernaryOp. I split the
patches up to make reviewing them easier. As such, this PR is only about
adding the CIR operation. The next PR will be about the CodeGen bits
from the C++ conditional operator and the final one will add the
cir-simplify transform for TernaryOp and SelectOp.

---------

Co-authored-by: Morris Hafner <mhaf...@nvidia.com>
Co-authored-by: Andy Kaylor <akay...@nvidia.com>

Added: 
    clang/test/CIR/IR/ternary.cir
    clang/test/CIR/Lowering/ternary.cir
    clang/test/CIR/Transforms/ternary.cir

Modified: 
    clang/include/clang/CIR/Dialect/IR/CIROps.td
    clang/lib/CIR/Dialect/IR/CIRDialect.cpp
    clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp

Removed: 
    


################################################################################
diff  --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index 0492f0bc931e7..8319364b9e5e3 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -610,9 +610,9 @@ def ConditionOp : CIR_Op<"condition", [
 
//===----------------------------------------------------------------------===//
 
 def YieldOp : CIR_Op<"yield", [ReturnLike, Terminator,
-                               ParentOneOf<["IfOp", "ScopeOp", "SwitchOp",
-                                            "WhileOp", "ForOp", "CaseOp",
-                                            "DoWhileOp"]>]> {
+                               ParentOneOf<["CaseOp", "DoWhileOp", "ForOp",
+                                            "IfOp", "ScopeOp", "SwitchOp",
+                                            "TernaryOp", "WhileOp"]>]> {
   let summary = "Represents the default branching behaviour of a region";
   let description = [{
     The `cir.yield` operation terminates regions on 
diff erent CIR operations,
@@ -1462,6 +1462,63 @@ def SelectOp : CIR_Op<"select", [Pure,
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+def TernaryOp : CIR_Op<"ternary",
+      [DeclareOpInterfaceMethods<RegionBranchOpInterface>,
+       RecursivelySpeculatable, AutomaticAllocationScope, NoRegionArguments]> {
+  let summary = "The `cond ? a : b` C/C++ ternary operation";
+  let description = [{
+    The `cir.ternary` operation represents C/C++ ternary, much like a `select`
+    operation. The first argument is a `cir.bool` condition to evaluate, 
followed
+    by two regions to execute (true or false). This is 
diff erent from `cir.if`
+    since each region is one block sized and the `cir.yield` closing the block
+    scope should have one argument.
+
+    `cir.ternary` also represents the GNU binary conditional operator ?: which
+    reuses the parent operation for both the condition and the true branch to
+    evaluate it only once.
+
+    Example:
+
+    ```mlir
+    // cond = a && b;
+
+    %x = cir.ternary (%cond, true_region {
+      ...
+      cir.yield %a : i32
+    }, false_region {
+      ...
+      cir.yield %b : i32
+    }) -> i32
+    ```
+  }];
+  let arguments = (ins CIR_BoolType:$cond);
+  let regions = (region AnyRegion:$trueRegion,
+                        AnyRegion:$falseRegion);
+  let results = (outs Optional<CIR_AnyType>:$result);
+
+  let skipDefaultBuilders = 1;
+  let builders = [
+    OpBuilder<(ins "mlir::Value":$cond,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$trueBuilder,
+      "llvm::function_ref<void(mlir::OpBuilder &, 
mlir::Location)>":$falseBuilder)
+      >
+  ];
+
+  // All constraints already verified elsewhere.
+  let hasVerifier = 0;
+
+  let assemblyFormat = [{
+    `(` $cond `,`
+      `true` $trueRegion `,`
+      `false` $falseRegion
+    `)` `:` functional-type(operands, results) attr-dict
+  }];
+}
+
 
//===----------------------------------------------------------------------===//
 // GlobalOp
 
//===----------------------------------------------------------------------===//

diff  --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp 
b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
index 25993063ee7fd..21b77b5327ca7 100644
--- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
+++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp
@@ -1187,6 +1187,49 @@ LogicalResult cir::BinOp::verify() {
   return mlir::success();
 }
 
+//===----------------------------------------------------------------------===//
+// TernaryOp
+//===----------------------------------------------------------------------===//
+
+/// Given the region at `point`, or the parent operation if `point` is None,
+/// return the successor regions. These are the regions that may be selected
+/// during the flow of control. `operands` is a set of optional attributes that
+/// correspond to a constant value for each operand, or null if that operand is
+/// not a constant.
+void cir::TernaryOp::getSuccessorRegions(
+    mlir::RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
+  // The `true` and the `false` region branch back to the parent operation.
+  if (!point.isParent()) {
+    regions.push_back(RegionSuccessor(this->getODSResults(0)));
+    return;
+  }
+
+  // When branching from the parent operation, both the true and false
+  // regions are considered possible successors
+  regions.push_back(RegionSuccessor(&getTrueRegion()));
+  regions.push_back(RegionSuccessor(&getFalseRegion()));
+}
+
+void cir::TernaryOp::build(
+    OpBuilder &builder, OperationState &result, Value cond,
+    function_ref<void(OpBuilder &, Location)> trueBuilder,
+    function_ref<void(OpBuilder &, Location)> falseBuilder) {
+  result.addOperands(cond);
+  OpBuilder::InsertionGuard guard(builder);
+  Region *trueRegion = result.addRegion();
+  Block *block = builder.createBlock(trueRegion);
+  trueBuilder(builder, result.location);
+  Region *falseRegion = result.addRegion();
+  builder.createBlock(falseRegion);
+  falseBuilder(builder, result.location);
+
+  auto yield = dyn_cast<YieldOp>(block->getTerminator());
+  assert((yield && yield.getNumOperands() <= 1) &&
+         "expected zero or one result type");
+  if (yield.getNumOperands() == 1)
+    result.addTypes(TypeRange{yield.getOperandTypes().front()});
+}
+
 
//===----------------------------------------------------------------------===//
 // ShiftOp
 
//===----------------------------------------------------------------------===//

diff  --git a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp 
b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
index 72ccfa8d4e14e..4a936d33b022a 100644
--- a/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
+++ b/clang/lib/CIR/Dialect/Transforms/FlattenCFG.cpp
@@ -254,10 +254,61 @@ class CIRLoopOpInterfaceFlattening
   }
 };
 
+class CIRTernaryOpFlattening : public mlir::OpRewritePattern<cir::TernaryOp> {
+public:
+  using OpRewritePattern<cir::TernaryOp>::OpRewritePattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::TernaryOp op,
+                  mlir::PatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    Block *condBlock = rewriter.getInsertionBlock();
+    Block::iterator opPosition = rewriter.getInsertionPoint();
+    Block *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
+    llvm::SmallVector<mlir::Location, 2> locs;
+    // Ternary result is optional, make sure to populate the location only
+    // when relevant.
+    if (op->getResultTypes().size())
+      locs.push_back(loc);
+    Block *continueBlock =
+        rewriter.createBlock(remainingOpsBlock, op->getResultTypes(), locs);
+    rewriter.create<cir::BrOp>(loc, remainingOpsBlock);
+
+    Region &trueRegion = op.getTrueRegion();
+    Block *trueBlock = &trueRegion.front();
+    mlir::Operation *trueTerminator = trueRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&trueRegion.back());
+    auto trueYieldOp = dyn_cast<cir::YieldOp>(trueTerminator);
+
+    rewriter.replaceOpWithNewOp<cir::BrOp>(trueYieldOp, trueYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(trueRegion, continueBlock);
+
+    Block *falseBlock = continueBlock;
+    Region &falseRegion = op.getFalseRegion();
+
+    falseBlock = &falseRegion.front();
+    mlir::Operation *falseTerminator = falseRegion.back().getTerminator();
+    rewriter.setInsertionPointToEnd(&falseRegion.back());
+    auto falseYieldOp = dyn_cast<cir::YieldOp>(falseTerminator);
+    rewriter.replaceOpWithNewOp<cir::BrOp>(falseYieldOp, 
falseYieldOp.getArgs(),
+                                           continueBlock);
+    rewriter.inlineRegionBefore(falseRegion, continueBlock);
+
+    rewriter.setInsertionPointToEnd(condBlock);
+    rewriter.create<cir::BrCondOp>(loc, op.getCond(), trueBlock, falseBlock);
+
+    rewriter.replaceOp(op, continueBlock->getArguments());
+
+    // Ok, we're done!
+    return mlir::success();
+  }
+};
+
 void populateFlattenCFGPatterns(RewritePatternSet &patterns) {
-  patterns
-      .add<CIRIfFlattening, CIRLoopOpInterfaceFlattening, 
CIRScopeOpFlattening>(
-          patterns.getContext());
+  patterns.add<CIRIfFlattening, CIRLoopOpInterfaceFlattening,
+               CIRScopeOpFlattening, CIRTernaryOpFlattening>(
+      patterns.getContext());
 }
 
 void CIRFlattenCFGPass::runOnOperation() {
@@ -269,9 +320,8 @@ void CIRFlattenCFGPass::runOnOperation() {
   getOperation()->walk<mlir::WalkOrder::PostOrder>([&](Operation *op) {
     assert(!cir::MissingFeatures::ifOp());
     assert(!cir::MissingFeatures::switchOp());
-    assert(!cir::MissingFeatures::ternaryOp());
     assert(!cir::MissingFeatures::tryOp());
-    if (isa<IfOp, ScopeOp, LoopOpInterface>(op))
+    if (isa<IfOp, ScopeOp, LoopOpInterface, TernaryOp>(op))
       ops.push_back(op);
   });
 

diff  --git a/clang/test/CIR/IR/ternary.cir b/clang/test/CIR/IR/ternary.cir
new file mode 100644
index 0000000000000..3827dc77726df
--- /dev/null
+++ b/clang/test/CIR/IR/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-opt %s | cir-opt | FileCheck %s
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// CHECK: module  {
+
+// CHECK: cir.func @blue(%arg0: !cir.bool) -> !u32i {
+// CHECK:   %0 = cir.ternary(%arg0, true {
+// CHECK:     %1 = cir.const #cir.int<0> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }, false {
+// CHECK:     %1 = cir.const #cir.int<1> : !u32i
+// CHECK:     cir.yield %1 : !u32i
+// CHECK:   }) : (!cir.bool) -> !u32i
+// CHECK:   cir.return %0 : !u32i
+// CHECK: }
+
+// CHECK: }

diff  --git a/clang/test/CIR/Lowering/ternary.cir 
b/clang/test/CIR/Lowering/ternary.cir
new file mode 100644
index 0000000000000..247c6ae3a1e17
--- /dev/null
+++ b/clang/test/CIR/Lowering/ternary.cir
@@ -0,0 +1,30 @@
+// RUN: cir-translate -cir-to-llvmir --disable-cc-lowering -o %t.ll %s
+// RUN: FileCheck --input-file=%t.ll -check-prefix=LLVM %s
+
+!u32i = !cir.int<u, 32>
+
+module  {
+  cir.func @blue(%arg0: !cir.bool) -> !u32i {
+    %0 = cir.ternary(%arg0, true {
+      %a = cir.const #cir.int<0> : !u32i
+      cir.yield %a : !u32i
+    }, false {
+      %b = cir.const #cir.int<1> : !u32i
+      cir.yield %b : !u32i
+    }) : (!cir.bool) -> !u32i
+    cir.return %0 : !u32i
+  }
+}
+
+// LLVM-LABEL: define i32 {{.*}}@blue(
+// LLVM-SAME: i1 [[PRED:%[[:alnum:]]+]])
+// LLVM:   br i1 [[PRED]], label %[[B1:[[:alnum:]]+]], label 
%[[B2:[[:alnum:]]+]]
+// LLVM: [[B1]]:
+// LLVM:   br label %[[M:[[:alnum:]]+]]
+// LLVM: [[B2]]:
+// LLVM:   br label %[[M]]
+// LLVM: [[M]]:
+// LLVM:   [[R:%[[:alnum:]]+]] = phi i32 [ 1, %[[B2]] ], [ 0, %[[B1]] ]
+// LLVM:   br label %[[B3:[[:alnum:]]+]]
+// LLVM: [[B3]]:
+// LLVM:   ret i32 [[R]]

diff  --git a/clang/test/CIR/Transforms/ternary.cir 
b/clang/test/CIR/Transforms/ternary.cir
new file mode 100644
index 0000000000000..67ef7f95a6b52
--- /dev/null
+++ b/clang/test/CIR/Transforms/ternary.cir
@@ -0,0 +1,68 @@
+// RUN: cir-opt %s -cir-flatten-cfg -o - | FileCheck %s
+
+!s32i = !cir.int<s, 32>
+
+module {
+  cir.func @foo(%arg0: !s32i) -> !s32i {
+    %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 : i64}
+    %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 4 : i64}
+    cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+    %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+    %3 = cir.const #cir.int<0> : !s32i
+    %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+    %5 = cir.ternary(%4, true {
+      %7 = cir.const #cir.int<3> : !s32i
+      cir.yield %7 : !s32i
+    }, false {
+      %7 = cir.const #cir.int<5> : !s32i
+      cir.yield %7 : !s32i
+    }) : (!cir.bool) -> !s32i
+    cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
+    %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+    cir.return %6 : !s32i
+  }
+
+// CHECK: cir.func @foo(%arg0: !s32i) -> !s32i {
+// CHECK:   %0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["y", init] {alignment = 4 
: i64}
+// CHECK:   %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["__retval"] {alignment = 
4 : i64}
+// CHECK:   cir.store %arg0, %0 : !s32i, !cir.ptr<!s32i>
+// CHECK:   %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
+// CHECK:   %3 = cir.const #cir.int<0> : !s32i
+// CHECK:   %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
+// CHECK:    cir.brcond %4 ^bb1, ^bb2
+// CHECK:  ^bb1:  // pred: ^bb0
+// CHECK:    %5 = cir.const #cir.int<3> : !s32i
+// CHECK:    cir.br ^bb3(%5 : !s32i)
+// CHECK:  ^bb2:  // pred: ^bb0
+// CHECK:    %6 = cir.const #cir.int<5> : !s32i
+// CHECK:    cir.br ^bb3(%6 : !s32i)
+// CHECK:  ^bb3(%7: !s32i):  // 2 preds: ^bb1, ^bb2
+// CHECK:    cir.br ^bb4
+// CHECK:  ^bb4:  // pred: ^bb3
+// CHECK:    cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
+// CHECK:    %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
+// CHECK:    cir.return %8 : !s32i
+// CHECK:  }
+
+  cir.func @foo2(%arg0: !cir.bool) {
+    cir.ternary(%arg0, true {
+      cir.yield
+    }, false {
+      cir.yield
+    }) : (!cir.bool) -> ()
+    cir.return
+  }
+
+// CHECK: cir.func @foo2(%arg0: !cir.bool) {
+// CHECK:   cir.brcond %arg0 ^bb1, ^bb2
+// CHECK: ^bb1:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb2:  // pred: ^bb0
+// CHECK:   cir.br ^bb3
+// CHECK: ^bb3:  // 2 preds: ^bb1, ^bb2
+// CHECK:   cir.br ^bb4
+// CHECK: ^bb4:  // pred: ^bb3
+// CHECK:   cir.return
+// CHECK: }
+
+}


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to