https://github.com/andykaylor created https://github.com/llvm/llvm-project/pull/153893
This change adds support for calling virtual functions. This includes adding the cir.vtable.get_virtual_fn_addr operation to lookup the address of the function being called from an object's vtable. >From 470cc01f64d2740dfb865e2d4a1171b87994c8d0 Mon Sep 17 00:00:00 2001 From: Andy Kaylor <akay...@nvidia.com> Date: Thu, 7 Aug 2025 16:05:10 -0700 Subject: [PATCH] [CIR] Add support for calling virtual functions This change adds support for calling virtual functions. This includes adding the cir.vtable.get_virtual_fn_addr operation to lookup the address of the function being called from an object's vtable. --- .../CIR/Dialect/Builder/CIRBaseBuilder.h | 13 +++++ clang/include/clang/CIR/Dialect/IR/CIROps.td | 48 ++++++++++++++++++ clang/include/clang/CIR/MissingFeatures.h | 3 +- clang/lib/CIR/CodeGen/CIRGenCXXABI.h | 17 +++++++ clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp | 49 +++++++++++++------ clang/lib/CIR/CodeGen/CIRGenCall.cpp | 7 ++- clang/lib/CIR/CodeGen/CIRGenCall.h | 46 ++++++++++++++++- clang/lib/CIR/CodeGen/CIRGenClass.cpp | 14 ++++++ clang/lib/CIR/CodeGen/CIRGenFunction.h | 5 ++ clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp | 48 ++++++++++++++++++ clang/lib/CIR/CodeGen/CIRGenValue.h | 8 +++ .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 22 +++++++-- .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h | 11 +++++ .../CIR/CodeGen/virtual-function-calls.cpp | 33 +++++++++++++ 14 files changed, 303 insertions(+), 21 deletions(-) diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h index 0bf3cb26be850..6181b64fe6d0e 100644 --- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h +++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h @@ -157,6 +157,19 @@ class CIRBaseBuilderTy : public mlir::OpBuilder { return create<cir::ComplexImagOp>(loc, operandTy.getElementType(), operand); } + cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr, + uint64_t alignment = 0) { + mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment); + assert(!cir::MissingFeatures::opLoadStoreVolatile()); + assert(!cir::MissingFeatures::opLoadStoreMemOrder()); + return create<cir::LoadOp>(loc, ptr, /*isDeref=*/false, alignmentAttr); + } + + mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr, + uint64_t alignment) { + return createLoad(loc, ptr, alignment); + } + mlir::Value createNot(mlir::Value value) { return create<cir::UnaryOp>(value.getLoc(), value.getType(), cir::UnaryOpKind::Not, value); diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index a181c95494eff..9f98a654b796c 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -1782,6 +1782,54 @@ def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", [Pure]> { }]; } +//===----------------------------------------------------------------------===// +// VTableGetVirtualFnAddrOp +//===----------------------------------------------------------------------===// + +def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [ + Pure +]> { + let summary = "Get a the address of a virtual function pointer"; + let description = [{ + The `vtable.get_virtual_fn_addr` operation retrieves the address of a + virtual function pointer from an object's vtable (__vptr). + This is an abstraction to perform the basic pointer arithmetic to get + the address of the virtual function pointer, which can then be loaded and + called. + + The `vptr` operand must be a `!cir.ptr<!cir.vptr>` value, which would + have been returned by a previous call to `cir.vatble.get_vptr`. The + `index` operand is an index of the virtual function in the vtable. + + The return type is a pointer-to-pointer to the function type. + + Example: + ```mlir + %2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C> + %3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr> + %4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr + %5 = cir.vtable.get_virtual_fn_addr %4[2] : !cir.vptr + -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>> + %6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) + -> !s32i>>>, + !cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>> + %7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>, + !cir.ptr<!rec_C>) -> !s32i + ``` + }]; + + let arguments = (ins + Arg<CIR_VPtrType, "vptr", [MemRead]>:$vptr, + I64Attr:$index); + + let results = (outs CIR_PointerType:$result); + + let assemblyFormat = [{ + $vptr `[` $index `]` attr-dict + `:` qualified(type($vptr)) `->` qualified(type($result)) + }]; +} + //===----------------------------------------------------------------------===// // SetBitfieldOp //===----------------------------------------------------------------------===// diff --git a/clang/include/clang/CIR/MissingFeatures.h b/clang/include/clang/CIR/MissingFeatures.h index baab62f572b98..c5d7006bfabf7 100644 --- a/clang/include/clang/CIR/MissingFeatures.h +++ b/clang/include/clang/CIR/MissingFeatures.h @@ -95,7 +95,6 @@ struct MissingFeatures { static bool opCallArgEvaluationOrder() { return false; } static bool opCallCallConv() { return false; } static bool opCallMustTail() { return false; } - static bool opCallVirtual() { return false; } static bool opCallInAlloca() { return false; } static bool opCallAttrs() { return false; } static bool opCallSurroundingTry() { return false; } @@ -204,6 +203,7 @@ struct MissingFeatures { static bool dataLayoutTypeAllocSize() { return false; } static bool dataLayoutTypeStoreSize() { return false; } static bool deferredCXXGlobalInit() { return false; } + static bool devirtualizeMemberFunction() { return false; } static bool ehCleanupFlags() { return false; } static bool ehCleanupScope() { return false; } static bool ehCleanupScopeRequiresEHCleanup() { return false; } @@ -215,6 +215,7 @@ struct MissingFeatures { static bool emitLValueAlignmentAssumption() { return false; } static bool emitNullabilityCheck() { return false; } static bool emitTypeCheck() { return false; } + static bool emitTypeMetadataCodeForVCall() { return false; } static bool fastMathFlags() { return false; } static bool fpConstraints() { return false; } static bool generateDebugInfo() { return false; } diff --git a/clang/lib/CIR/CodeGen/CIRGenCXXABI.h b/clang/lib/CIR/CodeGen/CIRGenCXXABI.h index abde1a7687a90..3f1cb8363a556 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCXXABI.h +++ b/clang/lib/CIR/CodeGen/CIRGenCXXABI.h @@ -63,6 +63,16 @@ class CIRGenCXXABI { /// parameter. virtual bool needsVTTParameter(clang::GlobalDecl gd) { return false; } + /// Perform ABI-specific "this" argument adjustment required prior to + /// a call of a virtual function. + /// The "VirtualCall" argument is true iff the call itself is virtual. + virtual Address adjustThisArgumentForVirtualFunctionCall(CIRGenFunction &cgf, + clang::GlobalDecl gd, + Address thisPtr, + bool virtualCall) { + return thisPtr; + } + /// Build a parameter variable suitable for 'this'. void buildThisParam(CIRGenFunction &cgf, FunctionArgList ¶ms); @@ -100,6 +110,13 @@ class CIRGenCXXABI { virtual cir::GlobalOp getAddrOfVTable(const CXXRecordDecl *rd, CharUnits vptrOffset) = 0; + /// Build a virtual function pointer in the ABI-specific way. + virtual CIRGenCallee getVirtualFunctionPointer(CIRGenFunction &cgf, + clang::GlobalDecl gd, + Address thisAddr, + mlir::Type ty, + SourceLocation loc) = 0; + /// Get the address point of the vtable for the given base subobject. virtual mlir::Value getVTableAddressPoint(BaseSubobject base, diff --git a/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp b/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp index bc30c7bd130af..c9e4ed92d16bb 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp @@ -79,11 +79,10 @@ RValue CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr( const Expr *base) { assert(isa<CXXMemberCallExpr>(ce) || isa<CXXOperatorCallExpr>(ce)); - if (md->isVirtual()) { - cgm.errorNYI(ce->getSourceRange(), - "emitCXXMemberOrOperatorMemberCallExpr: virtual call"); - return RValue::get(nullptr); - } + // Compute the object pointer. + bool canUseVirtualCall = md->isVirtual() && !hasQualifier; + const CXXMethodDecl *devirtualizedMethod = nullptr; + assert(!cir::MissingFeatures::devirtualizeMemberFunction()); // Note on trivial assignment // -------------------------- @@ -127,7 +126,8 @@ RValue CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr( return RValue::get(nullptr); // Compute the function type we're calling - const CXXMethodDecl *calleeDecl = md; + const CXXMethodDecl *calleeDecl = + devirtualizedMethod ? devirtualizedMethod : md; const CIRGenFunctionInfo *fInfo = nullptr; if (isa<CXXDestructorDecl>(calleeDecl)) { cgm.errorNYI(ce->getSourceRange(), @@ -137,25 +137,46 @@ RValue CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr( fInfo = &cgm.getTypes().arrangeCXXMethodDeclaration(calleeDecl); - mlir::Type ty = cgm.getTypes().getFunctionType(*fInfo); + cir::FuncType ty = cgm.getTypes().getFunctionType(*fInfo); assert(!cir::MissingFeatures::sanitizers()); assert(!cir::MissingFeatures::emitTypeCheck()); + // C++ [class.virtual]p12: + // Explicit qualification with the scope operator (5.1) suppresses the + // virtual call mechanism. + // + // We also don't emit a virtual call if the base expression has a record type + // because then we know what the type is. + bool useVirtualCall = canUseVirtualCall && !devirtualizedMethod; + if (isa<CXXDestructorDecl>(calleeDecl)) { cgm.errorNYI(ce->getSourceRange(), "emitCXXMemberOrOperatorMemberCallExpr: destructor call"); return RValue::get(nullptr); } - assert(!cir::MissingFeatures::sanitizers()); - if (getLangOpts().AppleKext) { - cgm.errorNYI(ce->getSourceRange(), - "emitCXXMemberOrOperatorMemberCallExpr: AppleKext"); - return RValue::get(nullptr); + CIRGenCallee callee; + if (useVirtualCall) { + callee = CIRGenCallee::forVirtual(ce, md, thisPtr.getAddress(), ty); + } else { + assert(!cir::MissingFeatures::sanitizers()); + if (getLangOpts().AppleKext) { + cgm.errorNYI(ce->getSourceRange(), + "emitCXXMemberOrOperatorMemberCallExpr: AppleKext"); + return RValue::get(nullptr); + } + + callee = CIRGenCallee::forDirect(cgm.getAddrOfFunction(calleeDecl, ty), + GlobalDecl(calleeDecl)); + } + + if (md->isVirtual()) { + Address newThisAddr = + cgm.getCXXABI().adjustThisArgumentForVirtualFunctionCall( + *this, calleeDecl, thisPtr.getAddress(), useVirtualCall); + thisPtr.setAddress(newThisAddr); } - CIRGenCallee callee = - CIRGenCallee::forDirect(cgm.getAddrOfFunction(md, ty), GlobalDecl(md)); return emitCXXMemberOrOperatorCall( calleeDecl, callee, returnValue, thisPtr.getPointer(), diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp b/clang/lib/CIR/CodeGen/CIRGenCall.cpp index e3fe3ca1c30c9..6d749940fa128 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp @@ -56,7 +56,12 @@ cir::FuncType CIRGenTypes::getFunctionType(const CIRGenFunctionInfo &info) { } CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const { - assert(!cir::MissingFeatures::opCallVirtual()); + if (isVirtual()) { + const CallExpr *ce = getVirtualCallExpr(); + return cgf.cgm.getCXXABI().getVirtualFunctionPointer( + cgf, getVirtualMethodDecl(), getThisAddress(), getVirtualFunctionType(), + ce ? ce->getBeginLoc() : SourceLocation()); + } return *this; } diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h b/clang/lib/CIR/CodeGen/CIRGenCall.h index 47d998ae25838..81cbb854f3b7d 100644 --- a/clang/lib/CIR/CodeGen/CIRGenCall.h +++ b/clang/lib/CIR/CodeGen/CIRGenCall.h @@ -47,8 +47,9 @@ class CIRGenCallee { Invalid, Builtin, PseudoDestructor, + Virtual, - Last = Builtin, + Last = Virtual }; struct BuiltinInfoStorage { @@ -58,6 +59,12 @@ class CIRGenCallee { struct PseudoDestructorInfoStorage { const clang::CXXPseudoDestructorExpr *expr; }; + struct VirtualInfoStorage { + const clang::CallExpr *ce; + clang::GlobalDecl md; + Address addr; + cir::FuncType fTy; + }; SpecialKind kindOrFunctionPtr; @@ -65,6 +72,7 @@ class CIRGenCallee { CIRGenCalleeInfo abstractInfo; BuiltinInfoStorage builtinInfo; PseudoDestructorInfoStorage pseudoDestructorInfo; + VirtualInfoStorage virtualInfo; }; explicit CIRGenCallee(SpecialKind kind) : kindOrFunctionPtr(kind) {} @@ -128,7 +136,8 @@ class CIRGenCallee { CIRGenCallee prepareConcreteCallee(CIRGenFunction &cgf) const; CIRGenCalleeInfo getAbstractInfo() const { - assert(!cir::MissingFeatures::opCallVirtual()); + if (isVirtual()) + return virtualInfo.md; assert(isOrdinary()); return abstractInfo; } @@ -138,6 +147,39 @@ class CIRGenCallee { return reinterpret_cast<mlir::Operation *>(kindOrFunctionPtr); } + bool isVirtual() const { return kindOrFunctionPtr == SpecialKind::Virtual; } + + static CIRGenCallee forVirtual(const clang::CallExpr *ce, + clang::GlobalDecl md, Address addr, + cir::FuncType fTy) { + CIRGenCallee result(SpecialKind::Virtual); + result.virtualInfo.ce = ce; + result.virtualInfo.md = md; + result.virtualInfo.addr = addr; + result.virtualInfo.fTy = fTy; + return result; + } + + const clang::CallExpr *getVirtualCallExpr() const { + assert(isVirtual()); + return virtualInfo.ce; + } + + clang::GlobalDecl getVirtualMethodDecl() const { + assert(isVirtual()); + return virtualInfo.md; + } + + Address getThisAddress() const { + assert(isVirtual()); + return virtualInfo.addr; + } + + cir::FuncType getVirtualFunctionType() const { + assert(isVirtual()); + return virtualInfo.fTy; + } + void setFunctionPointer(mlir::Operation *functionPtr) { assert(isOrdinary()); kindOrFunctionPtr = SpecialKind(reinterpret_cast<uintptr_t>(functionPtr)); diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp b/clang/lib/CIR/CodeGen/CIRGenClass.cpp index a3947047de079..3e5dc22426d8e 100644 --- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp @@ -657,6 +657,20 @@ Address CIRGenFunction::getAddressOfBaseClass( return value; } +// TODO(cir): this can be shared with LLVM codegen. +bool CIRGenFunction::shouldEmitVTableTypeCheckedLoad(const CXXRecordDecl *rd) { + assert(!cir::MissingFeatures::hiddenVisibility()); + if (!cgm.getCodeGenOpts().WholeProgramVTables) + return false; + + if (cgm.getCodeGenOpts().VirtualFunctionElimination) + return true; + + assert(!cir::MissingFeatures::sanitizers()); + + return false; +} + mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr, const CXXRecordDecl *rd) { auto vtablePtr = cir::VTableGetVPtrOp::create( diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h b/clang/lib/CIR/CodeGen/CIRGenFunction.h index 9a887ec047f86..7ad3ad3559de8 100644 --- a/clang/lib/CIR/CodeGen/CIRGenFunction.h +++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h @@ -552,6 +552,11 @@ class CIRGenFunction : public CIRGenTypeCache { mlir::Value getVTablePtr(mlir::Location loc, Address thisAddr, const clang::CXXRecordDecl *vtableClass); + /// Returns whether we should perform a type checked load when loading a + /// virtual function for virtual calls to members of RD. This is generally + /// true when both vcall CFI and whole-program-vtables are enabled. + bool shouldEmitVTableTypeCheckedLoad(const CXXRecordDecl *rd); + /// A scope within which we are constructing the fields of an object which /// might use a CXXDefaultInitExpr. This stashes away a 'this' value to use if /// we need to evaluate the CXXDefaultInitExpr within the evaluation. diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp index 43dfd6468046d..347656b5f6488 100644 --- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp @@ -69,6 +69,10 @@ class CIRGenItaniumCXXABI : public CIRGenCXXABI { cir::GlobalOp getAddrOfVTable(const CXXRecordDecl *rd, CharUnits vptrOffset) override; + CIRGenCallee getVirtualFunctionPointer(CIRGenFunction &cgf, + clang::GlobalDecl gd, Address thisAddr, + mlir::Type ty, + SourceLocation loc) override; mlir::Value getVTableAddressPoint(BaseSubobject base, const CXXRecordDecl *vtableClass) override; @@ -349,6 +353,50 @@ cir::GlobalOp CIRGenItaniumCXXABI::getAddrOfVTable(const CXXRecordDecl *rd, return vtable; } +CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer( + CIRGenFunction &cgf, clang::GlobalDecl gd, Address thisAddr, mlir::Type ty, + SourceLocation srcLoc) { + CIRGenBuilderTy &builder = cgm.getBuilder(); + mlir::Location loc = cgf.getLoc(srcLoc); + cir::PointerType tyPtr = builder.getPointerTo(ty); + auto *methodDecl = cast<CXXMethodDecl>(gd.getDecl()); + mlir::Value vtable = cgf.getVTablePtr(loc, thisAddr, methodDecl->getParent()); + + uint64_t vtableIndex = cgm.getItaniumVTableContext().getMethodVTableIndex(gd); + mlir::Value vfunc{}; + if (cgf.shouldEmitVTableTypeCheckedLoad(methodDecl->getParent())) { + cgm.errorNYI(loc, "getVirtualFunctionPointer: emitVTableTypeCheckedLoad"); + } else { + assert(!cir::MissingFeatures::emitTypeMetadataCodeForVCall()); + + mlir::Value vfuncLoad; + if (cgm.getItaniumVTableContext().isRelativeLayout()) { + assert(!cir::MissingFeatures::vtableRelativeLayout()); + cgm.errorNYI(loc, "getVirtualFunctionPointer: isRelativeLayout"); + } else { + auto vtableSlotPtr = cir::VTableGetVirtualFnAddrOp::create( + builder, loc, builder.getPointerTo(tyPtr), vtable, vtableIndex); + vfuncLoad = builder.createAlignedLoad( + loc, vtableSlotPtr, cgf.getPointerAlign().getQuantity()); + } + + // Add !invariant.load md to virtual function load to indicate that + // function didn't change inside vtable. + // It's safe to add it without -fstrict-vtable-pointers, but it would not + // help in devirtualization because it will only matter if we will have 2 + // the same virtual function loads from the same vtable load, which won't + // happen without enabled devirtualization with -fstrict-vtable-pointers. + if (cgm.getCodeGenOpts().OptimizationLevel > 0 && + cgm.getCodeGenOpts().StrictVTablePointers) { + cgm.errorNYI(loc, "getVirtualFunctionPointer: strictVTablePointers"); + } + vfunc = vfuncLoad; + } + + CIRGenCallee callee(gd, vfunc.getDefiningOp()); + return callee; +} + mlir::Value CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject base, const CXXRecordDecl *vtableClass) { diff --git a/clang/lib/CIR/CodeGen/CIRGenValue.h b/clang/lib/CIR/CodeGen/CIRGenValue.h index 661cecf8416b6..ac7e1cc1a1db6 100644 --- a/clang/lib/CIR/CodeGen/CIRGenValue.h +++ b/clang/lib/CIR/CodeGen/CIRGenValue.h @@ -212,6 +212,14 @@ class LValue { return Address(getPointer(), elementType, getAlignment()); } + void setAddress(Address address) { + assert(isSimple()); + v = address.getPointer(); + elementType = address.getElementType(); + alignment = address.getAlignment().getQuantity(); + assert(!cir::MissingFeatures::addressIsKnownNonNull()); + } + const clang::Qualifiers &getQuals() const { return quals; } clang::Qualifiers &getQuals() { return quals; } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 9f7521db78bec..1111e5f1b5b2f 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -1085,8 +1085,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange callOperands, auto calleeTy = op->getOperands().front().getType(); auto calleePtrTy = cast<cir::PointerType>(calleeTy); auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee()); - calleeFuncTy.dump(); - converter->convertType(calleeFuncTy).dump(); + llvm::append_range(adjustedCallOperands, callOperands); llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>( converter->convertType(calleeFuncTy)); } @@ -2193,6 +2192,9 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter &converter, return llvmStruct; }); + converter.addConversion([&](cir::VoidType type) -> mlir::Type { + return mlir::LLVM::LLVMVoidType::get(type.getContext()); + }); } // The applyPartialConversion function traverses blocks in the dominance order, @@ -2345,7 +2347,8 @@ void ConvertCIRToLLVMPass::runOnOperation() { CIRToLLVMVecSplatOpLowering, CIRToLLVMVecTernaryOpLowering, CIRToLLVMVTableAddrPointOpLowering, - CIRToLLVMVTableGetVPtrOpLowering + CIRToLLVMVTableGetVPtrOpLowering, + CIRToLLVMVTableGetVirtualFnAddrOpLowering // clang-format on >(converter, patterns.getContext()); @@ -2481,6 +2484,19 @@ mlir::LogicalResult CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite( return mlir::success(); } +mlir::LogicalResult CIRToLLVMVTableGetVirtualFnAddrOpLowering::matchAndRewrite( + cir::VTableGetVirtualFnAddrOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const { + mlir::Type targetType = getTypeConverter()->convertType(op.getType()); + auto eltType = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); + llvm::SmallVector<mlir::LLVM::GEPArg> offsets = + llvm::SmallVector<mlir::LLVM::GEPArg>{op.getIndex()}; + rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>( + op, targetType, eltType, adaptor.getVptr(), offsets, + mlir::LLVM::GEPNoWrapFlags::inbounds); + return mlir::success(); +} + mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite( cir::StackSaveOp op, OpAdaptor adaptor, mlir::ConversionPatternRewriter &rewriter) const { diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h index 91e8505233379..be1d1c44bf9db 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h @@ -477,6 +477,17 @@ class CIRToLLVMVTableGetVPtrOpLowering mlir::ConversionPatternRewriter &) const override; }; +class CIRToLLVMVTableGetVirtualFnAddrOpLowering + : public mlir::OpConversionPattern<cir::VTableGetVirtualFnAddrOp> { +public: + using mlir::OpConversionPattern< + cir::VTableGetVirtualFnAddrOp>::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(cir::VTableGetVirtualFnAddrOp op, OpAdaptor, + mlir::ConversionPatternRewriter &) const override; +}; + class CIRToLLVMStackSaveOpLowering : public mlir::OpConversionPattern<cir::StackSaveOp> { public: diff --git a/clang/test/CIR/CodeGen/virtual-function-calls.cpp b/clang/test/CIR/CodeGen/virtual-function-calls.cpp index 4787d78aa0e35..e68b38fad3f84 100644 --- a/clang/test/CIR/CodeGen/virtual-function-calls.cpp +++ b/clang/test/CIR/CodeGen/virtual-function-calls.cpp @@ -46,3 +46,36 @@ A::A() {} // NOTE: The GEP in OGCG looks very different from the one generated with CIR, // but it is equivalent. The OGCG GEP indexes by base pointer, then // structure, then array, whereas the CIR GEP indexes by byte offset. + +void f1(A *a) { + a->f('c'); +} + +// CIR: cir.func{{.*}} @_Z2f1P1A(%arg0: !cir.ptr<!rec_A> {{.*}}) +// CIR: %[[A_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A> +// CIR: cir.store %arg0, %[[A_ADDR]] +// CIR: %[[A:.*]] = cir.load{{.*}} %[[A_ADDR]] +// CIR: %[[C_LITERAL:.*]] = cir.const #cir.int<99> : !s8i +// CIR: %[[VPTR_ADDR:.*]] = cir.vtable.get_vptr %[[A]] : !cir.ptr<!rec_A> -> !cir.ptr<!cir.vptr> +// CIR: %[[VPTR:.*]] = cir.load{{.*}} %[[VPTR_ADDR]] : !cir.ptr<!cir.vptr>, !cir.vptr +// CIR: %[[FN_PTR_PTR:.*]] = cir.vtable.get_virtual_fn_addr %[[VPTR]][0] : !cir.vptr -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>> +// CIR: %[[FN_PTR:.*]] = cir.load{{.*}} %[[FN_PTR_PTR:.*]] : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>>, !cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>> +// CIR: cir.call %[[FN_PTR]](%[[A]], %[[C_LITERAL]]) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>, !cir.ptr<!rec_A>, !s8i) -> () + +// LLVM: define{{.*}} void @_Z2f1P1A(ptr %[[ARG0:.*]]) +// LLVM: %[[A_ADDR:.*]] = alloca ptr +// LLVM: store ptr %[[ARG0]], ptr %[[A_ADDR]] +// LLVM: %[[A:.*]] = load ptr, ptr %[[A_ADDR]] +// LLVM: %[[VPTR:.*]] = load ptr, ptr %[[A]] +// LLVM: %[[FN_PTR_PTR:.*]] = getelementptr inbounds ptr, ptr %[[VPTR]], i32 0 +// LLVM: %[[FN_PTR:.*]] = load ptr, ptr %[[FN_PTR_PTR]] +// LLVM: call void %[[FN_PTR]](ptr %[[A]], i8 99) + +// OGCG: define{{.*}} void @_Z2f1P1A(ptr {{.*}} %[[ARG0:.*]]) +// OGCG: %[[A_ADDR:.*]] = alloca ptr +// OGCG: store ptr %[[ARG0]], ptr %[[A_ADDR]] +// OGCG: %[[A:.*]] = load ptr, ptr %[[A_ADDR]] +// OGCG: %[[VPTR:.*]] = load ptr, ptr %[[A]] +// OGCG: %[[FN_PTR_PTR:.*]] = getelementptr inbounds ptr, ptr %[[VPTR]], i64 0 +// OGCG: %[[FN_PTR:.*]] = load ptr, ptr %[[FN_PTR_PTR]] +// OGCG: call void %[[FN_PTR]](ptr {{.*}} %[[A]], i8 {{.*}} 99) _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits