https://github.com/andykaylor created 
https://github.com/llvm/llvm-project/pull/153893

This change adds support for calling virtual functions. This includes adding 
the cir.vtable.get_virtual_fn_addr operation to lookup the address of the 
function being called from an object's vtable.

>From 470cc01f64d2740dfb865e2d4a1171b87994c8d0 Mon Sep 17 00:00:00 2001
From: Andy Kaylor <akay...@nvidia.com>
Date: Thu, 7 Aug 2025 16:05:10 -0700
Subject: [PATCH] [CIR] Add support for calling virtual functions

This change adds support for calling virtual functions. This includes
adding the cir.vtable.get_virtual_fn_addr operation to lookup the
address of the function being called from an object's vtable.
---
 .../CIR/Dialect/Builder/CIRBaseBuilder.h      | 13 +++++
 clang/include/clang/CIR/Dialect/IR/CIROps.td  | 48 ++++++++++++++++++
 clang/include/clang/CIR/MissingFeatures.h     |  3 +-
 clang/lib/CIR/CodeGen/CIRGenCXXABI.h          | 17 +++++++
 clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp       | 49 +++++++++++++------
 clang/lib/CIR/CodeGen/CIRGenCall.cpp          |  7 ++-
 clang/lib/CIR/CodeGen/CIRGenCall.h            | 46 ++++++++++++++++-
 clang/lib/CIR/CodeGen/CIRGenClass.cpp         | 14 ++++++
 clang/lib/CIR/CodeGen/CIRGenFunction.h        |  5 ++
 clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp | 48 ++++++++++++++++++
 clang/lib/CIR/CodeGen/CIRGenValue.h           |  8 +++
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 22 +++++++--
 .../CIR/Lowering/DirectToLLVM/LowerToLLVM.h   | 11 +++++
 .../CIR/CodeGen/virtual-function-calls.cpp    | 33 +++++++++++++
 14 files changed, 303 insertions(+), 21 deletions(-)

diff --git a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h 
b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
index 0bf3cb26be850..6181b64fe6d0e 100644
--- a/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
+++ b/clang/include/clang/CIR/Dialect/Builder/CIRBaseBuilder.h
@@ -157,6 +157,19 @@ class CIRBaseBuilderTy : public mlir::OpBuilder {
     return create<cir::ComplexImagOp>(loc, operandTy.getElementType(), 
operand);
   }
 
+  cir::LoadOp createLoad(mlir::Location loc, mlir::Value ptr,
+                         uint64_t alignment = 0) {
+    mlir::IntegerAttr alignmentAttr = getAlignmentAttr(alignment);
+    assert(!cir::MissingFeatures::opLoadStoreVolatile());
+    assert(!cir::MissingFeatures::opLoadStoreMemOrder());
+    return create<cir::LoadOp>(loc, ptr, /*isDeref=*/false, alignmentAttr);
+  }
+
+  mlir::Value createAlignedLoad(mlir::Location loc, mlir::Value ptr,
+                                uint64_t alignment) {
+    return createLoad(loc, ptr, alignment);
+  }
+
   mlir::Value createNot(mlir::Value value) {
     return create<cir::UnaryOp>(value.getLoc(), value.getType(),
                                 cir::UnaryOpKind::Not, value);
diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td 
b/clang/include/clang/CIR/Dialect/IR/CIROps.td
index a181c95494eff..9f98a654b796c 100644
--- a/clang/include/clang/CIR/Dialect/IR/CIROps.td
+++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td
@@ -1782,6 +1782,54 @@ def CIR_VTableGetVPtrOp : CIR_Op<"vtable.get_vptr", 
[Pure]> {
   }];
 }
 
