llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clangir Author: Yeongu Choe (YeonguChoe) <details> <summary>Changes</summary> A three-way comparison operation is an operation that determines ordering of two numbers in a single operation. I implemented it as specified in the CIR documentation. Also I used `clang-format` for code formatting. Reference - https://llvm.github.io/clangir/Dialect/ops.html#circmp3way-circmpthreewayop - https://en.wikipedia.org/wiki/Three-way_comparison - https://clang.llvm.org/docs/ClangFormatStyleOptions.html --- Full diff: https://github.com/llvm/llvm-project/pull/186294.diff 6 Files Affected: - (modified) clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h (+7) - (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+34) - (modified) clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp (+42-4) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+82) - (added) clang/test/CIR/CodeGenCXX/three-way-comparison.cpp (+22) - (added) clang/test/CIR/Lowering/cmp3way.cir (+32) ``````````diff diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index e60288c40132f..2a81da408121c 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -669,6 +669,13 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { return cir::CmpOp::create(*this, loc, kind, lhs, rhs); } + cir::CmpThreeWayOp createThreeWayComparison(mlir::Location loc, + mlir::Type resultTy, + mlir::Value lhs, mlir::Value rhs, + mlir::Attribute info) { + return cir::CmpThreeWayOp::create(*this, loc, resultTy, lhs, rhs, info); + } + cir::VecCmpOp createVecCompare(mlir::Location loc, cir::CmpOpKind kind, mlir::Value lhs, mlir::Value rhs) { VectorType vecCast = mlir::cast<VectorType>(lhs.getType()); diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index a9b98b1f43b3f..24758a0c946f6 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2071,6 +2071,40 @@ def CIR_CmpOp : CIR_Op<"cmp", [Pure, SameTypeOperands]> { let hasCXXABILowering = true; } +//===----------------------------------------------------------------------===// +// CmpThreeWayOp +//===----------------------------------------------------------------------===// + +def CIR_CmpThreeWayStrongInfoAttr + : CIR_Attr<"CmpThreeWayStrongInfo", "cmp3way_strong_info"> { + let parameters = (ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt); + let assemblyFormat = + [{`<` `strong` `,` `lt` `=` $lt `,` `eq` `=` $eq `,` `gt` `=` $gt `>`}]; +} +def CIR_CmpThreeWayPartialInfoAttr + : CIR_Attr<"CmpThreeWayPartialInfo", "cmp3way_partial_info"> { + let parameters = (ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt, + "int64_t":$unordered); + let assemblyFormat = + [{`<` `partial` `,` `lt` `=` $lt `,` `eq` `=` $eq `,` `gt` `=` $gt `,` `unordered` `=` $unordered `>` }]; +} +def CIR_CmpThreeWayInfoAttr : AnyAttrOf<[CIR_CmpThreeWayStrongInfoAttr, + CIR_CmpThreeWayPartialInfoAttr]>; + +def CIR_CmpThreeWayOp + : CIR_Op<"cmp3way", [Pure, SameTypeOperands, ConditionallySpeculatable]> { + let summary = "Performs three-way comparison."; + let description = [{ + Three-way comparison takes two operands of the same type and determines ordering. + }]; + let arguments = (ins CIR_AnyType:$lhs, CIR_AnyType:$rhs, + CIR_CmpThreeWayInfoAttr:$info); + let results = (outs CIR_AnySIntType:$result); + let assemblyFormat = [{ + `(` $lhs `:` type($lhs) `,` $rhs `,` qualified($info) `)` `:` type($result) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // BinOpOverflowOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp index 9f390fec97613..19e844664f53c 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp @@ -301,9 +301,7 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> { cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitSubstNonTypeTemplateParmExpr"); } - void VisitConstantExpr(ConstantExpr *e) { - cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitConstantExpr"); - } + void VisitConstantExpr(ConstantExpr *e) { return Visit(e->getSubExpr()); } void VisitMemberExpr(MemberExpr *e) { emitAggLoadOfLValue(e); } void VisitUnaryDeref(UnaryOperator *e) { emitAggLoadOfLValue(e); } void VisitStringLiteral(StringLiteral *e) { emitAggLoadOfLValue(e); } @@ -326,7 +324,47 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> { Visit(e->getRHS()); } void VisitBinCmp(const BinaryOperator *e) { - cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitBinCmp"); + const ComparisonCategoryInfo &CompCategoryInfo = + cgf.getContext().CompCategories.getInfoForType(e->getType()); + + QualType ArgTy = e->getLHS()->getType(); + if (ArgTy->isIntegralOrEnumerationType() || ArgTy->isRealFloatingType() || + ArgTy->isNullPtrType() || ArgTy->isPointerType() || + ArgTy->isMemberPointerType()) { + mlir::Value lhs = cgf.emitScalarExpr(e->getLHS()); + mlir::Value rhs = cgf.emitScalarExpr(e->getRHS()); + + mlir::Attribute info; + if (CompCategoryInfo.isStrong()) { + info = cir::CmpThreeWayStrongInfoAttr::get( + cgf.getBuilder().getContext(), + CompCategoryInfo.getLess()->getIntValue().getSExtValue(), + CompCategoryInfo.getEqualOrEquiv()->getIntValue().getSExtValue(), + CompCategoryInfo.getGreater()->getIntValue().getSExtValue()); + } else { + info = cir::CmpThreeWayPartialInfoAttr::get( + cgf.getBuilder().getContext(), + CompCategoryInfo.getLess()->getIntValue().getSExtValue(), + CompCategoryInfo.getEqualOrEquiv()->getIntValue().getSExtValue(), + CompCategoryInfo.getGreater()->getIntValue().getSExtValue(), + CompCategoryInfo.getUnordered()->getIntValue().getSExtValue()); + } + mlir::Type resultTy = cgf.convertType(cgf.getContext().IntTy); + mlir::Value result = cgf.getBuilder().createThreeWayComparison( + cgf.getLoc(e->getSourceRange()), resultTy, lhs, rhs, info); + + ensureDest(cgf.getLoc(e->getSourceRange()), e->getType()); + LValue destLValue = cgf.makeAddrLValue(dest.getAddress(), e->getType()); + + const FieldDecl *field = *CompCategoryInfo.Record->field_begin(); + LValue fieldLValue = cgf.emitLValueForFieldInitialization( + destLValue, field, field->getName()); + cgf.emitStoreThroughLValue(RValue::get(result), fieldLValue, true); + } else { + cgf.cgm.errorNYI(e->getSourceRange(), + "AggExprEmitter: unsupported operand type"); + return; + } } void VisitCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *e) { cgf.cgm.errorNYI(e->getSourceRange(), diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 9c68248d5dede..53086ecb4b669 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -3155,6 +3155,88 @@ mlir::LogicalResult CIRToLLVMCmpOpLowering::matchAndRewrite( return cmpOp.emitError() << "unsupported type for CmpOp: " << type; } +mlir::LogicalResult CIRToLLVMCmpThreeWayOpLowering::matchAndRewrite( + cir::CmpThreeWayOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Location loc = op.getLoc(); + auto info = op.getInfo(); + mlir::Type resultTy = getTypeConverter()->convertType(op.getType()); + mlir::Value lhs = adaptor.getLhs(); + mlir::Value rhs = adaptor.getRhs(); + mlir::Type operandTy = lhs.getType(); + + mlir::Value ltValue, eqValue, gtValue, unorderedValue; + if (auto strongInfo = mlir::dyn_cast<cir::CmpThreeWayStrongInfoAttr>(info)) { + ltValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, resultTy, + rewriter.getI64IntegerAttr(strongInfo.getLt())); + eqValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, resultTy, + rewriter.getI64IntegerAttr(strongInfo.getEq())); + gtValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, resultTy, + rewriter.getI64IntegerAttr(strongInfo.getGt())); + } else if (auto partialInfo = + mlir::dyn_cast<cir::CmpThreeWayPartialInfoAttr>(info)) { + ltValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, resultTy, + rewriter.getI64IntegerAttr(partialInfo.getLt())); + eqValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, resultTy, + rewriter.getI64IntegerAttr(partialInfo.getEq())); + gtValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, resultTy, + rewriter.getI64IntegerAttr(partialInfo.getGt())); + unorderedValue = mlir::LLVM::ConstantOp::create( + rewriter, loc, resultTy, + rewriter.getI64IntegerAttr(partialInfo.getUnordered())); + } else { + return op.emitError("unsupported comparison info attribute"); + } + + if (mlir::isa<mlir::IntegerType>(operandTy)) { + bool isSigned = true; + if (auto cirIntTy = mlir::dyn_cast<cir::IntType>(op.getLhs().getType())) { + isSigned = cirIntTy.isSigned(); + } + auto ltPred = isSigned ? mlir::LLVM::ICmpPredicate::slt + : mlir::LLVM::ICmpPredicate::ult; + + mlir::Value ltCmp = + mlir::LLVM::ICmpOp::create(rewriter, loc, ltPred, lhs, rhs); + mlir::Value eqCmp = mlir::LLVM::ICmpOp::create( + rewriter, loc, mlir::LLVM::ICmpPredicate::eq, lhs, rhs); + + mlir::Value result = mlir::LLVM::SelectOp::create( + rewriter, loc, ltCmp, ltValue, + mlir::LLVM::SelectOp::create(rewriter, loc, eqCmp, eqValue, gtValue)); + rewriter.replaceOp(op, result); + return mlir::success(); + } else if (mlir::isa<mlir::FloatType>(operandTy)) { + if (!unorderedValue) { + return op.emitError("strong ordering not supported for float operands"); + } + + mlir::Value ltCmp = mlir::LLVM::FCmpOp::create( + rewriter, loc, mlir::LLVM::FCmpPredicate::olt, lhs, rhs); + mlir::Value eqCmp = mlir::LLVM::FCmpOp::create( + rewriter, loc, mlir::LLVM::FCmpPredicate::oeq, lhs, rhs); + mlir::Value orderedResult = mlir::LLVM::SelectOp::create( + rewriter, loc, ltCmp, ltValue, + mlir::LLVM::SelectOp::create(rewriter, loc, eqCmp, eqValue, gtValue)); + + mlir::Value unorderedCmp = mlir::LLVM::FCmpOp::create( + rewriter, loc, mlir::LLVM::FCmpPredicate::uno, lhs, rhs); + + mlir::Value result = mlir::LLVM::SelectOp::create( + rewriter, loc, unorderedCmp, unorderedValue, orderedResult); + rewriter.replaceOp(op, result); + return mlir::success(); + } else { + return op.emitError("unsupported operand type for three-way comparison"); + } +} + mlir::LogicalResult CIRToLLVMBinOpOverflowOpLowering::matchAndRewrite( cir::BinOpOverflowOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp b/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp new file mode 100644 index 0000000000000..80c408597f0d2 --- /dev/null +++ b/clang/test/CIR/CodeGenCXX/three-way-comparison.cpp @@ -0,0 +1,22 @@ +// RUN: %clang_cc1 -std=c++20 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir %s -o %t.cir +// RUN: FileCheck --input-file=%t.cir %s + +#include "Inputs/std-compare.h" + +int test_int_spaceship(int a, int b) { + auto result = a <=> b; + // CHECK: cir.cmp3way(%{{.*}} : !s32i, %{{.*}}, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i + return (result < 0) ? -1 : (result > 0) ? 1 : 0; +} + +unsigned int test_uint_spaceship(unsigned int a, unsigned int b) { + auto result = a <=> b; + // CHECK: cir.cmp3way(%{{.*}} : !u32i, %{{.*}}, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i + return (result < 0) ? 0 : (result > 0) ? 2 : 1; +} + +float test_float_spaceship(float a, float b) { + auto result = a <=> b; + // CHECK: cir.cmp3way(%{{.*}} : !cir.float, %{{.*}}, #cir.cmp3way_partial_info<partial, lt = -1, eq = 0, gt = 1, unordered = -127>) + return (result < 0) ? -1.0f : (result > 0) ? 1.0f : 0.0f; +} \ No newline at end of file diff --git a/clang/test/CIR/Lowering/cmp3way.cir b/clang/test/CIR/Lowering/cmp3way.cir new file mode 100644 index 0000000000000..fe3740b6fd30d --- /dev/null +++ b/clang/test/CIR/Lowering/cmp3way.cir @@ -0,0 +1,32 @@ +// RUN: cir-opt %s -cir-to-llvm -o %t.mlir +// RUN: FileCheck --input-file=%t.mlir %s + +!s32i = !cir.int<s, 32> +!u32i = !cir.int<u, 32> +module { + cir.func @test_signed() -> !s32i { + %0 = cir.const #cir.int<5> : !s32i + %1 = cir.const #cir.int<3> : !s32i + %2 = cir.cmp3way(%0 : !s32i, %1, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i + // CHECK: llvm.icmp "slt" + cir.return %2 : !s32i + } + + cir.func @test_unsigned() -> !s32i { + %0 = cir.const #cir.int<5> : !u32i + %1 = cir.const #cir.int<3> : !u32i + %2 = cir.cmp3way(%0 : !u32i, %1, #cir.cmp3way_strong_info<strong, lt = -1, eq = 0, gt = 1>) : !s32i + // CHECK: llvm.icmp "ult" + cir.return %2 : !s32i + } + + cir.func @test_float() -> !s32i { + %0 = cir.const #cir.fp<1.5> : !cir.float + %1 = cir.const #cir.fp<2.5> : !cir.float + %2 = cir.cmp3way(%0 : !cir.float, %1, #cir.cmp3way_partial_info<partial, lt = -1, eq = 0, gt = 1, unordered = 2>) : !s32i + // CHECK: llvm.fcmp "olt" + // CHECK: llvm.fcmp "oeq" + // CHECK: llvm.fcmp "uno" + cir.return %2 : !s32i + } +} \ No newline at end of file `````````` </details> https://github.com/llvm/llvm-project/pull/186294 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
