tbaeder updated this revision to Diff 497276.

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D142630/new/

https://reviews.llvm.org/D142630

Files:
  clang/lib/AST/Interp/ByteCodeExprGen.cpp
  clang/lib/AST/Interp/Context.cpp
  clang/lib/AST/Interp/Context.h
  clang/lib/AST/Interp/Descriptor.cpp
  clang/lib/AST/Interp/Function.h
  clang/lib/AST/Interp/Interp.h
  clang/lib/AST/Interp/InterpState.h
  clang/lib/AST/Interp/Opcodes.td
  clang/lib/AST/Interp/Pointer.h
  clang/test/AST/Interp/records.cpp

Index: clang/test/AST/Interp/records.cpp
===================================================================
--- clang/test/AST/Interp/records.cpp
+++ clang/test/AST/Interp/records.cpp
@@ -1,8 +1,10 @@
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -verify %s
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++14 -verify %s
+// RUN: %clang_cc1 -fexperimental-new-constant-interpreter -std=c++20 -verify %s
 // RUN: %clang_cc1 -fexperimental-new-constant-interpreter -triple i686 -verify %s
 // RUN: %clang_cc1 -verify=ref %s
 // RUN: %clang_cc1 -verify=ref -std=c++14 %s
+// RUN: %clang_cc1 -verify=ref -std=c++20 %s
 // RUN: %clang_cc1 -verify=ref -triple i686 %s
 
 struct BoolPair {
@@ -286,6 +288,7 @@
 };
 
 namespace DeriveFailures {
+#if __cplusplus < 202002L
   struct Base { // ref-note 2{{declared here}}
     int Val;
   };
@@ -301,10 +304,12 @@
                            // ref-note {{in call to 'Derived(12)'}} \
                            // ref-note {{declared here}} \
                            // expected-error {{must be initialized by a constant expression}}
+
   static_assert(D.Val == 0, ""); // ref-error {{not an integral constant expression}} \
                                  // ref-note {{initializer of 'D' is not a constant expression}} \
                                  // expected-error {{not an integral constant expression}} \
                                  // expected-note {{read of object outside its lifetime}}
+#endif
 
   struct AnotherBase {
     int Val;
@@ -354,3 +359,121 @@
   static_assert(getS(true).a == 12, "");
   static_assert(getS(false).a == 13, "");
 };
+
+#if __cplusplus >= 202002L
+namespace VirtualCalls {
+namespace Obvious {
+
+  class A {
+  public:
+    constexpr A(){}
+    constexpr virtual int foo() {
+      return 3;
+    }
+  };
+  class B : public A {
+  public:
+    constexpr int foo() override {
+      return 6;
+    }
+  };
+
+  constexpr int getFooB(bool b) {
+    A *a;
+    A myA;
+    B myB;
+
+    if (b)
+      a = &myA;
+    else
+      a = &myB;
+
+    return a->foo();
+  }
+  static_assert(getFooB(true) == 3, "");
+  static_assert(getFooB(false) == 6, "");
+}
+
+namespace MultipleBases {
+  class A {
+  public:
+    constexpr virtual int getInt() const { return 10; }
+  };
+  class B {
+  public:
+  };
+  class C : public A, public B {
+  public:
+    constexpr int getInt() const override { return 20; }
+  };
+
+  constexpr int callGetInt(const A& a) { return a.getInt(); }
+  static_assert(callGetInt(C()) == 20, "");
+  static_assert(callGetInt(A()) == 10, "");
+}
+
+namespace Destructors {
+  class Base {
+  public:
+    int i;
+    constexpr Base(int &i) : i(i) {i++;}
+    constexpr virtual ~Base() {i--;}
+  };
+
+  class Derived : public Base {
+  public:
+    constexpr Derived(int &i) : Base(i) {}
+    constexpr virtual ~Derived() {i--;}
+  };
+
+  constexpr int test() {
+    int i = 0;
+    Derived d(i);
+    return i;
+  }
+  static_assert(test() == 1);
+}
+
+
+namespace QualifiedCalls {
+  class A {
+      public:
+      constexpr virtual int foo() const {
+          return 5;
+      }
+  };
+  class B : public A {};
+  class C : public B {
+      public:
+      constexpr int foo() const override {
+          return B::foo(); // B doesn't have a foo(), so this should call A::foo().
+      }
+      constexpr int foo2() const {
+        return this->A::foo();
+      }
+  };
+  constexpr C c;
+  static_assert(c.foo() == 5);
+  static_assert(c.foo2() == 5);
+
+
+  struct S {
+    int _c = 0;
+    virtual constexpr int foo() const { return 1; }
+  };
+
+  struct SS : S {
+    int a;
+    constexpr SS() {
+      a = S::foo();
+    }
+    constexpr int foo() const override {
+      return S::foo();
+    }
+  };
+
+  constexpr SS ss;
+  static_assert(ss.a == 1);
+}
+};
+#endif
Index: clang/lib/AST/Interp/Pointer.h
===================================================================
--- clang/lib/AST/Interp/Pointer.h
+++ clang/lib/AST/Interp/Pointer.h
@@ -206,6 +206,8 @@
   /// Returns the type of the innermost field.
   QualType getType() const { return getFieldDesc()->getType(); }
 
+  Pointer getDeclPtr() const { return Pointer(Pointee); }
+
   /// Returns the element size of the innermost field.
   size_t elemSize() const {
     if (Base == RootPtrMark)
Index: clang/lib/AST/Interp/Opcodes.td
===================================================================
--- clang/lib/AST/Interp/Opcodes.td
+++ clang/lib/AST/Interp/Opcodes.td
@@ -182,6 +182,11 @@
   let ChangesPC = 1;
 }
 
+def CallVirt : Opcode {
+  let Args = [ArgFunction];
+  let Types = [];
+}
+
 def CallBI : Opcode {
   let Args = [ArgFunction];
   let Types = [];
Index: clang/lib/AST/Interp/InterpState.h
===================================================================
--- clang/lib/AST/Interp/InterpState.h
+++ clang/lib/AST/Interp/InterpState.h
@@ -86,6 +86,8 @@
     return M ? M->getSource(F, PC) : F->getSource(PC);
   }
 
+  Context &getContext() const { return Ctx; }
+
 private:
   /// AST Walker state.
   State &Parent;
Index: clang/lib/AST/Interp/Interp.h
===================================================================
--- clang/lib/AST/Interp/Interp.h
+++ clang/lib/AST/Interp/Interp.h
@@ -1534,6 +1534,36 @@
   return false;
 }
 
