modocache created this revision.
modocache added reviewers: rsmith, GorNishanov, eric_niebler.
Herald added a subscriber: EricWF.

Depends on https://reviews.llvm.org/D42605.

An implementation of the behavior described in `[dcl.fct.def.coroutine]/7`:
when a promise type overloads `operator new` using a "placement new"
that takes the same argument types as the coroutine function, that
overload is used when allocating the coroutine frame.

Simply passing references to the coroutine function parameters directly
to `operator new` results in invariant violations in LLVM's coroutine
splitting pass, so this implementation modifies Clang codegen to
produce allocator-specific alloc/store/loads for each parameter being
forwarded to the allocator.

Test Plan: `check-clang`


Repository:
  rC Clang

https://reviews.llvm.org/D42606

Files:
  lib/CodeGen/CGCoroutine.cpp
  lib/Sema/SemaCoroutine.cpp
  test/CodeGenCoroutines/coro-alloc.cpp

Index: test/CodeGenCoroutines/coro-alloc.cpp
===================================================================
--- test/CodeGenCoroutines/coro-alloc.cpp
+++ test/CodeGenCoroutines/coro-alloc.cpp
@@ -106,6 +106,34 @@
   co_return;
 }
 
+struct promise_matching_placement_new_tag {};
+
+template<>
+struct std::experimental::coroutine_traits<void, promise_matching_placement_new_tag, int, float, double> {
+  struct promise_type {
+    void *operator new(unsigned long, promise_matching_placement_new_tag,
+                       int, float, double);
+    void get_return_object() {}
+    suspend_always initial_suspend() { return {}; }
+    suspend_always final_suspend() { return {}; }
+    void return_void() {}
+  };
+};
+
+// CHECK-LABEL: f1a(
+extern "C" void f1a(promise_matching_placement_new_tag, int x, float y , double z) {
+  // CHECK: %[[ID:.+]] = call token @llvm.coro.id(i32 16
+  // CHECK: %[[SIZE:.+]] = call i64 @llvm.coro.size.i64()
+  // CHECK: store i32 %x, i32* %coro.allocate.x.addr, align 4
+  // CHECK: %[[INT:.+]] = load i32, i32* %coro.allocate.x.addr, align 4
+  // CHECK: store float %y, float* %coro.allocate.y.addr, align 4
+  // CHECK: %[[FLOAT:.+]] = load float, float* %coro.allocate.y.addr, align 4
+  // CHECK: store double %z, double* %coro.allocate.z.addr, align 8
+  // CHECK: %[[DOUBLE:.+]] = load double, double* %coro.allocate.z.addr, align 8
+  // CHECK: call i8* @_ZNSt12experimental16coroutine_traitsIJv34promise_matching_placement_new_tagifdEE12promise_typenwEmS1_ifd(i64 %[[SIZE]], i32 %[[INT]], float %[[FLOAT]], double %[[DOUBLE]])
+  co_return;
+}
+
 struct promise_delete_tag {};
 
 template<>
Index: lib/Sema/SemaCoroutine.cpp
===================================================================
--- lib/Sema/SemaCoroutine.cpp
+++ lib/Sema/SemaCoroutine.cpp
@@ -1050,18 +1050,54 @@
 
   const bool RequiresNoThrowAlloc = ReturnStmtOnAllocFailure != nullptr;
 
-  // FIXME: Add support for stateful allocators.
+  // [dcl.fct.def.coroutine]/7
+  // Lookup allocation functions using a parameter list composed of the
+  // requested size of the coroutine state being allocated, followed by
+  // the coroutine function's arguments. If a matching allocation function
+  // exists, use it. Otherwise, use an allocation function that just takes
+  // the requested size.
 
   FunctionDecl *OperatorNew = nullptr;
   FunctionDecl *OperatorDelete = nullptr;
   FunctionDecl *UnusedResult = nullptr;
   bool PassAlignment = false;
   SmallVector<Expr *, 1> PlacementArgs;
 
