llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clangir Author: Hendrik Hübner (HendrikHuebner) <details> <summary>Changes</summary> This PR upstreams the three way compare op from the incubator repo --- Patch is 36.96 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/169963.diff 10 Files Affected: - (modified) clang/include/clang/CIR/Dialect/IR/CIRAttrs.td (+62) - (modified) clang/include/clang/CIR/Dialect/IR/CIROps.td (+65) - (modified) clang/lib/CIR/CodeGen/CIRGenBuilder.h (+31) - (modified) clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp (+57-1) - (modified) clang/lib/CIR/Dialect/IR/CIRAttrs.cpp (+54) - (modified) clang/lib/CIR/Dialect/IR/CIRDialect.cpp (+20) - (modified) clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp (+38-1) - (modified) clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp (+53) - (added) clang/test/CIR/CodeGen/Inputs/std-compare.h (+307) - (added) clang/test/CIR/CodeGen/three-way-cmp.cpp (+72) ``````````diff diff --git a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td index 12bc9cf7b5b04..0fce5b8bc6b72 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRAttrs.td @@ -447,6 +447,68 @@ def CIR_ConstPtrAttr : CIR_Attr<"ConstPtr", "ptr", [TypedAttrInterface]> { }]; } +//===----------------------------------------------------------------------===// +// CmpThreeWayInfoAttr +//===----------------------------------------------------------------------===// + +def CIR_CmpOrdering : CIR_I32EnumAttr< + "CmpOrdering", "three-way comparison ordering kind", [ + I32EnumAttrCase<"Strong", 0, "strong">, + I32EnumAttrCase<"Partial", 1, "partial"> +]> { + let genSpecializedAttr = 0; +} + +def CIR_CmpThreeWayInfoAttr : CIR_Attr<"CmpThreeWayInfo", "cmp3way_info"> { + let summary = "Holds information about a three-way comparison operation"; + let description = [{ + The `#cmp3way_info` attribute contains information about a three-way + comparison operation `cir.cmp3way`. + + The `ordering` parameter gives the ordering kind of the three-way comparison + operation. It may be either strong ordering or partial ordering. + + Given the two input operands of the three-way comparison operation `lhs` and + `rhs`, the `lt`, `eq`, `gt`, and `unordered` parameters gives the result + value that should be produced by the three-way comparison operation when the + ordering between `lhs` and `rhs` is `lhs < rhs`, `lhs == rhs`, `lhs > rhs`, + or neither, respectively. + }]; + + let parameters = (ins + EnumParameter<CIR_CmpOrdering>:$ordering, + "int64_t":$lt, "int64_t":$eq, "int64_t":$gt, + OptionalParameter<"std::optional<int64_t>">:$unordered + ); + + let builders = [ + AttrBuilder<(ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt), [{ + return $_get($_ctxt, CmpOrdering::Strong, lt, eq, gt, std::nullopt); + }]>, + AttrBuilder<(ins "int64_t":$lt, "int64_t":$eq, "int64_t":$gt, + "int64_t":$unordered), [{ + return $_get($_ctxt, CmpOrdering::Partial, lt, eq, gt, unordered); + }]>, + ]; + + let extraClassDeclaration = [{ + /// Get attribute alias name for this attribute. + std::string getAlias() const; + }]; + + let assemblyFormat = [{ + `<` + $ordering `,` + `lt` `=` $lt `,` + `eq` `=` $eq `,` + `gt` `=` $gt + (`,` `unordered` `=` $unordered^)? + `>` + }]; + + let genVerifyDecl = 1; +} + //===----------------------------------------------------------------------===// // GlobalViewAttr //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index a19c4f951fff9..411ee8dc5984a 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -958,6 +958,71 @@ def CIR_ScopeOp : CIR_Op<"scope", [ let hasLLVMLowering = false; } +//===----------------------------------------------------------------------===// +// CmpThreeWayOp +//===----------------------------------------------------------------------===// + +def CIR_CmpThreeWayOp : CIR_Op<"cmp3way", [Pure, SameTypeOperands]> { + let summary = "Compare two values with C++ three-way comparison semantics"; + let description = [{ + The `cir.cmp3way` operation models the `<=>` operator in C++20. It takes two + operands with the same type and produces a result indicating the ordering + between the two input operands. + + The result of the operation is a signed integer that indicates the ordering + between the two input operands. + + There are two kinds of ordering: strong ordering and partial ordering. + Comparing different types of values yields different kinds of orderings. + The `info` parameter gives the ordering kind and other necessary information + about the comparison. + + Example: + + ```mlir + !s32i = !cir.int<s, 32> + + #cmp3way_strong = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1> + #cmp3way_partial = #cmp3way_info<strong, lt = -1, eq = 0, gt = 1, unordered = 2> + + %0 = cir.const #cir.int<0> : !s32i + %1 = cir.const #cir.int<1> : !s32i + %2 = cir.cmp3way(%0 : !s32i, %1, #cmp3way_strong) : !s8i + + %3 = cir.const #cir.fp<0.0> : !cir.float + %4 = cir.const #cir.fp<1.0> : !cir.float + %5 = cir.cmp3way(%3 : !cir.float, %4, #cmp3way_partial) : !s8i + ``` + }]; + + let arguments = (ins + CIR_AnyType:$lhs, + CIR_AnyType:$rhs, + CIR_CmpThreeWayInfoAttr:$info + ); + + let results = (outs CIR_AnySIntType:$result); + + let assemblyFormat = [{ + `(` $lhs `:` type($lhs) `,` $rhs `,` qualified($info) `)` + `:` type($result) attr-dict + }]; + + let extraClassDeclaration = [{ + /// Determine whether this three-way comparison produces a strong ordering. + bool isStrongOrdering() { + return getInfo().getOrdering() == cir::CmpOrdering::Strong; + } + + /// Determine whether this three-way comparison compares integral operands. + bool isIntegralComparison() { + return mlir::isa<cir::IntType>(getLhs().getType()); + } + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // SwitchOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenBuilder.h b/clang/lib/CIR/CodeGen/CIRGenBuilder.h index 85b38120169fd..c296449084614 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuilder.h +++ b/clang/lib/CIR/CodeGen/CIRGenBuilder.h @@ -567,6 +567,37 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy { return cir::StackRestoreOp::create(*this, loc, v); } + cir::CmpThreeWayOp createThreeWayCmpStrong(mlir::Location loc, + mlir::Value lhs, mlir::Value rhs, + const llvm::APSInt <Res, + const llvm::APSInt &eqRes, + const llvm::APSInt >Res) { + assert(ltRes.getBitWidth() == eqRes.getBitWidth() && + ltRes.getBitWidth() == gtRes.getBitWidth() && + "the three comparison results must have the same bit width"); + cir::IntType cmpResultTy = getSIntNTy(ltRes.getBitWidth()); + auto infoAttr = cir::CmpThreeWayInfoAttr::get( + getContext(), ltRes.getSExtValue(), eqRes.getSExtValue(), gtRes.getSExtValue()); + return cir::CmpThreeWayOp::create(*this, loc, cmpResultTy, lhs, rhs, + infoAttr); + } + + cir::CmpThreeWayOp + createThreeWayCmpPartial(mlir::Location loc, mlir::Value lhs, mlir::Value rhs, + const llvm::APSInt <Res, const llvm::APSInt &eqRes, + const llvm::APSInt >Res, + const llvm::APSInt &unorderedRes) { + assert(ltRes.getBitWidth() == eqRes.getBitWidth() && + ltRes.getBitWidth() == gtRes.getBitWidth() && + ltRes.getBitWidth() == unorderedRes.getBitWidth() && + "the four comparison results must have the same bit width"); + auto cmpResultTy = getSIntNTy(ltRes.getBitWidth()); + auto infoAttr = cir::CmpThreeWayInfoAttr::get( + getContext(), ltRes.getSExtValue(), eqRes.getSExtValue(), gtRes.getSExtValue(), unorderedRes.getSExtValue()); + return cir::CmpThreeWayOp::create(*this, loc, cmpResultTy, lhs, rhs, + infoAttr); + } + mlir::Value createSetBitfield(mlir::Location loc, mlir::Type resultType, Address dstAddr, mlir::Type storageType, mlir::Value src, const CIRGenBitFieldInfo &info, diff --git a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp index 872fc8d14ad95..985da8283baac 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprAggregate.cpp @@ -18,6 +18,7 @@ #include "clang/AST/Expr.h" #include "clang/AST/RecordLayout.h" #include "clang/AST/StmtVisitor.h" +#include "llvm/IR/Value.h" #include <cstdint> using namespace clang; @@ -298,8 +299,63 @@ class AggExprEmitter : public StmtVisitor<AggExprEmitter> { Visit(e->getRHS()); } void VisitBinCmp(const BinaryOperator *e) { - cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitBinCmp"); + assert(cgf.getContext().hasSameType(e->getLHS()->getType(), + e->getRHS()->getType())); + const ComparisonCategoryInfo &cmpInfo = + cgf.getContext().CompCategories.getInfoForType(e->getType()); + assert(cmpInfo.Record->isTriviallyCopyable() && + "cannot copy non-trivially copyable aggregate"); + + QualType argTy = e->getLHS()->getType(); + + if (!argTy->isIntegralOrEnumerationType() && !argTy->isRealFloatingType() && + !argTy->isNullPtrType() && !argTy->isPointerType() && + !argTy->isMemberPointerType() && !argTy->isAnyComplexType()) + llvm_unreachable("aggregate three-way comparison"); + + mlir::Location loc = cgf.getLoc(e->getSourceRange()); + CIRGenBuilderTy builder = cgf.getBuilder(); + + if (e->getType()->isAnyComplexType()) + llvm_unreachable("NYI"); + + mlir::Value lhs = cgf.emitAnyExpr(e->getLHS()).getValue(); + mlir::Value rhs = cgf.emitAnyExpr(e->getRHS()).getValue(); + + mlir::Value resultScalar; + if (argTy->isNullPtrType()) { + resultScalar = + builder.getConstInt(loc, cmpInfo.getEqualOrEquiv()->getIntValue()); + } else { + llvm::APSInt ltRes = cmpInfo.getLess()->getIntValue(); + llvm::APSInt eqRes = cmpInfo.getEqualOrEquiv()->getIntValue(); + llvm::APSInt gtRes = cmpInfo.getGreater()->getIntValue(); + if (!cmpInfo.isPartial()) { + // Strong ordering. + resultScalar = builder.createThreeWayCmpStrong(loc, lhs, rhs, ltRes, + eqRes, gtRes); + } else { + // Partial ordering. + llvm::APSInt unorderedRes = cmpInfo.getUnordered()->getIntValue(); + resultScalar = builder.createThreeWayCmpPartial( + loc, lhs, rhs, ltRes, eqRes, gtRes, unorderedRes); + } + } + + // Create the return value in the destination slot. + ensureDest(loc, e->getType()); + LValue destLVal = cgf.makeAddrLValue(dest.getAddress(), e->getType()); + + // Emit the address of the first (and only) field in the comparison category + // type, and initialize it from the constant integer value produced above. + const FieldDecl *resultField = *cmpInfo.Record->field_begin(); + LValue fieldLVal = cgf.emitLValueForFieldInitialization(destLVal, resultField, + resultField->getName()); + cgf.emitStoreThroughLValue(RValue::get(resultScalar), fieldLVal); + + // All done! The result is in the dest slot. } + void VisitCXXRewrittenBinaryOperator(CXXRewrittenBinaryOperator *e) { cgf.cgm.errorNYI(e->getSourceRange(), "AggExprEmitter: VisitCXXRewrittenBinaryOperator"); diff --git a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp index 64ac97025e7c7..70e483de91d8d 100644 --- a/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRAttrs.cpp @@ -267,6 +267,60 @@ LogicalResult FPAttr::verify(function_ref<InFlightDiagnostic()> emitError, return success(); } + +//===----------------------------------------------------------------------===// +// CmpThreeWayInfoAttr definitions +//===----------------------------------------------------------------------===// + +std::string CmpThreeWayInfoAttr::getAlias() const { + std::string alias = "cmp3way_info"; + + if (getOrdering() == CmpOrdering::Strong) + alias.append("_strong_"); + else + alias.append("_partial_"); + + auto appendInt = [&](int64_t value) { + if (value < 0) { + alias.push_back('n'); + value = -value; + } + alias.append(std::to_string(value)); + }; + + alias.append("lt"); + appendInt(getLt()); + alias.append("eq"); + appendInt(getEq()); + alias.append("gt"); + appendInt(getGt()); + + if (std::optional<int> unordered = getUnordered()) { + alias.append("un"); + appendInt(unordered.value()); + } + + return alias; +} + +LogicalResult +CmpThreeWayInfoAttr::verify(function_ref<InFlightDiagnostic()> emitError, + CmpOrdering ordering, int64_t lt, int64_t eq, + int64_t gt, std::optional<int64_t> unordered) { + // The presense of unordered must match the value of ordering. + if (ordering == CmpOrdering::Strong && unordered) { + emitError() << "strong ordering does not include unordered ordering"; + return failure(); + } + if (ordering == CmpOrdering::Partial && !unordered) { + emitError() << "partial ordering lacks unordered ordering"; + return failure(); + } + + return success(); +} + + //===----------------------------------------------------------------------===// // ConstComplexAttr definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 6bf543cf794b7..cb90dbeda7abf 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -78,6 +78,11 @@ struct CIROpAsmDialectInterface : public OpAsmDialectInterface { os << dynCastInfoAttr.getAlias(); return AliasResult::FinalAlias; } + if (auto cmpThreeWayInfoAttr = + mlir::dyn_cast<cir::CmpThreeWayInfoAttr>(attr)) { + os << cmpThreeWayInfoAttr.getAlias(); + return AliasResult::FinalAlias; + } return AliasResult::NoAlias; } }; @@ -1132,6 +1137,21 @@ Block *cir::BrCondOp::getSuccessorForOperands(ArrayRef<Attribute> operands) { return nullptr; } + +//===----------------------------------------------------------------------===// +// CmpThreeWayOp +//===----------------------------------------------------------------------===// + +mlir::LogicalResult CmpThreeWayOp::verify() { + // Type of the result must be a signed integer type. + if (!getType().isSigned()) { + emitOpError() << "result type of cir.cmp3way must be a signed integer type"; + return failure(); + } + + return success(); +} + //===----------------------------------------------------------------------===// // CaseOp //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp index f8f354c2d1072..14ecd7013afeb 100644 --- a/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp +++ b/clang/lib/CIR/Dialect/Transforms/LoweringPrepare.cpp @@ -71,6 +71,7 @@ struct LoweringPreparePass void lowerComplexMulOp(cir::ComplexMulOp op); void lowerUnaryOp(cir::UnaryOp op); void lowerGlobalOp(cir::GlobalOp op); + void lowerThreeWayCmpOp(cir::CmpThreeWayOp op); void lowerDynamicCastOp(cir::DynamicCastOp op); void lowerArrayDtor(cir::ArrayDtor op); void lowerArrayCtor(cir::ArrayCtor op); @@ -911,6 +912,40 @@ void LoweringPreparePass::lowerGlobalOp(GlobalOp op) { assert(!cir::MissingFeatures::opGlobalAnnotations()); } +void LoweringPreparePass::lowerThreeWayCmpOp(CmpThreeWayOp op) { + CIRBaseBuilderTy builder(getContext()); + builder.setInsertionPointAfter(op); + + mlir::Location loc = op->getLoc(); + cir::CmpThreeWayInfoAttr cmpInfo = op.getInfo(); + + mlir::Value ltRes = builder.getConstantInt(loc, op.getType(), cmpInfo.getLt()); + mlir::Value eqRes = builder.getConstantInt(loc, op.getType(), cmpInfo.getEq()); + mlir::Value gtRes = builder.getConstantInt(loc, op.getType(), cmpInfo.getGt()); + + mlir::Value lt = builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs()); + mlir::Value eq = builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs()); + + mlir::Value transformedResult; + if (cmpInfo.getOrdering() == CmpOrdering::Strong) { + // Strong ordering. + mlir::Value selectOnEq = builder.createSelect(loc, eq, eqRes, gtRes); + transformedResult = builder.createSelect(loc, lt, ltRes, selectOnEq); + } else { + // Partial ordering. + cir::ConstantOp unorderedRes = + builder.getConstantInt(loc, op.getType(), cmpInfo.getUnordered().value()); + + mlir::Value gt = builder.createCompare(loc, CmpOpKind::lt, op.getLhs(), op.getRhs()); + mlir::Value selectOnEq = builder.createSelect(loc, eq, eqRes, unorderedRes); + mlir::Value selectOnGt = builder.createSelect(loc, gt, gtRes, selectOnEq); + transformedResult = builder.createSelect(loc, lt, ltRes, selectOnGt); + } + + op.replaceAllUsesWith(transformedResult); + op.erase(); +} + template <typename AttributeTy> static llvm::SmallVector<mlir::Attribute> prepareCtorDtorAttrList(mlir::MLIRContext *context, @@ -1107,6 +1142,8 @@ void LoweringPreparePass::runOnOp(mlir::Operation *op) { globalCtorList.emplace_back(fnOp.getName(), globalCtor.value()); else if (auto globalDtor = fnOp.getGlobalDtorPriority()) globalDtorList.emplace_back(fnOp.getName(), globalDtor.value()); + } else if (auto threeWayCmp = dyn_cast<CmpThreeWayOp>(op)) { + lowerThreeWayCmpOp(threeWayCmp); } } @@ -1120,7 +1157,7 @@ void LoweringPreparePass::runOnOperation() { op->walk([&](mlir::Operation *op) { if (mlir::isa<cir::ArrayCtor, cir::ArrayDtor, cir::CastOp, cir::ComplexMulOp, cir::ComplexDivOp, cir::DynamicCastOp, - cir::FuncOp, cir::GlobalOp, cir::UnaryOp>(op)) + cir::FuncOp, cir::GlobalOp, cir::UnaryOp, cir::CmpThreeWayOp>(op)) opsToTransform.push_back(op); }); diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 6136d48204e0c..e67349074f87f 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1346,6 +1346,59 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite( return mlir::success(); } +static std::string getThreeWayCmpIntrinsicName( + bool signedCmp, unsigned operandWidth, unsigned resultWidth) { + // The intrinsic's name takes the form: + // `llvm.<scmp|ucmp>.i<resultWidth>.i<operandWidth>` + + std::string result = "llvm."; + + if (signedCmp) + result.append("scmp."); + else + result.append("ucmp."); + + // Result type part. + result.push_back('i'); + result.append(std::to_string(resultWidth)); + result.push_back('.'); + + // Operand type part. + result.push_back('i'); + result.append(std::to_string(operandWidth)); + + return result; +} + +mlir::LogicalResult CIRToLLVMCmpThreeWayOpLowering::matchAndRewrite( + cir::CmpThreeWayOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + if (!op.isIntegralComparison() || !op.isStrongOrdering()) { + op.emitError() << "unsupported three-way comparison type"; + return mlir::failure(); + } + + cir::CmpThreeWayInfoAttr cmpInfo = op.getInfo(); + assert(cmpInfo.getLt() == -1 && cmpInfo.getEq() == 0 && cmpInfo.getGt() == 1); + + auto operandTy = mlir::cast<cir::IntType>(op.getLhs().getType()); + cir::IntType resultTy = op.getType(); + std::string llvmIntrinsicName = getThreeWayCmpIntrinsicName( + operandTy.isSigned(), operandTy.getWidth(), resultTy.getWidth()); + + rewriter.setInsertionPoint(op); + + mlir::Value llvmLhs = adaptor.getLhs(); + mlir::Value llvmRhs = adaptor.getRhs(); + mlir::Type llvmResultTy = getTypeConverter()->convertType(resultTy); + mlir::LLVM::CallIntrinsicOp callIntrinsicOp = + createCallLLVMIntrinsicOp(rewriter, op.getLoc(), llvmIntrinsicName, + llvmResultTy, {llvmLhs, llvmRhs}); + + rewriter.replaceOp(op, callIntrinsicOp); + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite( cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/test/CIR/CodeGen/Inputs/std-compare.h b/clang/test/CIR/CodeGen/Inputs/std-compare.h new file mode 100644 index 0000000000000..e... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/169963 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