+inline bool CallVirt(InterpState &S, CodePtr OpPC, const Function *Func) {
+  assert(Func->hasThisPointer());
+  assert(Func->isVirtual());
+  size_t ThisOffset =
+      Func->getArgSize() + (Func->hasRVO() ? primSize(PT_Ptr) : 0);
+  Pointer &ThisPtr = S.Stk.peek<Pointer>(ThisOffset);
+
+  const CXXRecordDecl *DynamicDecl =
+      ThisPtr.getDeclDesc()->getType()->getAsCXXRecordDecl();
+  const CXXRecordDecl *StaticDecl = cast<CXXRecordDecl>(Func->getParentDecl());
+  const CXXMethodDecl *InitialFunction = cast<CXXMethodDecl>(Func->getDecl());
+  const CXXMethodDecl *Overrider = S.getContext().getOverridingFunction(
+      DynamicDecl, StaticDecl, InitialFunction);
+
+  if (Overrider != InitialFunction) {
+    Func = S.P.getFunction(Overrider);
+
+    const CXXRecordDecl *ThisFieldDecl =
+        ThisPtr.getFieldDesc()->getType()->getAsCXXRecordDecl();
+    if (Func->getParentDecl()->isDerivedFrom(ThisFieldDecl)) {
+      // If the function we call is further DOWN the hierarchy than the
+      // FieldDesc of our pointer, just get the DeclDesc instead, which
+      // is the furthest we might go up in the hierarchy.
+      ThisPtr = ThisPtr.getDeclPtr();
+    }
+  }
+
+  return Call(S, OpPC, Func);
+}
+
 inline bool CallBI(InterpState &S, CodePtr &PC, const Function *Func) {
   auto NewFrame = std::make_unique<InterpFrame>(S, Func, PC);
 
Index: clang/lib/AST/Interp/Function.h
===================================================================
--- clang/lib/AST/Interp/Function.h
+++ clang/lib/AST/Interp/Function.h
@@ -132,6 +132,13 @@
   /// Checks if the function is a destructor.
   bool isDestructor() const { return isa<CXXDestructorDecl>(F); }
 
+  /// Returns the parent record decl, if any.
+  const CXXRecordDecl *getParentDecl() const {
+    if (const auto *MD = dyn_cast<CXXMethodDecl>(F))
+      return MD->getParent();
+    return nullptr;
+  }
+
   /// Checks if the function is fully done compiling.
   bool isFullyCompiled() const { return IsFullyCompiled; }
 
Index: clang/lib/AST/Interp/Descriptor.cpp
===================================================================
--- clang/lib/AST/Interp/Descriptor.cpp
+++ clang/lib/AST/Interp/Descriptor.cpp
@@ -269,6 +269,8 @@
     return E->getType();
   if (auto *D = asValueDecl())
     return D->getType();
+  if (auto *T = dyn_cast<TypeDecl>(asDecl()))
+    return QualType(T->getTypeForDecl(), 0);
   llvm_unreachable("Invalid descriptor type");
 }
 
