tbaeder updated this revision to Diff 492338.
tbaeder marked an inline comment as done.

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D141472/new/

https://reviews.llvm.org/D141472

Files:
  clang/lib/AST/Interp/ByteCodeExprGen.cpp
  clang/lib/AST/Interp/Context.cpp
  clang/lib/AST/Interp/Descriptor.cpp
  clang/lib/AST/Interp/FunctionPointer.h
  clang/lib/AST/Interp/Interp.h
  clang/lib/AST/Interp/InterpStack.h
  clang/lib/AST/Interp/Opcodes.td
  clang/lib/AST/Interp/PrimType.cpp
  clang/lib/AST/Interp/PrimType.h
  clang/test/AST/Interp/functions.cpp

Index: clang/test/AST/Interp/functions.cpp
===================================================================
--- clang/test/AST/Interp/functions.cpp
+++ clang/test/AST/Interp/functions.cpp
@@ -99,3 +99,58 @@
   huh(); // expected-error {{use of undeclared identifier}} \
          // ref-error {{use of undeclared identifier}}
 }
+
+namespace FunctionPointers {
+  constexpr int add(int a, int b) {
+    return a + b;
+  }
+
+  struct S { int a; };
+  constexpr S getS() {
+    return S{12};
+  }
+
+  constexpr int applyBinOp(int a, int b, int (*op)(int, int)) {
+    return op(a, b);
+  }
+  static_assert(applyBinOp(1, 2, add) == 3, "");
+
+  constexpr int ignoreReturnValue() {
+    int (*foo)(int, int) = add;
+
+    foo(1, 2);
+    return 1;
+  }
+  static_assert(ignoreReturnValue() == 1, "");
+
+  constexpr int createS(S (*gimme)()) {
+    gimme(); // Ignored return value
+    return gimme().a;
+  }
+  static_assert(createS(getS) == 12, "");
+
+namespace FunctionReturnType {
+  typedef int (*ptr)(int*);
+  typedef ptr (*pm)();
+
+  constexpr int fun1(int* y) {
+      return *y + 10;
+  }
+  constexpr ptr fun() {
+      return &fun1;
+  }
+  static_assert(fun() == nullptr, ""); // expected-error {{static assertion failed}} \
+                                       // ref-error {{static assertion failed}}
+
+  constexpr int foo() {
+    int (*f)(int *) = fun();
+    int m = 0;
+
+    m = f(&m);
+
+    return m;
+  }
+  static_assert(foo() == 10);
+}
+
+}
Index: clang/lib/AST/Interp/PrimType.h
===================================================================
--- clang/lib/AST/Interp/PrimType.h
+++ clang/lib/AST/Interp/PrimType.h
@@ -24,6 +24,7 @@
 class Pointer;
 class Boolean;
 class Floating;
+class FunctionPointer;
 
 /// Enumeration of the primitive types of the VM.
 enum PrimType : unsigned {
@@ -38,6 +39,7 @@
   PT_Bool,
   PT_Float,
   PT_Ptr,
+  PT_FnPtr,
 };
 
 /// Mapping from primitive types to their representation.
@@ -53,6 +55,7 @@
 template <> struct PrimConv<PT_Float> { using T = Floating; };
 template <> struct PrimConv<PT_Bool> { using T = Boolean; };
 template <> struct PrimConv<PT_Ptr> { using T = Pointer; };
+template <> struct PrimConv<PT_FnPtr> { using T = FunctionPointer; };
 
 /// Returns the size of a primitive type in bytes.
 size_t primSize(PrimType Type);
@@ -90,6 +93,7 @@
       TYPE_SWITCH_CASE(PT_Float, B)                                            \
       TYPE_SWITCH_CASE(PT_Bool, B)                                             \
       TYPE_SWITCH_CASE(PT_Ptr, B)                                              \
+      TYPE_SWITCH_CASE(PT_FnPtr, B)                                            \
     }                                                                          \
   } while (0)
 #define COMPOSITE_TYPE_SWITCH(Expr, B, D)                                      \
