llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-clang

Author: Yeongu Choe (YeonguChoe)

<details>
<summary>Changes</summary>

A three-way comparison operation is an operation that determines ordering of 
two numbers in a single operation. I implemented it as specified in the CIR 
documentation. Also I used `clang-format` for code formatting.

Reference
- https://llvm.github.io/clangir/Dialect/ops.html#circmp3way-circmpthreewayop
- https://en.wikipedia.org/wiki/Three-way_comparison
- https://clang.llvm.org/docs/ClangFormatStyleOptions.html

---
Full diff: https://github.com/llvm/llvm-project/pull/186294.diff


6 Files Affected:

- (modified) clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h (+7) 
- (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+34) 
- (modified) clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp (+42-4) 
- (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+82) 
- (added) clang/test/CIR/CodeGenCXX/three-way-comparison.cpp (+22) 
- (added) clang/test/CIR/Lowering/cmp3way.cir (+32) 


``````````diff
diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h 
b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index e60288c40132f..2a81da408121c 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -669,6 +669,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
     return cir::CmpOp::create(*this, loc, kind, lhs, rhs);
   }
 
+  cir::CmpThreeWayOp createThreeWayComparison(mlir::Location loc,
+                                              mlir::Type resultTy,
+                                              mlir::Value lhs, mlir::Value rhs,
+                                              mlir::Attribute info) {
+    return cir::CmpThreeWayOp::create(*this, loc, resultTy, lhs, rhs, info);
+  }
+
   cir::VecCmpOp createVecCompare(mlir::Location loc, cir::CmpOpKind kind,
                                  mlir::Value lhs, mlir::Value rhs) {
     VectorType vecCast = mlir::cast<VectorType>(lhs.getType());
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index a9b98b1f43b3f..24758a0c946f6 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -2071,6 +2071,40 @@ def CIR_CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> {
   let hasCXXABILowering = true;
 }
 
+//===----------------------------------------------------------------------===//
+// CmpThreeWayOp
+//===----------------------------------------------------------------------===//
+
+def CIR_CmpThreeWayStrongInfoAttr
+    : CIR_Attr<"CmpThreeWayStrongInfo", "cmp3way_strong_info"> {
+  let parameters = (ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt);
+  let assemblyFormat =
+      [{`<` `strong` `,` `lt` `=` $lt `,` `eq` `=` $eq `,` `gt` `=` $gt `>`}];
+}
+def CIR_CmpThreeWayPartialInfoAttr
+    : CIR_Attr<"CmpThreeWayPartialInfo", "cmp3way_partial_info"> {
+  let parameters = (ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt,
+      "int64_t":$unordered);
+  let assemblyFormat =
+      [{`<` `partial` `,` `lt` `=` $lt `,` `eq` `=` $eq `,` `gt` `=` $gt `,` 
`unordered` `=` $unordered `>` }];
+}
+def CIR_CmpThreeWayInfoAttr : AnyAttrOf<[CIR_CmpThreeWayStrongInfoAttr,
+                                         CIR_CmpThreeWayPartialInfoAttr]>;
+
+def CIR_CmpThreeWayOp
+    : CIR_Op<"cmp3way", [Pure, SameTypeOperands, ConditionallySpeculatable]> {
+  let summary = "Performs three-way comparison.";
+  let description = [{
+            Three-way comparison takes two operands of the same type and 
determines ordering.
+        }];
+  let arguments = (ins CIR_AnyType:$lhs, CIR_AnyType:$rhs,
+      CIR_CmpThreeWayInfoAttr:$info);
+  let results = (outs CIR_AnySIntType:$result);
+  let assemblyFormat = [{
+    `(` $lhs `:` type($lhs) `,` $rhs `,` qualified($info) `)` `:` 
type($result) attr-dict
+    }];
+}
+
 
//===----------------------------------------------------------------------===//
 // BinOpOverflowOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp 
b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
index 9f390fec97613..19e844664f53c 100644
--- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp
@@ -301,9 +301,7 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
     cgf.cgm.errorNYI(e->getSourceRange(),
                      "AggExprEmitter: VisitSubstNonTypeTemplateParmExpr");
   }
