bruno updated this revision to Diff 505588.
bruno marked 2 inline comments as done.
bruno added a comment.

Update after reviewer comments


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D145641/new/

https://reviews.llvm.org/D145641

Files:
  clang/include/clang/AST/StmtCXX.h
  clang/lib/CodeGen/CGCoroutine.cpp
  clang/lib/Sema/SemaCoroutine.cpp
  clang/lib/Sema/TreeTransform.h
  clang/test/SemaCXX/coroutine-no-move-ctor.cpp

Index: clang/test/SemaCXX/coroutine-no-move-ctor.cpp
===================================================================
--- clang/test/SemaCXX/coroutine-no-move-ctor.cpp
+++ clang/test/SemaCXX/coroutine-no-move-ctor.cpp
@@ -15,13 +15,10 @@
   };
   using promise_type = invoker_promise;
   invoker() {}
-  // TODO: implement RVO for get_return_object type matching
-  // function return type.
-  //
-  // invoker(const invoker &) = delete;
-  // invoker &operator=(const invoker &) = delete;
-  // invoker(invoker &&) = delete;
-  // invoker &operator=(invoker &&) = delete;
+  invoker(const invoker &) = delete;
+  invoker &operator=(const invoker &) = delete;
+  invoker(invoker &&) = delete;
+  invoker &operator=(invoker &&) = delete;
 };
 
 invoker f() {
Index: clang/lib/Sema/TreeTransform.h
===================================================================
--- clang/lib/Sema/TreeTransform.h
+++ clang/lib/Sema/TreeTransform.h
@@ -8066,11 +8066,12 @@
       return StmtError();
     Builder.Deallocate = DeallocRes.get();
 
-    assert(S->getResultDecl() && "ResultDecl must already be built");
-    StmtResult ResultDecl = getDerived().TransformStmt(S->getResultDecl());
-    if (ResultDecl.isInvalid())
-      return StmtError();
-    Builder.ResultDecl = ResultDecl.get();
+    if (auto *ResultDecl = S->getResultDecl()) {
+      StmtResult Res = getDerived().TransformStmt(ResultDecl);
+      if (Res.isInvalid())
+        return StmtError();
+      Builder.ResultDecl = Res.get();
+    }
 
     if (auto *ReturnStmt = S->getReturnStmt()) {
       StmtResult Res = getDerived().TransformStmt(ReturnStmt);
Index: clang/lib/Sema/SemaCoroutine.cpp
===================================================================
--- clang/lib/Sema/SemaCoroutine.cpp
+++ clang/lib/Sema/SemaCoroutine.cpp
@@ -1730,13 +1730,22 @@
   assert(!FnRetType->isDependentType() &&
          "get_return_object type must no longer be dependent");
 
+  // The call to get_­return_­object is sequenced before the call to
+  // initial_­suspend and is invoked at most once, but there are caveats
+  // regarding on whether the prvalue result object may be initialized
+  // directly/eager or delayed, depending on the types involved.
+  //
+  // More info at https://github.com/cplusplus/papers/issues/1414
+  bool GroMatchesRetType = S.getASTContext().hasSameType(GroType, FnRetType);
+
   if (FnRetType->isVoidType()) {
     ExprResult Res =
         S.ActOnFinishFullExpr(this->ReturnValue, Loc, /*DiscardedValue*/ false);
     if (Res.isInvalid())
       return false;
 
-    this->ResultDecl = Res.get();
+    if (!GroMatchesRetType)
+      this->ResultDecl = Res.get();
     return true;
   }
 
@@ -1749,53 +1758,59 @@
     return false;
   }
 
-  // StmtResult ReturnStmt = S.BuildReturnStmt(Loc, ReturnValue);
-  auto *GroDecl = VarDecl::Create(
-      S.Context, &FD, FD.getLocation(), FD.getLocation(),
-      &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
-      S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
-  GroDecl->setImplicit();
-
-  S.CheckVariableDeclarationType(GroDecl);
-  if (GroDecl->isInvalidDecl())
-    return false;
+  StmtResult ReturnStmt;
+  clang::VarDecl *GroDecl = nullptr;
+  if (GroMatchesRetType) {
+    ReturnStmt = S.BuildReturnStmt(Loc, ReturnValue);
+  } else {
+    GroDecl = VarDecl::Create(
+        S.Context, &FD, FD.getLocation(), FD.getLocation(),
+        &S.PP.getIdentifierTable().get("__coro_gro"), GroType,
+        S.Context.getTrivialTypeSourceInfo(GroType, Loc), SC_None);
+    GroDecl->setImplicit();
+
+    S.CheckVariableDeclarationType(GroDecl);
+    if (GroDecl->isInvalidDecl())
+      return false;
 
-  InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
-  ExprResult Res =
-      S.PerformCopyInitialization(Entity, SourceLocation(), ReturnValue);
-  if (Res.isInvalid())
-    return false;
+    InitializedEntity Entity = InitializedEntity::InitializeVariable(GroDecl);
+    ExprResult Res =
+        S.PerformCopyInitialization(Entity, SourceLocation(), ReturnValue);
+    if (Res.isInvalid())
+      return false;
 
-  Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
-  if (Res.isInvalid())
-    return false;
+    Res = S.ActOnFinishFullExpr(Res.get(), /*DiscardedValue*/ false);
+    if (Res.isInvalid())
+      return false;
 
-  S.AddInitializerToDecl(GroDecl, Res.get(),
-                         /*DirectInit=*/false);
+    S.AddInitializerToDecl(GroDecl, Res.get(),
+                           /*DirectInit=*/false);
 
-  S.FinalizeDeclaration(GroDecl);
+    S.FinalizeDeclaration(GroDecl);
 
-  // Form a declaration statement for the return declaration, so that AST
-  // visitors can more easily find it.
-  StmtResult GroDeclStmt =
-      S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
-  if (GroDeclStmt.isInvalid())
-    return false;
+    // Form a declaration statement for the return declaration, so that AST
+    // visitors can more easily find it.
+    StmtResult GroDeclStmt =
+        S.ActOnDeclStmt(S.ConvertDeclToDeclGroup(GroDecl), Loc, Loc);
+    if (GroDeclStmt.isInvalid())
+      return false;
 
-  this->ResultDecl = GroDeclStmt.get();
+    this->ResultDecl = GroDeclStmt.get();
 
-  ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
-  if (declRef.isInvalid())
-    return false;
+    ExprResult declRef = S.BuildDeclRefExpr(GroDecl, GroType, VK_LValue, Loc);
+    if (declRef.isInvalid())
+      return false;
 
-  StmtResult ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
+    ReturnStmt = S.BuildReturnStmt(Loc, declRef.get());
+  }
 
   if (ReturnStmt.isInvalid()) {
     noteMemberDeclaredHere(S, ReturnValue, Fn);
     return false;
   }
 
-  if (cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
+  if (!GroMatchesRetType &&
+      cast<clang::ReturnStmt>(ReturnStmt.get())->getNRVOCandidate() == GroDecl)
     GroDecl->setNRVOVariable(true);
 
   this->ReturnStmt = ReturnStmt.get();
Index: clang/lib/CodeGen/CGCoroutine.cpp
===================================================================
--- clang/lib/CodeGen/CGCoroutine.cpp
+++ clang/lib/CodeGen/CGCoroutine.cpp
@@ -472,13 +472,44 @@
   CodeGenFunction &CGF;
   CGBuilderTy &Builder;
   const CoroutineBodyStmt &S;
+  // When true, performs RVO for the return object.
+  bool DirectEmit = false;
 
   Address GroActiveFlag;
   CodeGenFunction::AutoVarEmission GroEmission;
 
   GetReturnObjectManager(CodeGenFunction &CGF, const CoroutineBodyStmt &S)
       : CGF(CGF), Builder(CGF.Builder), S(S), GroActiveFlag(Address::invalid()),
-        GroEmission(CodeGenFunction::AutoVarEmission::invalid()) {}
+        GroEmission(CodeGenFunction::AutoVarEmission::invalid()) {
+    // The call to get_­return_­object is sequenced before the call to
+    // initial_­suspend and is invoked at most once, but there are caveats
+    // regarding on whether the prvalue result object may be initialized
+    // directly/eager or delayed, depending on the types involved.
+    //
+    // More info at https://github.com/cplusplus/papers/issues/1414
+    //
+    // The general cases:
+    // 1. Same type of get_return_object and coroutine return type (direct
+    // emission):
+    //  - Constructed in the return slot.
+    // 2. Different types (delayed emission):
+    //  - Constructed temporary object prior to initial suspend initialized with
+    //  a call to get_return_object()
+    //  - When coroutine needs to to return to the caller and needs to construct
+    //  return value for the coroutine it is initialized with expiring value of
+    //  the temporary obtained above.
+    //
+    // Direct emission for void returning coroutines or GROs.
+    DirectEmit = [&]() {
+      auto *RVI = S.getReturnValueInit();
+      if (!RVI || CGF.FnRetTy->isVoidType())
+        return true;
+      auto GroType = RVI->getType();
+      if (GroType->isVoidType())
+        return true;
+      return CGF.getContext().hasSameType(GroType, CGF.FnRetTy);
+    }();
+  }
 
   // The gro variable has to outlive coroutine frame and coroutine promise, but,
   // it can only be initialized after coroutine promise was created, thus, we
@@ -486,7 +517,10 @@
   // cleanups. Later when coroutine promise is available we initialize the gro
   // and sets the flag that the cleanup is now active.
   void EmitGroAlloca() {
-    auto *GroDeclStmt = dyn_cast<DeclStmt>(S.getResultDecl());
+    if (DirectEmit)
+      return;
+
+    auto *GroDeclStmt = dyn_cast_or_null<DeclStmt>(S.getResultDecl());
     if (!GroDeclStmt) {
       // If get_return_object returns void, no need to do an alloca.
       return;
@@ -519,6 +553,27 @@
   }
 
   void EmitGroInit() {
+    if (DirectEmit) {
+      // ReturnValue should be valid as long as the coroutine's return type
+      // is not void. The assertion could help us to reduce the check later.
+      assert(CGF.ReturnValue.isValid() == (bool)S.getReturnStmt());
+      // Now we have the promise, initialize the GRO.
+      // We need to emit `get_return_object` first. According to:
+      // [dcl.fct.def.coroutine]p7
+      // The call to get_return_­object is sequenced before the call to
+      // initial_suspend and is invoked at most once.
+      //
+      // So we couldn't emit return value when we emit return statment,
+      // otherwise the call to get_return_object wouldn't be in front
+      // of initial_suspend.
+      if (CGF.ReturnValue.isValid()) {
+        CGF.EmitAnyExprToMem(S.getReturnValue(), CGF.ReturnValue,
+                             S.getReturnValue()->getType().getQualifiers(),
+                             /*IsInit*/ true);
+      }
+      return;
+    }
+
     if (!GroActiveFlag.isValid()) {
       // No Gro variable was allocated. Simply emit the call to
       // get_return_object.
@@ -598,10 +653,6 @@
       CGM.getIntrinsic(llvm::Intrinsic::coro_begin), {CoroId, Phi});
   CurCoro.Data->CoroBegin = CoroBegin;
 
-  // We need to emit `get_­return_­object` first. According to:
-  // [dcl.fct.def.coroutine]p7
-  // The call to get_­return_­object is sequenced before the call to
-  // initial_­suspend and is invoked at most once.
   GetReturnObjectManager GroManager(*this, S);
   GroManager.EmitGroAlloca();
 
@@ -706,8 +757,13 @@
   llvm::Function *CoroEnd = CGM.getIntrinsic(llvm::Intrinsic::coro_end);
   Builder.CreateCall(CoroEnd, {NullPtr, Builder.getFalse()});
 
-  if (Stmt *Ret = S.getReturnStmt())
+  if (Stmt *Ret = S.getReturnStmt()) {
+    // Since we already emitted the return value above, so we shouldn't
+    // emit it again here.
+    if (GroManager.DirectEmit)
+      cast<ReturnStmt>(Ret)->setRetValue(nullptr);
     EmitStmt(Ret);
+  }
 
   // LLVM require the frontend to mark the coroutine.
   CurFn->setPresplitCoroutine();
Index: clang/include/clang/AST/StmtCXX.h
===================================================================
--- clang/include/clang/AST/StmtCXX.h
+++ clang/include/clang/AST/StmtCXX.h
@@ -411,9 +411,8 @@
     return cast<Expr>(getStoredStmts()[SubStmt::ReturnValue]);
   }
   Expr *getReturnValue() const {
-    assert(getReturnStmt());
-    auto *RS = cast<clang::ReturnStmt>(getReturnStmt());
-    return RS->getRetValue();
+    auto *RS = dyn_cast_or_null<clang::ReturnStmt>(getReturnStmt());
+    return RS ? RS->getRetValue() : nullptr;
   }
   Stmt *getReturnStmt() const { return getStoredStmts()[SubStmt::ReturnStmt]; }
   Stmt *getReturnStmtOnAllocFailure() const {
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to