Index: clang/lib/AST/Interp/PrimType.cpp
===================================================================
--- clang/lib/AST/Interp/PrimType.cpp
+++ clang/lib/AST/Interp/PrimType.cpp
@@ -9,6 +9,7 @@
 #include "PrimType.h"
 #include "Boolean.h"
 #include "Floating.h"
+#include "FunctionPointer.h"
 #include "Pointer.h"
 
 using namespace clang;
Index: clang/lib/AST/Interp/Opcodes.td
===================================================================
--- clang/lib/AST/Interp/Opcodes.td
+++ clang/lib/AST/Interp/Opcodes.td
@@ -27,6 +27,7 @@
 def Uint64 : Type;
 def Float : Type;
 def Ptr : Type;
+def FnPtr : Type;
 
 //===----------------------------------------------------------------------===//
 // Types transferred to the interpreter.
@@ -77,7 +78,7 @@
 }
 
 def PtrTypeClass : TypeClass {
-  let Types = [Ptr];
+  let Types = [Ptr, FnPtr];
 }
 
 def BoolTypeClass : TypeClass {
@@ -187,6 +188,12 @@
   let ChangesPC = 1;
 }
 
+def CallPtr : Opcode {
+  let Args = [];
+  let Types = [];
+  let ChangesPC = 1;
+}
+
 //===----------------------------------------------------------------------===//
 // Frame management
 //===----------------------------------------------------------------------===//
@@ -228,6 +235,7 @@
 // [] -> [Pointer]
 def Null : Opcode {
   let Types = [PtrTypeClass];
+  let HasGroup = 1;
 }
 
 //===----------------------------------------------------------------------===//
@@ -447,6 +455,14 @@
   let HasGroup = 0;
 }
 
+//===----------------------------------------------------------------------===//
+// Function pointers.
+//===----------------------------------------------------------------------===//
+def GetFnPtr : Opcode {
+  let Args = [ArgFunction];
+}
+
+
 //===----------------------------------------------------------------------===//
 // Binary operators.
 //===----------------------------------------------------------------------===//
Index: clang/lib/AST/Interp/InterpStack.h
===================================================================
--- clang/lib/AST/Interp/InterpStack.h
+++ clang/lib/AST/Interp/InterpStack.h
@@ -13,6 +13,7 @@
 #ifndef LLVM_CLANG_AST_INTERP_INTERPSTACK_H
 #define LLVM_CLANG_AST_INTERP_INTERPSTACK_H
 
+#include "FunctionPointer.h"
 #include "PrimType.h"
 #include <memory>
 #include <vector>
@@ -162,6 +163,8 @@
       return PT_Uint64;
     else if constexpr (std::is_same_v<T, Floating>)
       return PT_Float;
+    else if constexpr (std::is_same_v<T, FunctionPointer>)
+      return PT_FnPtr;
 
     llvm_unreachable("unknown type push()'ed into InterpStack");
   }
Index: clang/lib/AST/Interp/Interp.h
===================================================================
--- clang/lib/AST/Interp/Interp.h
+++ clang/lib/AST/Interp/Interp.h
@@ -16,6 +16,7 @@
 #include "Boolean.h"
 #include "Floating.h"
 #include "Function.h"
+#include "FunctionPointer.h"
 #include "InterpFrame.h"
 #include "InterpStack.h"
 #include "InterpState.h"
@@ -1545,6 +1546,22 @@
   return false;
 }
 
+inline bool CallPtr(InterpState &S, CodePtr &PC) {
+  const FunctionPointer &FuncPtr = S.Stk.pop<FunctionPointer>();
+
+  const Function *F = FuncPtr.getFunction();
+  if (!F || !F->isConstexpr())
+    return false;
+
+  return Call(S, PC, F);
+}
+
+inline bool GetFnPtr(InterpState &S, CodePtr &PC, const Function *Func) {
+  assert(Func);
+  S.Stk.push<FunctionPointer>(Func);
+  return true;
+}
+
 //===----------------------------------------------------------------------===//
 // Read opcode arguments
 //===----------------------------------------------------------------------===//