+  // [dcl.fct.def.coroutine]/7
+  // "The allocation function’s name is looked up in the scope of P.
+  // [...] If the lookup finds an allocation function in the scope of P,
+  // overload resolution is performed on a function call created by assembling
+  // an argument list."
+  for (auto *PD : FD.parameters()) {
+    if (PD->getType()->isDependentType())
+      continue;
+
+    // Build a reference to the parameter.
+    auto PDLoc = PD->getLocation();
+    ExprResult PDRefExpr =
+        S.BuildDeclRefExpr(PD, PD->getOriginalType().getNonReferenceType(),
+                           ExprValueKind::VK_LValue, PDLoc);
+    if (PDRefExpr.isInvalid())
+      return false;
+
+    PlacementArgs.push_back(PDRefExpr.get());
+  }
   S.FindAllocationFunctions(Loc, SourceRange(),
                             /*UseGlobal*/ false, PromiseType,
                             /*isArray*/ false, PassAlignment, PlacementArgs,
-                            OperatorNew, UnusedResult);
+                            OperatorNew, UnusedResult, /*Diagnose*/ false);
+
+  // [dcl.fct.def.coroutine]/7
+  // "If no matching function is found, overload resolution is performed again
+  // on a function call created by passing just the amount of space required as
+  // an argument of type std::size_t."
+  if (!OperatorNew && !PlacementArgs.empty()) {
+    PlacementArgs.clear();
+    S.FindAllocationFunctions(Loc, SourceRange(),
+                              /*UseGlobal*/ false, PromiseType,
+                              /*isArray*/ false, PassAlignment,
+                              PlacementArgs, OperatorNew, UnusedResult);
+  }
 
   bool IsGlobalOverload =
       OperatorNew && !isa<CXXRecordDecl>(OperatorNew->getDeclContext());
@@ -1080,7 +1116,8 @@
                               OperatorNew, UnusedResult);
   }
 
-  assert(OperatorNew && "expected definition of operator new to be found");
+  if (!OperatorNew)
+    return false;
 
   if (RequiresNoThrowAlloc) {
     const auto *FT = OperatorNew->getType()->getAs<FunctionProtoType>();
Index: lib/CodeGen/CGCoroutine.cpp
===================================================================
--- lib/CodeGen/CGCoroutine.cpp
+++ lib/CodeGen/CGCoroutine.cpp
@@ -523,7 +523,59 @@
   Builder.CreateCondBr(CoroAlloc, AllocBB, InitBB);
 
   EmitBlock(AllocBB);
-  auto *AllocateCall = EmitScalarExpr(S.getAllocate());
+  // Emit the call to the coroutine frame allocation function.
+  auto *AllocateCall = cast<llvm::CallInst>(EmitScalarExpr(S.getAllocate()));
+
+  // The backend coroutine split transform will move stores and loads for the
+  // coroutine function's arguments down past the first suspend point.
+  //
+  // However, in the case that the coroutine frame is, as per
+  // [dcl.fct.def.coroutine]/7, being allocated with an allocation function
+  // matching the coroutine function's arguments, we need to ensure that the
+  // allocation function is passed arguments that have values stored in them.
+  //
+  // Here, we generate instructions to store the coroutine function's arguments
+  // separately, and then pass them into the allocation function. First, we
+  // search for the coroutine function argument allocas that correspond to the
+  // arguments passed into the allocation function.
+  for (unsigned OpIdx = 0, OpEnd = AllocateCall->getNumArgOperands();
+       OpIdx != OpEnd; ++OpIdx) {
+    if (auto *AllocateOp =
+            dyn_cast<llvm::Instruction>(AllocateCall->getArgOperand(OpIdx))) {
+      for (auto &AllocateOpOp : AllocateOp->operands()) {
+        if (auto *Alloca = dyn_cast<llvm::AllocaInst>(AllocateOpOp)) {
+          // We've found the alloca instruction. Now we search for the store
+          // instruction that stores the coroutine function argument into that
+          // alloca's address.
+          for (auto I = Alloca->user_begin(), E = Alloca->user_end(); I != E;
+               ++I) {
+            if (auto *Store = dyn_cast<llvm::StoreInst>(*I)) {
+              // Now we generate an alloca, store, and a load, to replace the
+              // allocation function call instruction operands.
+              llvm::BasicBlock::iterator InsertPt = Builder.GetInsertPoint();
+              Builder.SetInsertPoint(AllocateCall);
+
+              llvm::AllocaInst *NewAlloca = Builder.CreateAlloca(
+                  Alloca->getAllocatedType(), Alloca->getArraySize(),
+                  "coro.allocate." + Alloca->getName());
+              NewAlloca->setAlignment(Alloca->getAlignment());
+
+              Address NewAllocaAddress = {
+                  NewAlloca,
+                  CharUnits::fromQuantity(NewAlloca->getAlignment())};
+              Builder.CreateStore(Store->getOperand(0), NewAllocaAddress);
+
+              llvm::LoadInst *NewLoad = Builder.CreateLoad(NewAllocaAddress);
+              AllocateCall->setOperand(OpIdx, NewLoad);
+
+              Builder.SetInsertPoint(AllocBB, InsertPt);
+            }
+          }
+        }
+      }
+    }
+  }
+
   auto *AllocOrInvokeContBB = Builder.GetInsertBlock();
 
   // Handle allocation failure if 'ReturnStmtOnAllocFailure' was provided.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to