+//===----------------------------------------------------------------------===//
+// VTableGetVirtualFnAddrOp
+//===----------------------------------------------------------------------===//
+
+def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [
+  Pure
+]> {
+  let summary = "Get a the address of a virtual function pointer";
+  let description = [{
+    The `vtable.get_virtual_fn_addr` operation retrieves the address of a
+    virtual function pointer from an object's vtable (__vptr).
+    This is an abstraction to perform the basic pointer arithmetic to get
+    the address of the virtual function pointer, which can then be loaded and
+    called.
+
+    The `vptr` operand must be a `!cir.ptr<!cir.vptr>` value, which would
+    have been returned by a previous call to `cir.vatble.get_vptr`. The
+    `index` operand is an index of the virtual function in the vtable.
+
+    The return type is a pointer-to-pointer to the function type.
+
+    Example:
+    ```mlir
+    %2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
+    %3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
+    %4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr
+    %5 = cir.vtable.get_virtual_fn_addr %4[2] : !cir.vptr
+                  -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>>
+    %6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>)
+                                                                 -> !s32i>>>,
+                                !cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> 
!s32i>>
+    %7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>,
+                            !cir.ptr<!rec_C>) -> !s32i
+    ```
+  }];
+
+  let arguments = (ins
+    Arg<CIR_VPtrType, "vptr", [MemRead]>:$vptr,
+    I64Attr:$index);
+
+  let results = (outs CIR_PointerType:$result);
+
+  let assemblyFormat = [{
+    $vptr `[` $index `]` attr-dict
+    `:` qualified(type($vptr)) `->` qualified(type($result))
+  }];
+}
+
 
//===----------------------------------------------------------------------===//
 // SetBitfieldOp
 