Index: clang/lib/AST/Interp/FunctionPointer.h
===================================================================
--- /dev/null
+++ clang/lib/AST/Interp/FunctionPointer.h
@@ -0,0 +1,57 @@
+
+
+#ifndef LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H
+#define LLVM_CLANG_AST_INTERP_FUNCTION_POINTER_H
+
+#include "Function.h"
+#include "Primitives.h"
+#include "clang/AST/APValue.h"
+
+namespace clang {
+namespace interp {
+
+class FunctionPointer final {
+private:
+  const Function *Func;
+
+public:
+  FunctionPointer() : Func(nullptr) {}
+  FunctionPointer(const Function *Func) : Func(Func) { assert(Func); }
+
+  const Function *getFunction() const { return Func; }
+
+  APValue toAPValue() const {
+    if (!Func)
+      return APValue(static_cast<Expr *>(nullptr), CharUnits::Zero(), {},
+                     /*OnePastTheEnd=*/false, /*IsNull=*/true);
+
+    return APValue(Func->getDecl(), CharUnits::Zero(), {},
+                   /*OnePastTheEnd=*/false, /*IsNull=*/false);
+  }
+
+  void print(llvm::raw_ostream &OS) const {
+    OS << "FnPtr(";
+    if (Func)
+      OS << Func->getName();
+    else
+      OS << "nullptr";
+    OS << ")";
+  }
+
+  ComparisonCategoryResult compare(const FunctionPointer &RHS) const {
+    if (Func == RHS.Func)
+      return ComparisonCategoryResult::Equal;
+    return ComparisonCategoryResult::Unordered;
+  }
+};
+
+inline llvm::raw_ostream &operator<<(llvm::raw_ostream &OS,
+                                     FunctionPointer FP) {
+  FP.print(OS);
+  return OS;
+}
+
+} // namespace interp
+} // namespace clang
+
+#endif
Index: clang/lib/AST/Interp/Descriptor.cpp
===================================================================
--- clang/lib/AST/Interp/Descriptor.cpp
+++ clang/lib/AST/Interp/Descriptor.cpp
@@ -9,6 +9,7 @@
 #include "Descriptor.h"
 #include "Boolean.h"
 #include "Floating.h"
+#include "FunctionPointer.h"
 #include "Pointer.h"
 #include "PrimType.h"
 #include "Record.h"
Index: clang/lib/AST/Interp/Context.cpp
===================================================================
--- clang/lib/AST/Interp/Context.cpp
+++ clang/lib/AST/Interp/Context.cpp
@@ -78,9 +78,11 @@
 const LangOptions &Context::getLangOpts() const { return Ctx.getLangOpts(); }
 
 std::optional<PrimType> Context::classify(QualType T) const {
-  if (T->isReferenceType() || T->isPointerType()) {
+  if (T->isFunctionPointerType() || T->isFunctionReferenceType())
+    return PT_FnPtr;
+
+  if (T->isReferenceType() || T->isPointerType())
     return PT_Ptr;
-  }
 
   if (T->isBooleanType())
     return PT_Bool;
Index: clang/lib/AST/Interp/ByteCodeExprGen.cpp
===================================================================
--- clang/lib/AST/Interp/ByteCodeExprGen.cpp
+++ clang/lib/AST/Interp/ByteCodeExprGen.cpp
@@ -134,10 +134,11 @@
   }
 
   case CK_NullToPointer:
