Author: Andy Kaylor Date: 2026-01-16T16:38:49-08:00 New Revision: f53c2e69d919ee3b44dd30f7a75ee25d55eab8a9
URL: https://github.com/llvm/llvm-project/commit/f53c2e69d919ee3b44dd30f7a75ee25d55eab8a9 DIFF: https://github.com/llvm/llvm-project/commit/f53c2e69d919ee3b44dd30f7a75ee25d55eab8a9.diff LOG: [CIR] Upstream support for calling through method pointers (#176063) This adds support to CIR for calling functions through pointer to method pointers with the Itanium ABI for x86_64 targets. The ARM-specific handling of method pointers is not-yet implemented. Added: Modified: clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h clang/include/clang/CIR/Dialect/IR/CIROps.td clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td clang/include/clang/CIR/MissingFeatures.h clang/lib/CIR/CodeGen/CIRGenCall.h clang/lib/CIR/CodeGen/CIRGenExpr.cpp clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp clang/lib/CIR/CodeGen/CIRGenFunction.h clang/lib/CIR/Dialect/IR/CIRDialect.cpp clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp clang/test/CIR/CodeGen/pointer-to-member-func.cpp Removed: ################################################################################ diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index eadf3dd6ee0f0..2aaae86240cf2 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -710,6 +710,36 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { cir::YieldOp createYield(mlir::Location loc, mlir::ValueRange value = {}) { return cir::YieldOp::create(*this, loc, value); } + + struct GetMethodResults { + mlir::Value callee; + mlir::Value adjustedThis; + }; + + GetMethodResults createGetMethod(mlir::Location loc, mlir::Value method, + mlir::Value objectPtr) { + // Build the callee function type. + auto methodFuncTy = + mlir::cast<cir::MethodType>(method.getType()).getMemberFuncTy(); + auto methodFuncInputTypes = methodFuncTy.getInputs(); + + auto objectPtrTy = mlir::cast<cir::PointerType>(objectPtr.getType()); + mlir::Type adjustedThisTy = getVoidPtrTy(objectPtrTy.getAddrSpace()); + + llvm::SmallVector<mlir::Type> calleeFuncInputTypes{adjustedThisTy}; + calleeFuncInputTypes.insert(calleeFuncInputTypes.end(), + methodFuncInputTypes.begin(), + methodFuncInputTypes.end()); + cir::FuncType calleeFuncTy = + methodFuncTy.clone(calleeFuncInputTypes, methodFuncTy.getReturnType()); + // TODO(cir): consider the address space of the callee. + assert(!cir::MissingFeatures::addressSpace()); + cir::PointerType calleeTy = getPointerTo(calleeFuncTy); + + auto op = cir::GetMethodOp::create(*this, loc, calleeTy, adjustedThisTy, + method, objectPtr); + return {op.getCallee(), op.getAdjustedThis()}; + } }; } // namespace cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index db1696d50848c..e31024f7dfa84 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -3809,6 +3809,60 @@ def CIR_GetRuntimeMemberOp : CIR_Op<"get_runtime_member"> { let hasLLVMLowering = false; } +//===----------------------------------------------------------------------===// +// GetMethodOp +//===----------------------------------------------------------------------===// + +def CIR_GetMethodOp : CIR_Op<"get_method"> { + let summary = "Resolve a method to a function pointer as callee"; + let description = [{ + The `cir.get_method` operation takes a pointer to method (!cir.method) and + a pointer to a class object (!cir.ptr<!cir.record>>) as input, and + yields a function pointer that points to the actual function corresponding + to the input method. The operation also applies any necessary adjustments to + the input object pointer for calling the method and yields the adjusted + pointer. + + This operation is generated when calling a method through a pointer-to- + member-function in C++: + + ```cpp + // Foo *object; + // int arg; + // void (Foo::*method)(int); + + (object->*method)(arg); + ``` + + The code above will generate CIR similar to: + + ```mlir + %callee, %this = cir.get_method %method, %object + cir.call %callee(%this, %arg) + ``` + + The method type must match the callee type. That is: + - The return type of the method must match the return type of the callee. + - The first parameter of the callee must have type `!cir.ptr<!cir.void>`. + - Types of other parameters of the callee must match the parameters of the + method. + }]; + + let arguments = (ins CIR_MethodType:$method, CIR_PtrToRecordType:$object); + let results = (outs CIR_PtrToFunc:$callee, CIR_VoidPtrType:$adjusted_this); + + let assemblyFormat = [{ + $method `,` $object + `:` `(` qualified(type($method)) `,` qualified(type($object)) `)` + `->` `(` qualified(type($callee)) `,` qualified(type($adjusted_this)) `)` + attr-dict + }]; + + let hasVerifier = 1; + let hasLLVMLowering = false; + let hasCXXABILowering = true; +} + //===----------------------------------------------------------------------===// // VecCreate //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td index 3b2ec5276a677..1a5bae13c96df 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td +++ b/clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td @@ -191,6 +191,12 @@ def CIR_AnyComplexOrIntOrBoolOrFloatType def CIR_AnyRecordType : CIR_TypeBase<"::cir::RecordType", "record type">; +//===----------------------------------------------------------------------===// +// Function Type predicates +//===----------------------------------------------------------------------===// + +def CIR_AnyFuncType : CIR_TypeBase<"::cir::FuncType", "function type">; + //===----------------------------------------------------------------------===// // Array Type predicates //===----------------------------------------------------------------------===// @@ -253,6 +259,8 @@ def CIR_PtrToComplexType : CIR_PtrToType<CIR_AnyComplexType>; def CIR_PtrToRecordType : CIR_PtrToType<CIR_AnyRecordType>; +def CIR_PtrToFunc : CIR_PtrToType<CIR_AnyFuncType>; + def CIR_PtrToArray : CIR_PtrToType<CIR_AnyArrayType>; //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index 39818417fc3d0..359d813171294 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -194,6 +194,11 @@ struct MissingFeatures { static bool lowerModuleLangOpts() { return false; } static bool targetLoweringInfo() { return false; } + // Extra checks for lowerGetMethod in ItaniumCXXABI + static bool emitCFICheck() { return false; } + static bool emitVFEInfo() { return false; } + static bool emitWPDInfo() { return false; } + // Misc static bool aarch64SIMDIntrinsics() { return false; } static bool aarch64SMEIntrinsics() { return false; } @@ -211,6 +216,7 @@ struct MissingFeatures { static bool aggValueSlotVolatile() { return false; } static bool alignCXXRecordDecl() { return false; } static bool allocToken() { return false; } + static bool appleArm64CXXABI() { return false; } static bool appleKext() { return false; } static bool armComputeVolatileBitfields() { return false; } static bool asmGoto() { return false; } diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h b/clang/lib/CIR/CodeGen/CIRGenCall.h index 55b3d9765c5c5..347bd4a7c8266 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.h +++ b/clang/lib/CIR/CodeGen/CIRGenCall.h @@ -33,6 +33,8 @@ class CIRGenCalleeInfo { CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy, clang::GlobalDecl calleeDecl) : calleeProtoTy(calleeProtoTy), calleeDecl(calleeDecl) {} + CIRGenCalleeInfo(const clang::FunctionProtoType *calleeProtoTy) + : calleeProtoTy(calleeProtoTy) {} CIRGenCalleeInfo(clang::GlobalDecl calleeDecl) : calleeProtoTy(nullptr), calleeDecl(calleeDecl) {} diff --git a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp index 3219bc9a808ad..504f18e1a9f31 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExpr.cpp @@ -2350,11 +2350,8 @@ RValue CIRGenFunction::emitCXXMemberCallExpr(const CXXMemberCallExpr *ce, ReturnValueSlot returnValue) { const Expr *callee = ce->getCallee()->IgnoreParens(); - if (isa<BinaryOperator>(callee)) { - cgm.errorNYI(ce->getSourceRange(), - "emitCXXMemberCallExpr: C++ binary operator"); - return RValue::get(nullptr); - } + if (isa<BinaryOperator>(callee)) + return emitCXXMemberPointerCallExpr(ce, returnValue); const auto *me = cast<MemberExpr>(callee); const auto *md = cast<CXXMethodDecl>(me->getMemberDecl()); diff --git a/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp b/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp index eb894c2fb30ee..98cf75f0d69e0 100644 --- a/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenExprCXX.cpp @@ -75,6 +75,50 @@ static MemberCallInfo commonBuildCXXMemberOrOperatorCall( return {required, prefixSize}; } +RValue +CIRGenFunction::emitCXXMemberPointerCallExpr(const CXXMemberCallExpr *ce, + ReturnValueSlot returnValue) { + const BinaryOperator *bo = + cast<BinaryOperator>(ce->getCallee()->IgnoreParens()); + const Expr *baseExpr = bo->getLHS(); + const Expr *memFnExpr = bo->getRHS(); + + const auto *mpt = memFnExpr->getType()->castAs<MemberPointerType>(); + const auto *fpt = mpt->getPointeeType()->castAs<FunctionProtoType>(); + + // Emit the 'this' pointer. + Address thisAddr = Address::invalid(); + if (bo->getOpcode() == BO_PtrMemI) + thisAddr = emitPointerWithAlignment(baseExpr); + else + thisAddr = emitLValue(baseExpr).getAddress(); + + assert(!cir::MissingFeatures::emitTypeCheck()); + + // Get the member function pointer. + mlir::Value memFnPtr = emitScalarExpr(memFnExpr); + + // Resolve the member function pointer to the actual callee and adjust the + // "this" pointer for call. + mlir::Location loc = getLoc(ce->getExprLoc()); + auto [/*mlir::Value*/ calleePtr, /*mlir::Value*/ adjustedThis] = + builder.createGetMethod(loc, memFnPtr, thisAddr.getPointer()); + + // Prepare the call arguments. + CallArgList argsList; + argsList.add(RValue::get(adjustedThis), getContext().VoidPtrTy); + emitCallArgs(argsList, fpt, ce->arguments()); + + RequiredArgs required = RequiredArgs::getFromProtoWithExtraSlots(fpt, 1); + + // Build the call. + CIRGenCallee callee(fpt, calleePtr.getDefiningOp()); + assert(!cir::MissingFeatures::opCallMustTail()); + return emitCall(cgm.getTypes().arrangeCXXMethodCall(argsList, fpt, required, + /*PrefixSize=*/0), + callee, returnValue, argsList, nullptr, loc); +} + RValue CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr( const CallExpr *ce, const CXXMethodDecl *md, ReturnValueSlot returnValue, bool hasQualifier, NestedNameSpecifier qualifier, bool isArrow, diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 049d15562f835..5aa977477f30e 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -1553,6 +1553,9 @@ class CIRGenFunction : public CIRGenTypeCache { clang::NestedNameSpecifier qualifier, bool isArrow, const clang::Expr *base); + RValue emitCXXMemberPointerCallExpr(const CXXMemberCallExpr *ce, + ReturnValueSlot returnValue); + mlir::Value emitCXXNewExpr(const CXXNewExpr *e); void emitNewArrayInitializer(const CXXNewExpr *e, QualType elementType, diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index 91c3050c46806..302a5ae1255fd 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -2703,6 +2703,53 @@ LogicalResult cir::GetRuntimeMemberOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// GetMethodOp Definitions +//===----------------------------------------------------------------------===// + +LogicalResult cir::GetMethodOp::verify() { + cir::MethodType methodTy = getMethod().getType(); + + // Assume objectTy is !cir.ptr<!T> + cir::PointerType objectPtrTy = getObject().getType(); + mlir::Type objectTy = objectPtrTy.getPointee(); + + if (methodTy.getClassTy() != objectTy) + return emitError() << "method class type and object type do not match"; + + // Assume methodFuncTy is !cir.func<!Ret (!Args)> + auto calleeTy = mlir::cast<cir::FuncType>(getCallee().getType().getPointee()); + cir::FuncType methodFuncTy = methodTy.getMemberFuncTy(); + + // We verify at here that calleeTy is !cir.func<!Ret (!cir.ptr<!void>, !Args)> + // Note that the first parameter type of the callee is !cir.ptr<!void> instead + // of !cir.ptr<!T> because the "this" pointer may be adjusted before calling + // the callee. + + if (methodFuncTy.getReturnType() != calleeTy.getReturnType()) + return emitError() + << "method return type and callee return type do not match"; + + llvm::ArrayRef<mlir::Type> calleeArgsTy = calleeTy.getInputs(); + llvm::ArrayRef<mlir::Type> methodFuncArgsTy = methodFuncTy.getInputs(); + + if (calleeArgsTy.empty()) + return emitError() << "callee parameter list lacks receiver object ptr"; + + auto calleeThisArgPtrTy = mlir::dyn_cast<cir::PointerType>(calleeArgsTy[0]); + if (!calleeThisArgPtrTy || + !mlir::isa<cir::VoidType>(calleeThisArgPtrTy.getPointee())) { + return emitError() + << "the first parameter of callee must be a void pointer"; + } + + if (calleeArgsTy.slice(1) != methodFuncArgsTy) + return emitError() + << "callee parameters and method parameters do not match"; + + return mlir::success(); +} + //===----------------------------------------------------------------------===// // GetMemberOp Definitions //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp b/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp index dbe656ac011d8..4429ca10415d8 100644 --- a/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CXXABILowering.cpp @@ -59,7 +59,7 @@ class CIRGenericCXXABILoweringPattern : public mlir::ConversionPattern { // Do not match on operations that have dedicated ABI lowering rewrite rules if (llvm::isa<cir::AllocaOp, cir::BaseDataMemberOp, cir::ConstantOp, cir::CmpOp, cir::DerivedDataMemberOp, cir::FuncOp, - cir::GetRuntimeMemberOp, cir::GlobalOp>(op)) + cir::GetMethodOp, cir::GetRuntimeMemberOp, cir::GlobalOp>(op)) return mlir::failure(); const mlir::TypeConverter *typeConverter = getTypeConverter(); @@ -256,6 +256,17 @@ mlir::LogicalResult CIRDerivedDataMemberOpABILowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRGetMethodOpABILowering::matchAndRewrite( + cir::GetMethodOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Value callee; + mlir::Value thisArg; + lowerModule->getCXXABI().lowerGetMethod( + op, callee, thisArg, adaptor.getMethod(), adaptor.getObject(), rewriter); + rewriter.replaceOp(op, {callee, thisArg}); + return mlir::success(); +} + mlir::LogicalResult CIRGetRuntimeMemberOpABILowering::matchAndRewrite( cir::GetRuntimeMemberOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h index f4d608cdbad03..a7d733afd18c6 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/CIRCXXABI.h @@ -66,6 +66,14 @@ class CIRCXXABI { mlir::Value loweredAddr, mlir::Value loweredMember, mlir::OpBuilder &builder) const = 0; + /// Lower the given cir.get_method op to a sequence of more "primitive" CIR + /// operations that act on the ABI types. The lowered result values will be + /// stored in the given loweredResults array. + virtual void + lowerGetMethod(cir::GetMethodOp op, mlir::Value &callee, mlir::Value &thisArg, + mlir::Value loweredMethod, mlir::Value loweredObjectPtr, + mlir::ConversionPatternRewriter &rewriter) const = 0; + /// Lower the given cir.base_data_member op to a sequence of more "primitive" /// CIR operations that act on the ABI types. virtual mlir::Value lowerBaseDataMember(cir::BaseDataMemberOp op, diff --git a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp index 39fcbcdf49f3e..94342f864fca6 100644 --- a/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp +++ b/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerItaniumCXXABI.cpp @@ -31,8 +31,12 @@ namespace cir { namespace { class LowerItaniumCXXABI : public CIRCXXABI { +protected: + bool useARMMethodPtrABI; + public: - LowerItaniumCXXABI(LowerModule &lm) : CIRCXXABI(lm) {} + LowerItaniumCXXABI(LowerModule &lm, bool useARMMethodPtrABI = false) + : CIRCXXABI(lm), useARMMethodPtrABI(useARMMethodPtrABI) {} /// Lower the given data member pointer type to its ABI type. The returned /// type is also a CIR type. @@ -57,6 +61,11 @@ class LowerItaniumCXXABI : public CIRCXXABI { mlir::Value loweredAddr, mlir::Value loweredMember, mlir::OpBuilder &builder) const override; + void lowerGetMethod(cir::GetMethodOp op, mlir::Value &callee, + mlir::Value &thisArg, mlir::Value loweredMethod, + mlir::Value loweredObjectPtr, + mlir::ConversionPatternRewriter &rewriter) const override; + mlir::Value lowerBaseDataMember(cir::BaseDataMemberOp op, mlir::Value loweredSrc, mlir::OpBuilder &builder) const override; @@ -77,7 +86,26 @@ class LowerItaniumCXXABI : public CIRCXXABI { } // namespace std::unique_ptr<CIRCXXABI> createItaniumCXXABI(LowerModule &lm) { - return std::make_unique<LowerItaniumCXXABI>(lm); + switch (lm.getCXXABIKind()) { + // Note that AArch64 uses the generic ItaniumCXXABI class since it doesn't + // include the other 32-bit ARM oddities: constructor/destructor return values + // and array cookies. + case clang::TargetCXXABI::GenericAArch64: + case clang::TargetCXXABI::AppleARM64: + // TODO: this isn't quite right, clang uses AppleARM64CXXABI which inherits + // from ARMCXXABI. We'll have to follow suit. + assert(!cir::MissingFeatures::appleArm64CXXABI()); + return std::make_unique<LowerItaniumCXXABI>(lm, + /*useARMMethodPtrABI=*/true); + + case clang::TargetCXXABI::GenericItanium: + return std::make_unique<LowerItaniumCXXABI>(lm); + + case clang::TargetCXXABI::Microsoft: + llvm_unreachable("Microsoft ABI is not Itanium-based"); + default: + llvm_unreachable("Other Itanium ABI?"); + } } static cir::IntType getPtrDiffCIRTy(LowerModule &lm) { @@ -202,6 +230,127 @@ mlir::Operation *LowerItaniumCXXABI::lowerGetRuntimeMember( cir::CastKind::bitcast, memberBytesPtr); } +void LowerItaniumCXXABI::lowerGetMethod( + cir::GetMethodOp op, mlir::Value &callee, mlir::Value &thisArg, + mlir::Value loweredMethod, mlir::Value loweredObjectPtr, + mlir::ConversionPatternRewriter &rewriter) const { + // In the Itanium and ARM ABIs, method pointers have the form: + // struct { ptr diff _t ptr; ptr diff _t adj; } memptr; + // + // In the Itanium ABI: + // - method pointers are virtual if (memptr.ptr & 1) is nonzero + // - the this-adjustment is (memptr.adj) + // - the virtual offset is (memptr.ptr - 1) + // + // In the ARM ABI: + // - method pointers are virtual if (memptr.adj & 1) is nonzero + // - the this-adjustment is (memptr.adj >> 1) + // - the virtual offset is (memptr.ptr) + // ARM uses 'adj' for the virtual flag because Thumb functions + // may be only single-byte aligned. + // + // If the member is virtual, the adjusted 'this' pointer points + // to a vtable pointer from which the virtual offset is applied. + // + // If the member is non-virtual, memptr.ptr is the address of + // the function to call. + + mlir::ImplicitLocOpBuilder locBuilder(op.getLoc(), rewriter); + mlir::Type calleePtrTy = op.getCallee().getType(); + + cir::IntType ptr diff CIRTy = getPtrDiffCIRTy(lm); + mlir::Value ptr diff One = + cir::ConstantOp::create(locBuilder, cir::IntAttr::get(ptr diff CIRTy, 1)); + + mlir::Value adj = + cir::ExtractMemberOp::create(locBuilder, ptr diff CIRTy, loweredMethod, 1); + if (useARMMethodPtrABI) { + op.emitError("ARM method ptr abi NYI"); + return; + } + + // Apply the adjustment to the 'this' pointer. + mlir::Type thisVoidPtrTy = + cir::PointerType::get(cir::VoidType::get(locBuilder.getContext()), + op.getObject().getType().getAddrSpace()); + mlir::Value thisVoidPtr = cir::CastOp::create( + locBuilder, thisVoidPtrTy, cir::CastKind::bitcast, loweredObjectPtr); + thisArg = + cir::PtrStrideOp::create(locBuilder, thisVoidPtrTy, thisVoidPtr, adj); + + // Load the "ptr" field of the member function pointer and determine if it + // points to a virtual function. + mlir::Value methodPtrField = + cir::ExtractMemberOp::create(locBuilder, ptr diff CIRTy, loweredMethod, 0); + mlir::Value virtualBit = cir::BinOp::create( + rewriter, op.getLoc(), cir::BinOpKind::And, methodPtrField, ptr diff One); + mlir::Value isVirtual; + if (useARMMethodPtrABI) + llvm_unreachable("ARM method ptr abi NYI"); + else + isVirtual = cir::CmpOp::create(locBuilder, cir::CmpOpKind::eq, virtualBit, + ptr diff One); + + assert(!cir::MissingFeatures::emitCFICheck()); + assert(!cir::MissingFeatures::emitVFEInfo()); + assert(!cir::MissingFeatures::emitWPDInfo()); + + auto buildVirtualCallee = [&](mlir::OpBuilder &b, mlir::Location loc) { + // Load vtable pointer. + // Note that vtable pointer always point to the global address space. + auto vtablePtrTy = + cir::PointerType::get(cir::IntType::get(b.getContext(), 8, true)); + auto vtablePtrPtrTy = cir::PointerType::get( + vtablePtrTy, op.getObject().getType().getAddrSpace()); + auto vtablePtrPtr = cir::CastOp::create(b, loc, vtablePtrPtrTy, + cir::CastKind::bitcast, thisArg); + assert(!cir::MissingFeatures::opTBAA()); + mlir::Value vtablePtr = + cir::LoadOp::create(b, loc, vtablePtrPtr, /*isDeref=*/false, + /*isVolatile=*/false, + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr{}, + /*mem_order=*/cir::MemOrderAttr()); + + // Get the vtable offset. + mlir::Value vtableOffset = methodPtrField; + assert(!useARMMethodPtrABI && "ARM method ptr abi NYI"); + vtableOffset = cir::BinOp::create(b, loc, cir::BinOpKind::Sub, vtableOffset, + ptr diff One); + + assert(!cir::MissingFeatures::emitCFICheck()); + assert(!cir::MissingFeatures::emitVFEInfo()); + assert(!cir::MissingFeatures::emitWPDInfo()); + + // Apply the offset to the vtable pointer and get the pointer to the target + // virtual function. Then load that pointer to get the callee. + mlir::Value vfpAddr = cir::PtrStrideOp::create(locBuilder, vtablePtrTy, + vtablePtr, vtableOffset); + auto vfpPtrTy = cir::PointerType::get(calleePtrTy); + mlir::Value vfpPtr = cir::CastOp::create(locBuilder, vfpPtrTy, + cir::CastKind::bitcast, vfpAddr); + auto fnPtr = cir::LoadOp::create(b, loc, vfpPtr, + /*isDeref=*/false, /*isVolatile=*/false, + /*alignment=*/mlir::IntegerAttr(), + /*sync_scope=*/cir::SyncScopeKindAttr{}, + /*mem_order=*/cir::MemOrderAttr()); + + cir::YieldOp::create(b, loc, fnPtr.getResult()); + assert(!cir::MissingFeatures::emitCFICheck()); + }; + + callee = cir::TernaryOp::create( + locBuilder, isVirtual, /*thenBuilder=*/buildVirtualCallee, + /*elseBuilder=*/ + [&](mlir::OpBuilder &b, mlir::Location loc) { + auto fnPtr = cir::CastOp::create(b, loc, calleePtrTy, + cir::CastKind::int_to_ptr, + methodPtrField); + cir::YieldOp::create(b, loc, fnPtr.getResult()); + }) + .getResult(); +} + static mlir::Value lowerDataMemberCast(mlir::Operation *op, mlir::Value loweredSrc, std::int64_t offset, diff --git a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp index 47c5871e72290..ad081d0d06dbc 100644 --- a/clang/test/CIR/CodeGen/pointer-to-member-func.cpp +++ b/clang/test/CIR/CodeGen/pointer-to-member-func.cpp @@ -38,3 +38,87 @@ auto make_non_virtual() -> void (Foo::*)(int) { // OGCG: define {{.*}} { i64, i64 } @_Z16make_non_virtualv() // OGCG: ret { i64, i64 } { i64 ptrtoint (ptr @_ZN3Foo2m1Ei to i64), i64 0 } + +void call(Foo *obj, void (Foo::*func)(int), int arg) { + (obj->*func)(arg); +} + +// CIR-BEFORE: cir.func {{.*}} @_Z4callP3FooMS_FviEi +// CIR-BEFORE: %[[OBJ:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!cir.ptr<!rec_Foo>>, !cir.ptr<!rec_Foo> +// CIR-BEFORE: %[[FUNC:.*]] = cir.load{{.*}} : !cir.ptr<!cir.method<!cir.func<(!s32i)> in !rec_Foo>>, !cir.method<!cir.func<(!s32i)> in !rec_Foo> +// CIR-BEFORE: %[[CALLEE:.*]], %[[THIS:.*]] = cir.get_method %[[FUNC]], %[[OBJ]] : (!cir.method<!cir.func<(!s32i)> in !rec_Foo>, !cir.ptr<!rec_Foo>) -> (!cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>>, !cir.ptr<!void>) +// CIR-BEFORE: %[[ARG:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i +// CIR-BEFORE: cir.call %[[CALLEE]](%[[THIS]], %[[ARG]]) : (!cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>>, !cir.ptr<!void>, !s32i) -> () + +// CIR-AFTER: cir.func {{.*}} @_Z4callP3FooMS_FviEi +// CIR-AFTER: %[[OBJ:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!cir.ptr<!rec_Foo>>, !cir.ptr<!rec_Foo> +// CIR-AFTER: %[[FUNC:.*]] = cir.load{{.*}} : !cir.ptr<!rec_anon_struct>, !rec_anon_struct +// CIR-AFTER: %[[VIRT_BIT:.*]] = cir.const #cir.int<1> : !s64i +// CIR-AFTER: %[[ADJ:.*]] = cir.extract_member %[[FUNC]][1] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[THIS:.*]] = cir.cast bitcast %[[OBJ]] : !cir.ptr<!rec_Foo> -> !cir.ptr<!void> +// CIR-AFTER: %[[ADJUSTED_THIS:.*]] = cir.ptr_stride %[[THIS]], %[[ADJ]] : (!cir.ptr<!void>, !s64i) -> !cir.ptr<!void> +// CIR-AFTER: %[[METHOD_PTR:.*]] = cir.extract_member %[[FUNC]][0] : !rec_anon_struct -> !s64i +// CIR-AFTER: %[[VIRT_BIT_TEST:.*]] = cir.binop(and, %[[METHOD_PTR]], %[[VIRT_BIT]]) : !s64i +// CIR-AFTER: %[[IS_VIRTUAL:.*]] = cir.cmp(eq, %[[VIRT_BIT_TEST]], %[[VIRT_BIT]]) : !s64i, !cir.bool +// CIR-AFTER: %[[CALLEE:.*]] = cir.ternary(%[[IS_VIRTUAL]], true { +// CIR-AFTER: %[[VTABLE_PTR:.*]] = cir.cast bitcast %[[ADJUSTED_THIS]] : !cir.ptr<!void> -> !cir.ptr<!cir.ptr<!s8i>> +// CIR-AFTER: %[[VTABLE:.*]] = cir.load %[[VTABLE_PTR]] : !cir.ptr<!cir.ptr<!s8i>>, !cir.ptr<!s8i> +// CIR-AFTER: %[[OFFSET:.*]] = cir.binop(sub, %[[METHOD_PTR]], %[[VIRT_BIT]]) : !s64i +// CIR-AFTER: %[[VTABLE_SLOT:.*]] = cir.ptr_stride %[[VTABLE]], %[[OFFSET]] : (!cir.ptr<!s8i>, !s64i) -> !cir.ptr<!s8i> +// CIR-AFTER: %[[VIRTUAL_FN_PTR:.*]] = cir.cast bitcast %[[VTABLE_SLOT]] : !cir.ptr<!s8i> -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>>> +// CIR-AFTER: %[[VIRTUAL_FN_PTR_LOAD:.*]] = cir.load %[[VIRTUAL_FN_PTR]] : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>>>, !cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>> +// CIR-AFTER: cir.yield %[[VIRTUAL_FN_PTR_LOAD]] : !cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>> +// CIR-AFTER: }, false { +// CIR-AFTER: %[[CALLEE_PTR:.*]] = cir.cast int_to_ptr %[[METHOD_PTR]] : !s64i -> !cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>> +// CIR-AFTER: cir.yield %[[CALLEE_PTR]] : !cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>> +// CIR-AFTER: }) : (!cir.bool) -> !cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>> +// CIR-AFTER: %[[ARG:.*]] = cir.load{{.*}} %{{.*}} : !cir.ptr<!s32i>, !s32i +// CIR-AFTER: cir.call %[[CALLEE]](%[[ADJUSTED_THIS]], %[[ARG]]) : (!cir.ptr<!cir.func<(!cir.ptr<!void>, !s32i)>>, !cir.ptr<!void>, !s32i) -> () + +// LLVM: define {{.*}} @_Z4callP3FooMS_FviEi +// LLVM: %[[OBJ:.*]] = load ptr, ptr %{{.*}} +// LLVM: %[[MEMFN_PTR:.*]] = load { i64, i64 }, ptr %{{.*}} +// LLVM: %[[THIS_ADJ:.*]] = extractvalue { i64, i64 } %[[MEMFN_PTR]], 1 +// LLVM: %[[ADJUSTED_THIS:.*]] = getelementptr i8, ptr %[[OBJ]], i64 %[[THIS_ADJ]] +// LLVM: %[[PTR_FIELD:.*]] = extractvalue { i64, i64 } %[[MEMFN_PTR]], 0 +// LLVM: %[[VIRT_BIT:.*]] = and i64 %[[PTR_FIELD]], 1 +// LLVM: %[[IS_VIRTUAL:.*]] = icmp eq i64 %[[VIRT_BIT]], 1 +// LLVM: br i1 %[[IS_VIRTUAL]], label %[[HANDLE_VIRTUAL:.*]], label %[[HANDLE_NON_VIRTUAL:.*]] +// LLVM: [[HANDLE_VIRTUAL]]: +// LLVM: %[[VTABLE:.*]] = load ptr, ptr %[[ADJUSTED_THIS]] +// LLVM: %[[OFFSET:.*]] = sub i64 %[[PTR_FIELD]], 1 +// LLVM: %[[VTABLE_SLOT:.*]] = getelementptr i8, ptr %[[VTABLE]], i64 %[[OFFSET]] +// LLVM: %[[VIRTUAL_FN_PTR:.*]] = load ptr, ptr %[[VTABLE_SLOT]] +// LLVM: br label %[[CONTINUE:.*]] +// LLVM: [[HANDLE_NON_VIRTUAL]]: +// LLVM: %[[FUNC_PTR:.*]] = inttoptr i64 %[[PTR_FIELD]] to ptr +// LLVM: br label %[[CONTINUE]] +// LLVM: [[CONTINUE]]: +// LLVM: %[[CALLEE_PTR:.*]] = phi ptr [ %[[FUNC_PTR]], %[[HANDLE_NON_VIRTUAL]] ], [ %[[VIRTUAL_FN_PTR]], %[[HANDLE_VIRTUAL]] ] +// LLVM: %[[ARG:.*]] = load i32, ptr %{{.+}} +// LLVM: call void %[[CALLEE_PTR]](ptr %[[ADJUSTED_THIS]], i32 %[[ARG]]) +// LLVM: } + +// OGCG: define {{.*}} @_Z4callP3FooMS_FviEi +// OGCG: %[[OBJ:.*]] = load ptr, ptr %{{.*}} +// OGCG: %[[MEMFN_PTR:.*]] = load { i64, i64 }, ptr %{{.*}} +// OGCG: %[[THIS_ADJ:.*]] = extractvalue { i64, i64 } %[[MEMFN_PTR]], 1 +// OGCG: %[[ADJUSTED_THIS:.*]] = getelementptr inbounds i8, ptr %[[OBJ]], i64 %[[THIS_ADJ]] +// OGCG: %[[PTR_FIELD:.*]] = extractvalue { i64, i64 } %[[MEMFN_PTR]], 0 +// OGCG: %[[VIRT_BIT:.*]] = and i64 %[[PTR_FIELD]], 1 +// OGCG: %[[IS_VIRTUAL:.*]] = icmp ne i64 %[[VIRT_BIT]], 0 +// OGCG: br i1 %[[IS_VIRTUAL]], label %[[HANDLE_VIRTUAL:.*]], label %[[HANDLE_NON_VIRTUAL:.*]] +// OGCG: [[HANDLE_VIRTUAL]]: +// OGCG: %[[VTABLE:.*]] = load ptr, ptr %[[ADJUSTED_THIS]] +// OGCG: %[[OFFSET:.*]] = sub i64 %[[PTR_FIELD]], 1 +// OGCG: %[[VTABLE_SLOT:.*]] = getelementptr i8, ptr %[[VTABLE]], i64 %[[OFFSET]] +// OGCG: %[[VIRTUAL_FN_PTR:.*]] = load ptr, ptr %[[VTABLE_SLOT]] +// OGCG: br label %[[CONTINUE:.*]] +// OGCG: [[HANDLE_NON_VIRTUAL]]: +// OGCG: %[[FUNC_PTR:.*]] = inttoptr i64 %[[PTR_FIELD]] to ptr +// OGCG: br label %[[CONTINUE]] +// OGCG: [[CONTINUE]]: +// OGCG: %[[CALLEE_PTR:.*]] = phi ptr [ %[[VIRTUAL_FN_PTR]], %[[HANDLE_VIRTUAL]] ], [ %[[FUNC_PTR]], %[[HANDLE_NON_VIRTUAL]] ] +// OGCG: %[[ARG:.*]] = load i32, ptr %{{.+}} +// OGCG: call void %[[CALLEE_PTR]](ptr {{.*}} %[[ADJUSTED_THIS]], i32 {{.*}} %[[ARG]]) +// OGCG: } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