-  void VisitConstantExpr(ConstantExpr *e) {
-    cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitConstantExpr");
-  }
+  void VisitConstantExpr(ConstantExpr *e) { return Visit(e->getSubExpr()); }
   void VisitMemberExpr(MemberExpr *e) { emitAggLoadOfLValue(e); }
   void VisitUnaryDeref(UnaryOperator *e) { emitAggLoadOfLValue(e); }
   void VisitStringLiteral(StringLiteral *e) { emitAggLoadOfLValue(e); }
@@ -326,7 +324,47 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> {
     Visit(e->getRHS());
   }
   void VisitBinCmp(const BinaryOperator *e) {
-    cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitBinCmp");
+    const ComparisonCategoryInfo &CompCategoryInfo =
+        cgf.getContext().CompCategories.getInfoForType(e->getType());
+
+    QualType ArgTy = e->getLHS()->getType();
+    if (ArgTy->isIntegralOrEnumerationType() || ArgTy->isRealFloatingType() ||
+        ArgTy->isNullPtrType() || ArgTy->isPointerType() ||
+        ArgTy->isMemberPointerType()) {
+      mlir::Value lhs = cgf.emitScalarExpr(e->getLHS());
+      mlir::Value rhs = cgf.emitScalarExpr(e->getRHS());
+
+      mlir::Attribute info;
+      if (CompCategoryInfo.isStrong()) {
+        info = cir::CmpThreeWayStrongInfoAttr::get(
+            cgf.getBuilder().getContext(),
+            CompCategoryInfo.getLess()->getIntValue().getSExtValue(),
+            CompCategoryInfo.getEqualOrEquiv()->getIntValue().getSExtValue(),
+            CompCategoryInfo.getGreater()->getIntValue().getSExtValue());
+      } else {
+        info = cir::CmpThreeWayPartialInfoAttr::get(
+            cgf.getBuilder().getContext(),
+            CompCategoryInfo.getLess()->getIntValue().getSExtValue(),
+            CompCategoryInfo.getEqualOrEquiv()->getIntValue().getSExtValue(),
+            CompCategoryInfo.getGreater()->getIntValue().getSExtValue(),
+            CompCategoryInfo.getUnordered()->getIntValue().getSExtValue());
+      }
+      mlir::Type resultTy = cgf.convertType(cgf.getContext().IntTy);
+      mlir::Value result = cgf.getBuilder().createThreeWayComparison(
+          cgf.getLoc(e->getSourceRange()), resultTy, lhs, rhs, info);
+
+      ensureDest(cgf.getLoc(e->getSourceRange()), e->getType());
+      LValue destLValue = cgf.makeAddrLValue(dest.getAddress(), e->getType());
+
+      const FieldDecl *field = *CompCategoryInfo.Record->field_begin();
+      LValue fieldLValue = cgf.emitLValueForFieldInitialization(
+          destLValue, field, field->getName());
+      cgf.emitStoreThroughLValue(RValue::get(result), fieldLValue, true);
+    } else {
+      cgf.cgm.errorNYI(e->getSourceRange(),
+                       "AggExprEmitter: unsupported operand type");
+      return;
+    }
   }
   void VisitCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *e) {
     cgf.cgm.errorNYI(e->getSourceRange(),
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 9c68248d5dede..53086ecb4b669 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -3155,6 +3155,88 @@ mlir::LogicalResult 
CIRToLLVMCmpOpLowering::matchAndRewrite(
   return cmpOp.emitError() << "unsupported type for CmpOp: " << type;
 }
 
+mlir::LogicalResult CIRToLLVMCmpThreeWayOpLowering::matchAndRewrite(
+    cir::CmpThreeWayOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  mlir::Location loc = op.getLoc();
+  auto info = op.getInfo();
+  mlir::Type resultTy = getTypeConverter()->convertType(op.getType());
+  mlir::Value lhs = adaptor.getLhs();
+  mlir::Value rhs = adaptor.getRhs();
+  mlir::Type operandTy = lhs.getType();
+
+  mlir::Value ltValue, eqValue, gtValue, unorderedValue;
+  if (auto strongInfo = mlir::dyn_cast<cir::CmpThreeWayStrongInfoAttr>(info)) {
+    ltValue = mlir::LLVM::ConstantOp::create(
+        rewriter, loc, resultTy,
+        rewriter.getI64IntegerAttr(strongInfo.getLt()));
+    eqValue = mlir::LLVM::ConstantOp::create(
+        rewriter, loc, resultTy,
+        rewriter.getI64IntegerAttr(strongInfo.getEq()));
+    gtValue = mlir::LLVM::ConstantOp::create(
+        rewriter, loc, resultTy,
+        rewriter.getI64IntegerAttr(strongInfo.getGt()));
+  } else if (auto partialInfo =
+                 mlir::dyn_cast<cir::CmpThreeWayPartialInfoAttr>(info)) {
+    ltValue = mlir::LLVM::ConstantOp::create(
+        rewriter, loc, resultTy,
+        rewriter.getI64IntegerAttr(partialInfo.getLt()));
+    eqValue = mlir::LLVM::ConstantOp::create(
+        rewriter, loc, resultTy,
+        rewriter.getI64IntegerAttr(partialInfo.getEq()));
+    gtValue = mlir::LLVM::ConstantOp::create(
+        rewriter, loc, resultTy,
+        rewriter.getI64IntegerAttr(partialInfo.getGt()));
+    unorderedValue = mlir::LLVM::ConstantOp::create(
+        rewriter, loc, resultTy,
+        rewriter.getI64IntegerAttr(partialInfo.getUnordered()));
+  } else {
+    return op.emitError("unsupported comparison info attribute");
+  }
+
+  if (mlir::isa<mlir::IntegerType>(operandTy)) {
+    bool isSigned = true;
+    if (auto cirIntTy = mlir::dyn_cast<cir::IntType>(op.getLhs().getType())) {
+      isSigned = cirIntTy.isSigned();
+    }
+    auto ltPred = isSigned ? mlir::LLVM::ICmpPredicate::slt
+                           : mlir::LLVM::ICmpPredicate::ult;
+
+    mlir::Value ltCmp =
+        mlir::LLVM::ICmpOp::create(rewriter, loc, ltPred, lhs, rhs);
+    mlir::Value eqCmp = mlir::LLVM::ICmpOp::create(
+        rewriter, loc, mlir::LLVM::ICmpPredicate::eq, lhs, rhs);
+
+    mlir::Value result = mlir::LLVM::SelectOp::create(
+        rewriter, loc, ltCmp, ltValue,
+        mlir::LLVM::SelectOp::create(rewriter, loc, eqCmp, eqValue, gtValue));
+    rewriter.replaceOp(op, result);
+    return mlir::success();
+  } else if (mlir::isa<mlir::FloatType>(operandTy)) {
+    if (!unorderedValue) {
+      return op.emitError("strong ordering not supported for float operands");
+    }
+
+    mlir::Value ltCmp = mlir::LLVM::FCmpOp::create(
+        rewriter, loc, mlir::LLVM::FCmpPredicate::olt, lhs, rhs);
+    mlir::Value eqCmp = mlir::LLVM::FCmpOp::create(
+        rewriter, loc, mlir::LLVM::FCmpPredicate::oeq, lhs, rhs);
+    mlir::Value orderedResult = mlir::LLVM::SelectOp::create(
+        rewriter, loc, ltCmp, ltValue,
+        mlir::LLVM::SelectOp::create(rewriter, loc, eqCmp, eqValue, gtValue));
+
+    mlir::Value unorderedCmp = mlir::LLVM::FCmpOp::create(
+        rewriter, loc, mlir::LLVM::FCmpPredicate::uno, lhs, rhs);
+
+    mlir::Value result = mlir::LLVM::SelectOp::create(
+        rewriter, loc, unorderedCmp, unorderedValue, orderedResult);
+    rewriter.replaceOp(op, result);
+    return mlir::success();
+  } else {
+    return op.emitError("unsupported operand type for three-way comparison");
+  }
+}
+
 mlir::LogicalResult CIRToLLVMBinOpOverflowOpLowering::matchAndRewrite(
     cir::BinOpOverflowOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp 
b/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp
new file mode 100644
index 0000000000000..80c408597f0d2
--- /dev/null
+++ b/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp
@@ -0,0 +1,22 @@
+// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir 
-emit-cir %s -o %t.cir
+// RUN: FileCheck --input-file=%t.cir %s
+
+#include "Inputs/std-compare.h"
+
+int test_int_spaceship(int a, int b) {
+  auto result = a <=> b;
+  // CHECK: cir.cmp3way(%{{.*}} : !s32i, %{{.*}}, 
#cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i
+  return (result < 0) ? -1 : (result > 0) ? 1 : 0;
+}
+
+unsigned int test_uint_spaceship(unsigned int a, unsigned int b) {
+  auto result = a <=> b;
+  // CHECK: cir.cmp3way(%{{.*}} : !u32i, %{{.*}}, 
#cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i
+  return (result < 0) ? 0 : (result > 0) ? 2 : 1;
+}
+
+float test_float_spaceship(float a, float b) {
+  auto result = a <=> b;
+  // CHECK: cir.cmp3way(%{{.*}} : !cir.float, %{{.*}}, 
#cir.cmp3way_partial_info<partial, lt = -1, eq = 0, gt = 1, unordered = -127>)
+  return (result < 0) ? -1.0f : (result > 0) ? 1.0f : 0.0f;
+}
\ No newline at end of file
diff --git a/clang/test/CIR/Lowering/cmp3way.cir 
b/clang/test/CIR/Lowering/cmp3way.cir
new file mode 100644
index 0000000000000..fe3740b6fd30d
--- /dev/null
+++ b/clang/test/CIR/Lowering/cmp3way.cir
@@ -0,0 +1,32 @@
+// RUN: cir-opt %s -cir-to-llvm -o %t.mlir
+// RUN: FileCheck --input-file=%t.mlir %s
+
+!s32i = !cir.int<s, 32>
+!u32i = !cir.int<u, 32>
+module {
+  cir.func @test_signed() -> !s32i {
+    %0 = cir.const #cir.int<5> : !s32i
+    %1 = cir.const #cir.int<3> : !s32i
+    %2 = cir.cmp3way(%0 : !s32i, %1, #cir.cmp3way_strong_info<strong, lt = -1, 
eq = 0, gt = 1>) : !s32i
+    // CHECK: llvm.icmp "slt"
+    cir.return %2 : !s32i
+  }
+
+  cir.func @test_unsigned() -> !s32i {
+    %0 = cir.const #cir.int<5> : !u32i
+    %1 = cir.const #cir.int<3> : !u32i
+    %2 = cir.cmp3way(%0 : !u32i, %1, #cir.cmp3way_strong_info<strong, lt = -1, 
eq = 0, gt = 1>) : !s32i
+    // CHECK: llvm.icmp "ult"
+    cir.return %2 : !s32i
+  }
+
+  cir.func @test_float() -> !s32i {
+    %0 = cir.const #cir.fp<1.5> : !cir.float
+    %1 = cir.const #cir.fp<2.5> : !cir.float
+    %2 = cir.cmp3way(%0 : !cir.float, %1, #cir.cmp3way_partial_info<partial, 
lt = -1, eq = 0, gt = 1, unordered = 2>) : !s32i
+    // CHECK: llvm.fcmp "olt"
+    // CHECK: llvm.fcmp "oeq"
+    // CHECK: llvm.fcmp "uno"
+    cir.return %2 : !s32i
+  }
+}
\ No newline at end of file

``````````

</details>


https://github.com/llvm/llvm-project/pull/186294
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to