-  case CK_IntegralToPointer: {
-    if (isa<CXXNullPtrLiteralExpr>(SubExpr))
-      return this->visit(SubExpr);
+    if (DiscardResult)
+      return true;
+    return this->emitNull(classifyPrim(CE->getType()), CE);
 
+  case CK_IntegralToPointer: {
     if (!this->visit(SubExpr))
       return false;
 
@@ -968,6 +969,7 @@
     return this->emitZeroUint64(E);
   case PT_Ptr:
     return this->emitNullPtr(E);
+  case PT_FnPtr:
   case PT_Float:
     assert(false);
   }
@@ -1134,6 +1136,7 @@
   case PT_Bool:
     return this->emitConstBool(Value, E);
   case PT_Ptr:
+  case PT_FnPtr:
   case PT_Float:
     llvm_unreachable("Invalid integral type");
     break;
@@ -1667,8 +1670,27 @@
   if (E->getBuiltinCallee())
     return VisitBuiltinCallExpr(E);
 
-  const Decl *Callee = E->getCalleeDecl();
-  if (const auto *FuncDecl = dyn_cast_if_present<FunctionDecl>(Callee)) {
+  QualType ReturnType = E->getCallReturnType(Ctx.getASTContext());
+  std::optional<PrimType> T = classify(ReturnType);
+  bool HasRVO = !ReturnType->isVoidType() && !T;
+
+  if (HasRVO && DiscardResult) {
+    // If we need to discard the return value but the function returns its
+    // value via an RVO pointer, we need to create one such pointer just
+    // for this call.
+    if (std::optional<unsigned> LocalIndex = allocateLocal(E)) {
+      if (!this->emitGetPtrLocal(*LocalIndex, E))
+        return false;
+    }
+  }
+
+  // Put arguments on the stack.
+  for (const auto *Arg : E->arguments()) {
+    if (!this->visit(Arg))
+      return false;
+  }
+
+  if (const FunctionDecl *FuncDecl = E->getDirectCallee()) {
     const Function *Func = getFunction(FuncDecl);
     if (!Func)
       return false;
@@ -1680,24 +1702,7 @@
     if (Func->isFullyCompiled() && !Func->isConstexpr())
       return false;
 
-    QualType ReturnType = E->getCallReturnType(Ctx.getASTContext());
-    std::optional<PrimType> T = classify(ReturnType);
-
-    if (Func->hasRVO() && DiscardResult) {
-      // If we need to discard the return value but the function returns its
-      // value via an RVO pointer, we need to create one such pointer just
-      // for this call.
-      if (std::optional<unsigned> LocalIndex = allocateLocal(E)) {
-        if (!this->emitGetPtrLocal(*LocalIndex, E))
-          return false;
-      }
-    }
-
-    // Put arguments on the stack.
-    for (const auto *Arg : E->arguments()) {
-      if (!this->visit(Arg))
-        return false;
-    }
+    assert(HasRVO == Func->hasRVO());
 
     // In any case call the function. The return value will end up on the stack
     // and if the function has RVO, we already have the pointer on the stack to
@@ -1705,15 +1710,21 @@
     if (!this->emitCall(Func, E))
       return false;
 
-    if (DiscardResult && !ReturnType->isVoidType() && T)
-      return this->emitPop(*T, E);
-
-    return true;
   } else {
-    assert(false && "We don't support non-FunctionDecl callees right now.");
+    // Indirect call. Visit the callee, which will leave a FunctionPointer on
+    // the stack. Cleanup of the returned value if necessary will be done after
+    // the function call completed.
+    if (!this->visit(E->getCallee()))
+      return false;
+
+    this->emitCallPtr(E);
   }
 
-  return false;
+  // Cleanup for discarded return values.
+  if (DiscardResult && !ReturnType->isVoidType() && T)
+    return this->emitPop(*T, E);
+
+  return true;
 }
 
 template <class Emitter>
@@ -1912,6 +1923,9 @@
     }
   } else if (const auto *ECD = dyn_cast<EnumConstantDecl>(Decl)) {
     return this->emitConst(ECD->getInitVal(), E);
+  } else if (const auto *FuncDecl = dyn_cast<FunctionDecl>(Decl)) {
+    const Function *F = getFunction(FuncDecl);
+    return F && this->emitGetFnPtr(F, E);
   }
 
   return false;
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to