//===----------------------------------------------------------------------===//
diff --git a/clang/include/clang/CIR/MissingFeatures.h 
b/clang/include/clang/CIR/MissingFeatures.h
index baab62f572b98..c5d7006bfabf7 100644
--- a/clang/include/clang/CIR/MissingFeatures.h
+++ b/clang/include/clang/CIR/MissingFeatures.h
@@ -95,7 +95,6 @@ struct MissingFeatures {
   static bool opCallArgEvaluationOrder() { return false; }
   static bool opCallCallConv() { return false; }
   static bool opCallMustTail() { return false; }
-  static bool opCallVirtual() { return false; }
   static bool opCallInAlloca() { return false; }
   static bool opCallAttrs() { return false; }
   static bool opCallSurroundingTry() { return false; }
@@ -204,6 +203,7 @@ struct MissingFeatures {
   static bool dataLayoutTypeAllocSize() { return false; }
   static bool dataLayoutTypeStoreSize() { return false; }
   static bool deferredCXXGlobalInit() { return false; }
+  static bool devirtualizeMemberFunction() { return false; }
   static bool ehCleanupFlags() { return false; }
   static bool ehCleanupScope() { return false; }
   static bool ehCleanupScopeRequiresEHCleanup() { return false; }
@@ -215,6 +215,7 @@ struct MissingFeatures {
   static bool emitLValueAlignmentAssumption() { return false; }
   static bool emitNullabilityCheck() { return false; }
   static bool emitTypeCheck() { return false; }
+  static bool emitTypeMetadataCodeForVCall() { return false; }
   static bool fastMathFlags() { return false; }
   static bool fpConstraints() { return false; }
   static bool generateDebugInfo() { return false; }
diff --git a/clang/lib/CIR/CodeGen/CIRGenCXXABI.h 
b/clang/lib/CIR/CodeGen/CIRGenCXXABI.h
index abde1a7687a90..3f1cb8363a556 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCXXABI.h
+++ b/clang/lib/CIR/CodeGen/CIRGenCXXABI.h
@@ -63,6 +63,16 @@ class CIRGenCXXABI {
   /// parameter.
   virtual bool needsVTTParameter(clang::GlobalDecl gd) { return false; }
 
+  /// Perform ABI-specific "this" argument adjustment required prior to
+  /// a call of a virtual function.
+  /// The "VirtualCall" argument is true iff the call itself is virtual.
+  virtual Address adjustThisArgumentForVirtualFunctionCall(CIRGenFunction &cgf,
+                                                           clang::GlobalDecl 
gd,
+                                                           Address thisPtr,
+                                                           bool virtualCall) {
+    return thisPtr;
+  }
+
   /// Build a parameter variable suitable for 'this'.
   void buildThisParam(CIRGenFunction &cgf, FunctionArgList &params);
 
@@ -100,6 +110,13 @@ class CIRGenCXXABI {
   virtual cir::GlobalOp getAddrOfVTable(const CXXRecordDecl *rd,
                                         CharUnits vptrOffset) = 0;
 
+  /// Build a virtual function pointer in the ABI-specific way.
+  virtual CIRGenCallee getVirtualFunctionPointer(CIRGenFunction &cgf,
+                                                 clang::GlobalDecl gd,
+                                                 Address thisAddr,
+                                                 mlir::Type ty,
+                                                 SourceLocation loc) = 0;
+
   /// Get the address point of the vtable for the given base subobject.
   virtual mlir::Value
   getVTableAddressPoint(BaseSubobject base,
diff --git a/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp 
b/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp
index bc30c7bd130af..c9e4ed92d16bb 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCXXExpr.cpp
@@ -79,11 +79,10 @@ RValue 
CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr(
     const Expr *base) {
   assert(isa<CXXMemberCallExpr>(ce) || isa<CXXOperatorCallExpr>(ce));
 
-  if (md->isVirtual()) {
-    cgm.errorNYI(ce->getSourceRange(),
-                 "emitCXXMemberOrOperatorMemberCallExpr: virtual call");
-    return RValue::get(nullptr);
-  }
+  // Compute the object pointer.
+  bool canUseVirtualCall = md->isVirtual() && !hasQualifier;
+  const CXXMethodDecl *devirtualizedMethod = nullptr;
+  assert(!cir::MissingFeatures::devirtualizeMemberFunction());
 
   // Note on trivial assignment
   // --------------------------
@@ -127,7 +126,8 @@ RValue 
CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr(
     return RValue::get(nullptr);
 
   // Compute the function type we're calling
-  const CXXMethodDecl *calleeDecl = md;
+  const CXXMethodDecl *calleeDecl =
+      devirtualizedMethod ? devirtualizedMethod : md;
   const CIRGenFunctionInfo *fInfo = nullptr;
   if (isa<CXXDestructorDecl>(calleeDecl)) {
     cgm.errorNYI(ce->getSourceRange(),
@@ -137,25 +137,46 @@ RValue 
CIRGenFunction::emitCXXMemberOrOperatorMemberCallExpr(
 
   fInfo = &cgm.getTypes().arrangeCXXMethodDeclaration(calleeDecl);
 
-  mlir::Type ty = cgm.getTypes().getFunctionType(*fInfo);
+  cir::FuncType ty = cgm.getTypes().getFunctionType(*fInfo);
 
   assert(!cir::MissingFeatures::sanitizers());
   assert(!cir::MissingFeatures::emitTypeCheck());
 
+  // C++ [class.virtual]p12:
+  //   Explicit qualification with the scope operator (5.1) suppresses the
+  //   virtual call mechanism.
+  //
+  // We also don't emit a virtual call if the base expression has a record type
+  // because then we know what the type is.
+  bool useVirtualCall = canUseVirtualCall && !devirtualizedMethod;
+
   if (isa<CXXDestructorDecl>(calleeDecl)) {
     cgm.errorNYI(ce->getSourceRange(),
                  "emitCXXMemberOrOperatorMemberCallExpr: destructor call");
     return RValue::get(nullptr);
   }
 
-  assert(!cir::MissingFeatures::sanitizers());
-  if (getLangOpts().AppleKext) {
-    cgm.errorNYI(ce->getSourceRange(),
-                 "emitCXXMemberOrOperatorMemberCallExpr: AppleKext");
-    return RValue::get(nullptr);
+  CIRGenCallee callee;
+  if (useVirtualCall) {
+    callee = CIRGenCallee::forVirtual(ce, md, thisPtr.getAddress(), ty);
+  } else {
+    assert(!cir::MissingFeatures::sanitizers());
+    if (getLangOpts().AppleKext) {
+      cgm.errorNYI(ce->getSourceRange(),
+                   "emitCXXMemberOrOperatorMemberCallExpr: AppleKext");
+      return RValue::get(nullptr);
+    }
+
+    callee = CIRGenCallee::forDirect(cgm.getAddrOfFunction(calleeDecl, ty),
+                                     GlobalDecl(calleeDecl));
+  }
+
+  if (md->isVirtual()) {
+    Address newThisAddr =
+        cgm.getCXXABI().adjustThisArgumentForVirtualFunctionCall(
+            *this, calleeDecl, thisPtr.getAddress(), useVirtualCall);
+    thisPtr.setAddress(newThisAddr);
   }
-  CIRGenCallee callee =
-      CIRGenCallee::forDirect(cgm.getAddrOfFunction(md, ty), GlobalDecl(md));
 
   return emitCXXMemberOrOperatorCall(
       calleeDecl, callee, returnValue, thisPtr.getPointer(),
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.cpp 
b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
index e3fe3ca1c30c9..6d749940fa128 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.cpp
@@ -56,7 +56,12 @@ cir::FuncType CIRGenTypes::getFunctionType(const 
CIRGenFunctionInfo &info) {
 }
 
 CIRGenCallee CIRGenCallee::prepareConcreteCallee(CIRGenFunction &cgf) const {
-  assert(!cir::MissingFeatures::opCallVirtual());
+  if (isVirtual()) {
+    const CallExpr *ce = getVirtualCallExpr();
+    return cgf.cgm.getCXXABI().getVirtualFunctionPointer(
+        cgf, getVirtualMethodDecl(), getThisAddress(), 
getVirtualFunctionType(),
+        ce ? ce->getBeginLoc() : SourceLocation());
+  }
   return *this;
 }
 
diff --git a/clang/lib/CIR/CodeGen/CIRGenCall.h 
b/clang/lib/CIR/CodeGen/CIRGenCall.h
index 47d998ae25838..81cbb854f3b7d 100644
--- a/clang/lib/CIR/CodeGen/CIRGenCall.h
+++ b/clang/lib/CIR/CodeGen/CIRGenCall.h
@@ -47,8 +47,9 @@ class CIRGenCallee {
     Invalid,
     Builtin,
     PseudoDestructor,
+    Virtual,
 
-    Last = Builtin,
+    Last = Virtual
   };
 
   struct BuiltinInfoStorage {
@@ -58,6 +59,12 @@ class CIRGenCallee {
   struct PseudoDestructorInfoStorage {
     const clang::CXXPseudoDestructorExpr *expr;
   };
+  struct VirtualInfoStorage {
+    const clang::CallExpr *ce;
+    clang::GlobalDecl md;
+    Address addr;
+    cir::FuncType fTy;
+  };
 
   SpecialKind kindOrFunctionPtr;
 
@@ -65,6 +72,7 @@ class CIRGenCallee {
     CIRGenCalleeInfo abstractInfo;
     BuiltinInfoStorage builtinInfo;
     PseudoDestructorInfoStorage pseudoDestructorInfo;
+    VirtualInfoStorage virtualInfo;
   };
 
   explicit CIRGenCallee(SpecialKind kind) : kindOrFunctionPtr(kind) {}
@@ -128,7 +136,8 @@ class CIRGenCallee {
   CIRGenCallee prepareConcreteCallee(CIRGenFunction &cgf) const;
 
   CIRGenCalleeInfo getAbstractInfo() const {
-    assert(!cir::MissingFeatures::opCallVirtual());
+    if (isVirtual())
+      return virtualInfo.md;
     assert(isOrdinary());
     return abstractInfo;
   }
@@ -138,6 +147,39 @@ class CIRGenCallee {
     return reinterpret_cast<mlir::Operation *>(kindOrFunctionPtr);
   }
 
+  bool isVirtual() const { return kindOrFunctionPtr == SpecialKind::Virtual; }
+
+  static CIRGenCallee forVirtual(const clang::CallExpr *ce,
+                                 clang::GlobalDecl md, Address addr,
+                                 cir::FuncType fTy) {
+    CIRGenCallee result(SpecialKind::Virtual);
+    result.virtualInfo.ce = ce;
+    result.virtualInfo.md = md;
+    result.virtualInfo.addr = addr;
+    result.virtualInfo.fTy = fTy;
+    return result;
+  }
+
+  const clang::CallExpr *getVirtualCallExpr() const {
+    assert(isVirtual());
+    return virtualInfo.ce;
+  }
+
+  clang::GlobalDecl getVirtualMethodDecl() const {
+    assert(isVirtual());
+    return virtualInfo.md;
+  }
+
+  Address getThisAddress() const {
+    assert(isVirtual());
+    return virtualInfo.addr;
+  }
+
+  cir::FuncType getVirtualFunctionType() const {
+    assert(isVirtual());
+    return virtualInfo.fTy;
+  }
+
   void setFunctionPointer(mlir::Operation *functionPtr) {
     assert(isOrdinary());
     kindOrFunctionPtr = SpecialKind(reinterpret_cast<uintptr_t>(functionPtr));
diff --git a/clang/lib/CIR/CodeGen/CIRGenClass.cpp 
b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
index a3947047de079..3e5dc22426d8e 100644
--- a/clang/lib/CIR/CodeGen/CIRGenClass.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenClass.cpp
@@ -657,6 +657,20 @@ Address CIRGenFunction::getAddressOfBaseClass(
   return value;
 }
 
+// TODO(cir): this can be shared with LLVM codegen.
+bool CIRGenFunction::shouldEmitVTableTypeCheckedLoad(const CXXRecordDecl *rd) {
+  assert(!cir::MissingFeatures::hiddenVisibility());
+  if (!cgm.getCodeGenOpts().WholeProgramVTables)
+    return false;
+
+  if (cgm.getCodeGenOpts().VirtualFunctionElimination)
+    return true;
+
+  assert(!cir::MissingFeatures::sanitizers());
+
+  return false;
+}
+
 mlir::Value CIRGenFunction::getVTablePtr(mlir::Location loc, Address thisAddr,
                                          const CXXRecordDecl *rd) {
   auto vtablePtr = cir::VTableGetVPtrOp::create(
diff --git a/clang/lib/CIR/CodeGen/CIRGenFunction.h 
b/clang/lib/CIR/CodeGen/CIRGenFunction.h
index 9a887ec047f86..7ad3ad3559de8 100644
--- a/clang/lib/CIR/CodeGen/CIRGenFunction.h
+++ b/clang/lib/CIR/CodeGen/CIRGenFunction.h
@@ -552,6 +552,11 @@ class CIRGenFunction : public CIRGenTypeCache {
   mlir::Value getVTablePtr(mlir::Location loc, Address thisAddr,
                            const clang::CXXRecordDecl *vtableClass);
 
+  /// Returns whether we should perform a type checked load when loading a
+  /// virtual function for virtual calls to members of RD. This is generally
+  /// true when both vcall CFI and whole-program-vtables are enabled.
+  bool shouldEmitVTableTypeCheckedLoad(const CXXRecordDecl *rd);
+
   /// A scope within which we are constructing the fields of an object which
   /// might use a CXXDefaultInitExpr. This stashes away a 'this' value to use 
if
   /// we need to evaluate the CXXDefaultInitExpr within the evaluation.
diff --git a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp 
b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
index 43dfd6468046d..347656b5f6488 100644
--- a/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
+++ b/clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp
@@ -69,6 +69,10 @@ class CIRGenItaniumCXXABI : public CIRGenCXXABI {
 
   cir::GlobalOp getAddrOfVTable(const CXXRecordDecl *rd,
                                 CharUnits vptrOffset) override;
+  CIRGenCallee getVirtualFunctionPointer(CIRGenFunction &cgf,
+                                         clang::GlobalDecl gd, Address 
thisAddr,
+                                         mlir::Type ty,
+                                         SourceLocation loc) override;
 
   mlir::Value getVTableAddressPoint(BaseSubobject base,
                                     const CXXRecordDecl *vtableClass) override;
@@ -349,6 +353,50 @@ cir::GlobalOp CIRGenItaniumCXXABI::getAddrOfVTable(const 
CXXRecordDecl *rd,
   return vtable;
 }
 
+CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
+    CIRGenFunction &cgf, clang::GlobalDecl gd, Address thisAddr, mlir::Type ty,
+    SourceLocation srcLoc) {
+  CIRGenBuilderTy &builder = cgm.getBuilder();
+  mlir::Location loc = cgf.getLoc(srcLoc);
+  cir::PointerType tyPtr = builder.getPointerTo(ty);
+  auto *methodDecl = cast<CXXMethodDecl>(gd.getDecl());
+  mlir::Value vtable = cgf.getVTablePtr(loc, thisAddr, 
methodDecl->getParent());
+
+  uint64_t vtableIndex = 
cgm.getItaniumVTableContext().getMethodVTableIndex(gd);
+  mlir::Value vfunc{};
+  if (cgf.shouldEmitVTableTypeCheckedLoad(methodDecl->getParent())) {
+    cgm.errorNYI(loc, "getVirtualFunctionPointer: emitVTableTypeCheckedLoad");
+  } else {
+    assert(!cir::MissingFeatures::emitTypeMetadataCodeForVCall());
+
+    mlir::Value vfuncLoad;
+    if (cgm.getItaniumVTableContext().isRelativeLayout()) {
+      assert(!cir::MissingFeatures::vtableRelativeLayout());
+      cgm.errorNYI(loc, "getVirtualFunctionPointer: isRelativeLayout");
+    } else {
+      auto vtableSlotPtr = cir::VTableGetVirtualFnAddrOp::create(
+          builder, loc, builder.getPointerTo(tyPtr), vtable, vtableIndex);
+      vfuncLoad = builder.createAlignedLoad(
+          loc, vtableSlotPtr, cgf.getPointerAlign().getQuantity());
+    }
+
+    // Add !invariant.load md to virtual function load to indicate that
+    // function didn't change inside vtable.
+    // It's safe to add it without -fstrict-vtable-pointers, but it would not
+    // help in devirtualization because it will only matter if we will have 2
+    // the same virtual function loads from the same vtable load, which won't
+    // happen without enabled devirtualization with -fstrict-vtable-pointers.
+    if (cgm.getCodeGenOpts().OptimizationLevel > 0 &&
+        cgm.getCodeGenOpts().StrictVTablePointers) {
+      cgm.errorNYI(loc, "getVirtualFunctionPointer: strictVTablePointers");
+    }
+    vfunc = vfuncLoad;
+  }
+
+  CIRGenCallee callee(gd, vfunc.getDefiningOp());
+  return callee;
+}
+
 mlir::Value
 CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject base,
                                            const CXXRecordDecl *vtableClass) {
diff --git a/clang/lib/CIR/CodeGen/CIRGenValue.h 
b/clang/lib/CIR/CodeGen/CIRGenValue.h
index 661cecf8416b6..ac7e1cc1a1db6 100644
--- a/clang/lib/CIR/CodeGen/CIRGenValue.h
+++ b/clang/lib/CIR/CodeGen/CIRGenValue.h
@@ -212,6 +212,14 @@ class LValue {
     return Address(getPointer(), elementType, getAlignment());
   }
 
+  void setAddress(Address address) {
+    assert(isSimple());
+    v = address.getPointer();
+    elementType = address.getElementType();
+    alignment = address.getAlignment().getQuantity();
+    assert(!cir::MissingFeatures::addressIsKnownNonNull());
+  }
+
   const clang::Qualifiers &getQuals() const { return quals; }
   clang::Qualifiers &getQuals() { return quals; }
 
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
index 9f7521db78bec..1111e5f1b5b2f 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
@@ -1085,8 +1085,7 @@ rewriteCallOrInvoke(mlir::Operation *op, mlir::ValueRange 
callOperands,
     auto calleeTy = op->getOperands().front().getType();
     auto calleePtrTy = cast<cir::PointerType>(calleeTy);
     auto calleeFuncTy = cast<cir::FuncType>(calleePtrTy.getPointee());
-    calleeFuncTy.dump();
-    converter->convertType(calleeFuncTy).dump();
+    llvm::append_range(adjustedCallOperands, callOperands);
     llvmFnTy = cast<mlir::LLVM::LLVMFunctionType>(
         converter->convertType(calleeFuncTy));
   }
@@ -2193,6 +2192,9 @@ static void prepareTypeConverter(mlir::LLVMTypeConverter 
&converter,
 
     return llvmStruct;
   });
+  converter.addConversion([&](cir::VoidType type) -> mlir::Type {
+    return mlir::LLVM::LLVMVoidType::get(type.getContext());
+  });
 }
 
 // The applyPartialConversion function traverses blocks in the dominance order,
@@ -2345,7 +2347,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
                CIRToLLVMVecSplatOpLowering,
                CIRToLLVMVecTernaryOpLowering,
                CIRToLLVMVTableAddrPointOpLowering,
-               CIRToLLVMVTableGetVPtrOpLowering
+               CIRToLLVMVTableGetVPtrOpLowering,
+               CIRToLLVMVTableGetVirtualFnAddrOpLowering
       // clang-format on
       >(converter, patterns.getContext());
 
@@ -2481,6 +2484,19 @@ mlir::LogicalResult 
CIRToLLVMVTableGetVPtrOpLowering::matchAndRewrite(
   return mlir::success();
 }
 
+mlir::LogicalResult CIRToLLVMVTableGetVirtualFnAddrOpLowering::matchAndRewrite(
+    cir::VTableGetVirtualFnAddrOp op, OpAdaptor adaptor,
+    mlir::ConversionPatternRewriter &rewriter) const {
+  mlir::Type targetType = getTypeConverter()->convertType(op.getType());
+  auto eltType = mlir::LLVM::LLVMPointerType::get(rewriter.getContext());
+  llvm::SmallVector<mlir::LLVM::GEPArg> offsets =
+      llvm::SmallVector<mlir::LLVM::GEPArg>{op.getIndex()};
+  rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
+      op, targetType, eltType, adaptor.getVptr(), offsets,
+      mlir::LLVM::GEPNoWrapFlags::inbounds);
+  return mlir::success();
+}
+
 mlir::LogicalResult CIRToLLVMStackSaveOpLowering::matchAndRewrite(
     cir::StackSaveOp op, OpAdaptor adaptor,
     mlir::ConversionPatternRewriter &rewriter) const {
diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h 
b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
index 91e8505233379..be1d1c44bf9db 100644
--- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
+++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
@@ -477,6 +477,17 @@ class CIRToLLVMVTableGetVPtrOpLowering
                   mlir::ConversionPatternRewriter &) const override;
 };
 
+class CIRToLLVMVTableGetVirtualFnAddrOpLowering
+    : public mlir::OpConversionPattern<cir::VTableGetVirtualFnAddrOp> {
+public:
+  using mlir::OpConversionPattern<
+      cir::VTableGetVirtualFnAddrOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(cir::VTableGetVirtualFnAddrOp op, OpAdaptor,
+                  mlir::ConversionPatternRewriter &) const override;
+};
+
 class CIRToLLVMStackSaveOpLowering
     : public mlir::OpConversionPattern<cir::StackSaveOp> {
 public:
diff --git a/clang/test/CIR/CodeGen/virtual-function-calls.cpp 
b/clang/test/CIR/CodeGen/virtual-function-calls.cpp
index 4787d78aa0e35..e68b38fad3f84 100644
--- a/clang/test/CIR/CodeGen/virtual-function-calls.cpp
+++ b/clang/test/CIR/CodeGen/virtual-function-calls.cpp
@@ -46,3 +46,36 @@ A::A() {}
 // NOTE: The GEP in OGCG looks very different from the one generated with CIR,
 //       but it is equivalent. The OGCG GEP indexes by base pointer, then
 //       structure, then array, whereas the CIR GEP indexes by byte offset.
+
+void f1(A *a) {
+  a->f('c');
+}
+
+// CIR: cir.func{{.*}} @_Z2f1P1A(%arg0: !cir.ptr<!rec_A> {{.*}})
+// CIR:   %[[A_ADDR:.*]] = cir.alloca !cir.ptr<!rec_A>
+// CIR:   cir.store %arg0, %[[A_ADDR]]
+// CIR:   %[[A:.*]] = cir.load{{.*}} %[[A_ADDR]]
+// CIR:   %[[C_LITERAL:.*]] = cir.const #cir.int<99> : !s8i
+// CIR:   %[[VPTR_ADDR:.*]] = cir.vtable.get_vptr %[[A]] : !cir.ptr<!rec_A> -> 
!cir.ptr<!cir.vptr>
+// CIR:   %[[VPTR:.*]] = cir.load{{.*}} %[[VPTR_ADDR]] : !cir.ptr<!cir.vptr>, 
!cir.vptr
+// CIR:   %[[FN_PTR_PTR:.*]] = cir.vtable.get_virtual_fn_addr %[[VPTR]][0] : 
!cir.vptr -> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>>
+// CIR:   %[[FN_PTR:.*]] = cir.load{{.*}} %[[FN_PTR_PTR:.*]] : 
!cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>>, 
!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>
+// CIR:   cir.call %[[FN_PTR]](%[[A]], %[[C_LITERAL]]) : 
(!cir.ptr<!cir.func<(!cir.ptr<!rec_A>, !s8i)>>, !cir.ptr<!rec_A>, !s8i) -> ()
+
+// LLVM: define{{.*}} void @_Z2f1P1A(ptr %[[ARG0:.*]])
+// LLVM:   %[[A_ADDR:.*]] = alloca ptr
+// LLVM:   store ptr %[[ARG0]], ptr %[[A_ADDR]]
+// LLVM:   %[[A:.*]] = load ptr, ptr %[[A_ADDR]]
+// LLVM:   %[[VPTR:.*]] = load ptr, ptr %[[A]]
+// LLVM:   %[[FN_PTR_PTR:.*]] = getelementptr inbounds ptr, ptr %[[VPTR]], i32 0
+// LLVM:   %[[FN_PTR:.*]] = load ptr, ptr %[[FN_PTR_PTR]]
+// LLVM:   call void %[[FN_PTR]](ptr %[[A]], i8 99)
+
+// OGCG: define{{.*}} void @_Z2f1P1A(ptr {{.*}} %[[ARG0:.*]])
+// OGCG:   %[[A_ADDR:.*]] = alloca ptr
+// OGCG:   store ptr %[[ARG0]], ptr %[[A_ADDR]]
+// OGCG:   %[[A:.*]] = load ptr, ptr %[[A_ADDR]]
+// OGCG:   %[[VPTR:.*]] = load ptr, ptr %[[A]]
+// OGCG:   %[[FN_PTR_PTR:.*]] = getelementptr inbounds ptr, ptr %[[VPTR]], i64 0
+// OGCG:   %[[FN_PTR:.*]] = load ptr, ptr %[[FN_PTR_PTR]]
+// OGCG:   call void %[[FN_PTR]](ptr {{.*}} %[[A]], i8 {{.*}} 99)

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to