Author: Timm Baeder
Date: 2025-05-20T11:19:24+02:00
New Revision: d01355645b1fece147163e1cfe9f71d9c704860e

URL: 
https://github.com/llvm/llvm-project/commit/d01355645b1fece147163e1cfe9f71d9c704860e
DIFF: 
https://github.com/llvm/llvm-project/commit/d01355645b1fece147163e1cfe9f71d9c704860e.diff

LOG: [clang][bytecode] Check downcasts for the correct type (#140689)

In multiple inheritance/diamond scenarios, we might arrive at the wrong
type.

Added: 
    

Modified: 
    clang/lib/AST/ByteCode/Compiler.cpp
    clang/lib/AST/ByteCode/Interp.h
    clang/lib/AST/ByteCode/Opcodes.td
    clang/test/AST/ByteCode/records.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/AST/ByteCode/Compiler.cpp 
b/clang/lib/AST/ByteCode/Compiler.cpp
index 36380543e5991..54a4647a604fb 100644
--- a/clang/lib/AST/ByteCode/Compiler.cpp
+++ b/clang/lib/AST/ByteCode/Compiler.cpp
@@ -296,12 +296,15 @@ bool Compiler<Emitter>::VisitCastExpr(const CastExpr *CE) 
{
   case CK_BaseToDerived: {
     if (!this->delegate(SubExpr))
       return false;
-
     unsigned DerivedOffset =
         collectBaseOffset(SubExpr->getType(), CE->getType());
 
-    return this->emitGetPtrDerivedPop(
-        DerivedOffset, /*NullOK=*/CE->getType()->isPointerType(), CE);
+    const Type *TargetType = CE->getType().getTypePtr();
+    if (TargetType->isPointerOrReferenceType())
+      TargetType = TargetType->getPointeeType().getTypePtr();
+    return this->emitGetPtrDerivedPop(DerivedOffset,
+                                      
/*NullOK=*/CE->getType()->isPointerType(),
+                                      TargetType, CE);
   }
 
   case CK_FloatingCast: {

diff  --git a/clang/lib/AST/ByteCode/Interp.h b/clang/lib/AST/ByteCode/Interp.h
index bfc6797d13412..70bbfc576925e 100644
--- a/clang/lib/AST/ByteCode/Interp.h
+++ b/clang/lib/AST/ByteCode/Interp.h
@@ -1643,7 +1643,7 @@ inline bool GetPtrActiveThisField(InterpState &S, CodePtr 
OpPC, uint32_t Off) {
 }
 
 inline bool GetPtrDerivedPop(InterpState &S, CodePtr OpPC, uint32_t Off,
-                             bool NullOK) {
+                             bool NullOK, const Type *TargetType) {
   const Pointer &Ptr = S.Stk.pop<Pointer>();
   if (!NullOK && !CheckNull(S, OpPC, Ptr, CSK_Derived))
     return false;
@@ -1661,6 +1661,20 @@ inline bool GetPtrDerivedPop(InterpState &S, CodePtr 
OpPC, uint32_t Off,
   if (!CheckDowncast(S, OpPC, Ptr, Off))
     return false;
 
+  const Record *TargetRecord = Ptr.atFieldSub(Off).getRecord();
+  assert(TargetRecord);
+
+  if (TargetRecord->getDecl()
+          ->getTypeForDecl()
+          ->getAsCXXRecordDecl()
+          ->getCanonicalDecl() !=
+      TargetType->getAsCXXRecordDecl()->getCanonicalDecl()) {
+    QualType MostDerivedType = Ptr.getDeclDesc()->getType();
+    S.CCEDiag(S.Current->getSource(OpPC), 
diag::note_constexpr_invalid_downcast)
+        << MostDerivedType << QualType(TargetType, 0);
+    return false;
+  }
+
   S.Stk.push<Pointer>(Ptr.atFieldSub(Off));
   return true;
 }

diff  --git a/clang/lib/AST/ByteCode/Opcodes.td 
b/clang/lib/AST/ByteCode/Opcodes.td
index 9dddcced8ca38..c8db8da113bd4 100644
--- a/clang/lib/AST/ByteCode/Opcodes.td
+++ b/clang/lib/AST/ByteCode/Opcodes.td
@@ -325,7 +325,7 @@ def GetMemberPtrBasePop : Opcode {
 def FinishInitPop : Opcode;
 def FinishInit    : Opcode;
 
-def GetPtrDerivedPop : Opcode { let Args = [ArgUint32, ArgBool]; }
+def GetPtrDerivedPop : Opcode { let Args = [ArgUint32, ArgBool, ArgTypePtr]; }
 
 // [Pointer] -> [Pointer]
 def GetPtrVirtBasePop : Opcode {

diff  --git a/clang/test/AST/ByteCode/records.cpp 
b/clang/test/AST/ByteCode/records.cpp
index c2fe3d9007480..9361d6ddeda70 100644
--- a/clang/test/AST/ByteCode/records.cpp
+++ b/clang/test/AST/ByteCode/records.cpp
@@ -1830,3 +1830,15 @@ namespace NullDtor {
   static_assert(foo() == 10, ""); // both-error {{not an integral constant 
expression}} \
                                   // both-note {{in call to}}
 }
+
+namespace DiamondDowncast {
+  struct Top {};
+  struct Middle1 : Top {};
+  struct Middle2 : Top {};
+  struct Bottom : Middle1, Middle2 {};
+
+  constexpr Bottom bottom;
+  constexpr Top &top1 = (Middle1&)bottom;
+  constexpr Middle2 &fail = (Middle2&)top1; // both-error {{must be 
initialized by a constant expression}} \
+                                            // both-note {{cannot cast object 
of dynamic type 'const Bottom' to type 'Middle2'}}
+}


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to