Index: clang/lib/AST/Interp/Context.h
===================================================================
--- clang/lib/AST/Interp/Context.h
+++ clang/lib/AST/Interp/Context.h
@@ -61,6 +61,11 @@
   /// Classifies an expression.
   std::optional<PrimType> classify(QualType T) const;
 
+  const CXXMethodDecl *
+  getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+                        const CXXRecordDecl *StaticDecl,
+                        const CXXMethodDecl *InitialFunction) const;
+
 private:
   /// Runs a function.
   bool Run(State &Parent, Function *Func, APValue &Result);
Index: clang/lib/AST/Interp/Context.cpp
===================================================================
--- clang/lib/AST/Interp/Context.cpp
+++ clang/lib/AST/Interp/Context.cpp
@@ -152,3 +152,38 @@
   });
   return false;
 }
+
+// TODO: Virtual bases?
+const CXXMethodDecl *
+Context::getOverridingFunction(const CXXRecordDecl *DynamicDecl,
+                               const CXXRecordDecl *StaticDecl,
+                               const CXXMethodDecl *InitialFunction) const {
+
+  const CXXRecordDecl *CurRecord = DynamicDecl;
+  const CXXMethodDecl *FoundFunction = InitialFunction;
+  for (;;) {
+    const CXXMethodDecl *Overrider =
+        FoundFunction->getCorrespondingMethodDeclaredInClass(CurRecord, false);
+    if (Overrider)
+      return Overrider;
+
+    // Common case of only one base class.
+    if (CurRecord->getNumBases() == 1) {
+      CurRecord = CurRecord->bases_begin()->getType()->getAsCXXRecordDecl();
+      continue;
+    }
+
+    // Otherwise, go to the base class that will lead to the StaticDecl.
+    for (const CXXBaseSpecifier &Spec : CurRecord->bases()) {
+      const CXXRecordDecl *Base = Spec.getType()->getAsCXXRecordDecl();
+      if (Base == StaticDecl || Base->isDerivedFrom(StaticDecl)) {
+        CurRecord = Base;
+        break;
+      }
+    }
+  }
+
+  llvm_unreachable(
+      "Couldn't find an overriding function in the class hierarchy?");
+  return nullptr;
+}
Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp
===================================================================
--- clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -1688,12 +1688,26 @@
 
     assert(HasRVO == Func->hasRVO());
 
+    bool HasQualifier = false;
+    if (const auto *ME = dyn_cast<MemberExpr>(E->getCallee());
+        ME && ME->hasQualifier())
+      HasQualifier = true;
+
+    bool IsVirtual = false;
+    if (const auto *MD = dyn_cast<CXXMethodDecl>(FuncDecl);
+        MD && MD->isVirtual())
+      IsVirtual = true;
+
     // In any case call the function. The return value will end up on the stack
     // and if the function has RVO, we already have the pointer on the stack to
     // write the result into.
-    if (!this->emitCall(Func, E))
-      return false;
-
+    if (IsVirtual && !HasQualifier) {
+      if (!this->emitCallVirt(Func, E))
+        return false;
+    } else {
+      if (!this->emitCall(Func, E))
+        return false;
+    }
   } else {
     // Indirect call. Visit the callee, which will leave a FunctionPointer on
     // the stack. Cleanup of the returned value if necessary will be done after
@@ -1977,6 +1991,8 @@
       if (!this->emitCall(DtorFunc, SourceInfo{}))
         return false;
     }
+    if (Dtor->isVirtual())
+      return this->emitPopPtr(SourceInfo{});
   }
 
   for (const Record::Base &Base : llvm::reverse(R->bases())) {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to