ilya-biryukov updated this revision to Diff 183303.
ilya-biryukov added a comment.

- Improve a comment


CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D56723/new/

https://reviews.llvm.org/D56723

Files:
  clang/include/clang/Parse/Parser.h
  clang/include/clang/Sema/CodeCompleteConsumer.h
  clang/include/clang/Sema/Sema.h
  clang/lib/Parse/ParseDecl.cpp
  clang/lib/Parse/ParseExpr.cpp
  clang/lib/Parse/ParseExprCXX.cpp
  clang/lib/Parse/ParseStmt.cpp
  clang/lib/Sema/SemaCodeComplete.cpp
  clang/unittests/Sema/CodeCompleteTest.cpp

Index: clang/unittests/Sema/CodeCompleteTest.cpp
===================================================================
--- clang/unittests/Sema/CodeCompleteTest.cpp
+++ clang/unittests/Sema/CodeCompleteTest.cpp
@@ -339,4 +339,103 @@
   EXPECT_THAT(collectPreferredTypes(Code), Each("NULL TYPE"));
 }
 
+TEST(PreferredTypeTest, Members) {
+  StringRef Code = R"cpp(
+    struct vector {
+      int *begin();
+      vector clone();
+    };
+
+    void test(int *a) {
+      a = ^vector().^clone().^begin();
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
+}
+
+TEST(PreferredTypeTest, Conditions) {
+  StringRef Code = R"cpp(
+    struct vector {
+      bool empty();
+    };
+
+    void test() {
+      if (^vector().^empty()) {}
+      while (^vector().^empty()) {}
+      for (; ^vector().^empty();) {}
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
+}
+
+TEST(PreferredTypeTest, InitAndAssignment) {
+  StringRef Code = R"cpp(
+    struct vector {
+      int* begin();
+    };
+
+    void test() {
+      const int* x = ^vector().^begin();
+      x = ^vector().^begin();
+
+      if (const int* y = ^vector().^begin()) {}
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
+}
+
+TEST(PreferredTypeTest, UnaryExprs) {
+  StringRef Code = R"cpp(
+    void test(long long a) {
+      a = +^a;
+      a = -^a
+      a = ++^a;
+      a = --^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("long long"));
+
+  Code = R"cpp(
+    void test(int a, int *ptr) {
+      !^a;
+      !^ptr;
+      !!!^a;
+
+      a = !^a;
+      a = !^ptr;
+      a = !!!^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("_Bool"));
+
+  Code = R"cpp(
+    void test(int a) {
+      const int* x = &^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int"));
+
+  Code = R"cpp(
+    void test(int *a) {
+      int x = *^a;
+      int &r = *^a;
+    }
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("int *"));
+
+  Code = R"cpp(
+    void test(int a) {
+      *^a;
+      &^a;
+    }
+
+  )cpp";
+}
+
+TEST(PreferredTypeTest, ParenExpr) {
+  StringRef Code = R"cpp(
+    const int *i = ^(^(^(^10)));
+  )cpp";
+  EXPECT_THAT(collectPreferredTypes(Code), Each("const int *"));
+}
 } // namespace
Index: clang/lib/Sema/SemaCodeComplete.cpp
===================================================================
--- clang/lib/Sema/SemaCodeComplete.cpp
+++ clang/lib/Sema/SemaCodeComplete.cpp
@@ -347,6 +347,210 @@
 };
 } // namespace
 
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterReturn(Sema &S, SourceLocation Tok) {
+  RestoreRAII R(*this);
+  if (isa<BlockDecl>(S.CurContext)) {
+    if (sema::BlockScopeInfo *BSI = S.getCurBlock()) {
+      Type = BSI->ReturnType;
+      ExpectedLoc = Tok;
+    }
+    return R;
+  }
+  if (const auto *Function = dyn_cast<FunctionDecl>(S.CurContext)) {
+    Type = Function->getReturnType();
+    ExpectedLoc = Tok;
+    return R;
+  }
+  if (const auto *Method = dyn_cast<ObjCMethodDecl>(S.CurContext)) {
+    Type = Method->getReturnType();
+    ExpectedLoc = Tok;
+    return R;
+  }
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterVariableInit(SourceLocation Tok, Decl *D) {
+  RestoreRAII R(*this);
+  auto *VD = llvm::dyn_cast_or_null<ValueDecl>(D);
+  Type = VD ? VD->getType() : QualType();
+  ExpectedLoc = Tok;
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterParenExpr(SourceLocation Tok,
+                                     SourceLocation LParLoc) {
+  RestoreRAII R(*this);
+  // expected type for parenthesized expression does not change.
+  if (ExpectedLoc == LParLoc)
+    ExpectedLoc = Tok;
+  return R;
+}
+
+static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS,
+                                            tok::TokenKind Op) {
+  if (!LHS)
+    return QualType();
+
+  QualType LHSType = LHS->getType();
+  if (LHSType->isPointerType()) {
+    if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal)
+      return S.getASTContext().getPointerDiffType();
+    // Pointer difference is more common than subtracting an int from a pointer.
+    if (Op == tok::minus)
+      return LHSType;
+  }
+
+  switch (Op) {
+  // No way to infer the type of RHS from LHS.
+  case tok::comma:
+    return QualType();
+  // Prefer the type of the left operand for all of these.
+  // Arithmetic operations.
+  case tok::plus:
+  case tok::plusequal:
+  case tok::minus:
+  case tok::minusequal:
+  case tok::percent:
+  case tok::percentequal:
+  case tok::slash:
+  case tok::slashequal:
+  case tok::star:
+  case tok::starequal:
+  // Assignment.
+  case tok::equal:
+  // Comparison operators.
+  case tok::equalequal:
+  case tok::exclaimequal:
+  case tok::less:
+  case tok::lessequal:
+  case tok::greater:
+  case tok::greaterequal:
+  case tok::spaceship:
+    return LHS->getType();
+  // Binary shifts are often overloaded, so don't try to guess those.
+  case tok::greatergreater:
+  case tok::greatergreaterequal:
+  case tok::lessless:
+  case tok::lesslessequal:
+    if (LHSType->isIntegralOrEnumerationType())
+      return S.getASTContext().IntTy;
+    return QualType();
+  // Logical operators, assume we want bool.
+  case tok::ampamp:
+  case tok::pipepipe:
+  case tok::caretcaret:
+    return S.getASTContext().BoolTy;
+  // Operators often used for bit manipulation are typically used with the type
+  // of the left argument.
+  case tok::pipe:
+  case tok::pipeequal:
+  case tok::caret:
+  case tok::caretequal:
+  case tok::amp:
+  case tok::ampequal:
+    if (LHSType->isIntegralOrEnumerationType())
+      return LHSType;
+    return QualType();
+  // RHS should be a pointer to a member of the 'LHS' type, but we can't give
+  // any particular type here.
+  case tok::periodstar:
+  case tok::arrowstar:
+    return QualType();
+  default:
+    // FIXME(ibiryukov): handle the missing op, re-add the assertion.
+    // assert(false && "unhandled binary op");
+    return QualType();
+  }
+}
+
+/// Get preferred type for an argument of an unary expression. \p ContextType is
+/// preferred type of the whole unary expression.
+static QualType getPreferredTypeOfUnaryArg(Sema &S, QualType ContextType,
+                                           tok::TokenKind Op) {
+  switch (Op) {
+  case tok::exclaim:
+    return S.getASTContext().BoolTy;
+  case tok::amp:
+    if (!ContextType.isNull() && ContextType->isPointerType())
+      return ContextType->getPointeeType();
+    return QualType();
+  case tok::star:
+    if (ContextType.isNull())
+      return QualType();
+    return S.getASTContext().getPointerType(ContextType.getNonReferenceType());
+  case tok::plus:
+  case tok::minus:
+  case tok::tilde:
+  case tok::minusminus:
+  case tok::plusplus:
+    if (ContextType.isNull())
+      return S.getASTContext().IntTy;
+    // leave as is, these operators typically return the same type.
+    return ContextType;
+  case tok::kw___real:
+  case tok::kw___imag:
+    return QualType();
+  default:
+    assert(false && "unhnalded unary op");
+    return QualType();
+  }
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterBinary(Sema &S, SourceLocation Tok, Expr *LHS,
+                                  tok::TokenKind Op) {
+  RestoreRAII R(*this);
+  Type = getPreferredTypeOfBinaryRHS(S, LHS, Op);
+  ExpectedLoc = Tok;
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterMemAccess(Sema &S, SourceLocation Tok, Expr *Base) {
+  RestoreRAII R(*this);
+  if (!Base)
+    return R;
+  Type = this->get(Base->getBeginLoc());
+  ExpectedLoc = Tok;
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterUnary(Sema &S, SourceLocation Tok,
+                                 tok::TokenKind OpKind, SourceLocation OpLoc) {
+  RestoreRAII R(*this);
+  Type = getPreferredTypeOfUnaryArg(S, this->get(OpLoc), OpKind);
+  ExpectedLoc = Tok;
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterSubscript(Sema &S, SourceLocation Tok, Expr *LHS) {
+  RestoreRAII R(*this);
+  Type = S.getASTContext().IntTy;
+  ExpectedLoc = Tok;
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterTypeCast(SourceLocation Tok, QualType CastType) {
+  RestoreRAII R(*this);
+  Type = !CastType.isNull() ? CastType.getCanonicalType() : QualType();
+  ExpectedLoc = Tok;
+  return R;
+}
+
+PreferredTypeBuilder::RestoreRAII
+PreferredTypeBuilder::enterCondition(Sema &S, SourceLocation Tok) {
+  RestoreRAII R(*this);
+  Type = S.getASTContext().BoolTy;
+  ExpectedLoc = Tok;
+  return R;
+}
+
 class ResultBuilder::ShadowMapEntry::iterator {
   llvm::PointerUnion<const NamedDecl *, const DeclIndexPair *> DeclOrIterator;
   unsigned SingleDeclIndex;
@@ -3856,13 +4060,15 @@
 }
 
 struct Sema::CodeCompleteExpressionData {
-  CodeCompleteExpressionData(QualType PreferredType = QualType())
+  CodeCompleteExpressionData(QualType PreferredType = QualType(),
+                             bool IsParenthesized = false)
       : PreferredType(PreferredType), IntegralConstantExpression(false),
-        ObjCCollection(false) {}
+        ObjCCollection(false), IsParenthesized(IsParenthesized) {}
 
   QualType PreferredType;
   bool IntegralConstantExpression;
   bool ObjCCollection;
+  bool IsParenthesized;
   SmallVector<Decl *, 4> IgnoreDecls;
 };
 
@@ -3873,13 +4079,18 @@
   ResultBuilder Results(
       *this, CodeCompleter->getAllocator(),
       CodeCompleter->getCodeCompletionTUInfo(),
-      CodeCompletionContext(CodeCompletionContext::CCC_Expression,
-                            Data.PreferredType));
+      CodeCompletionContext(
+          Data.IsParenthesized
+              ? CodeCompletionContext::CCC_ParenthesizedExpression
+              : CodeCompletionContext::CCC_Expression,
+          Data.PreferredType));
+  auto PCC =
+      Data.IsParenthesized ? PCC_ParenthesizedExpression : PCC_Expression;
   if (Data.ObjCCollection)
     Results.setFilter(&ResultBuilder::IsObjCCollection);
   else if (Data.IntegralConstantExpression)
     Results.setFilter(&ResultBuilder::IsIntegralConstantValue);
-  else if (WantTypesInContext(PCC_Expression, getLangOpts()))
+  else if (WantTypesInContext(PCC, getLangOpts()))
     Results.setFilter(&ResultBuilder::IsOrdinaryName);
   else
     Results.setFilter(&ResultBuilder::IsOrdinaryNonTypeName);
@@ -3897,7 +4108,7 @@
                      CodeCompleter->loadExternal());
 
   Results.EnterNewScope();
-  AddOrdinaryNameResults(PCC_Expression, S, *this, Results);
+  AddOrdinaryNameResults(PCC, S, *this, Results);
   Results.ExitScope();
 
   bool PreferredTypeIsPointer = false;
@@ -3917,13 +4128,16 @@
                             Results.data(), Results.size());
 }
 
-void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType) {
-  return CodeCompleteExpression(S, CodeCompleteExpressionData(PreferredType));
+void Sema::CodeCompleteExpression(Scope *S, QualType PreferredType,
+                                  bool IsParenthesized) {
+  return CodeCompleteExpression(
+      S, CodeCompleteExpressionData(PreferredType, IsParenthesized));
 }
 
-void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E) {
+void Sema::CodeCompletePostfixExpression(Scope *S, ExprResult E,
+                                         QualType PreferredType) {
   if (E.isInvalid())
-    CodeCompleteOrdinaryName(S, PCC_RecoveryInFunction);
+    CodeCompleteExpression(S, PreferredType);
   else if (getLangOpts().ObjC)
     CodeCompleteObjCInstanceMessage(S, E.get(), None, false);
 }
@@ -4211,7 +4425,8 @@
 void Sema::CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base,
                                            Expr *OtherOpBase,
                                            SourceLocation OpLoc, bool IsArrow,
-                                           bool IsBaseExprStatement) {
+                                           bool IsBaseExprStatement,
+                                           QualType PreferredType) {
   if (!Base || !CodeCompleter)
     return;
 
@@ -4239,6 +4454,7 @@
   }
 
   CodeCompletionContext CCContext(contextKind, ConvertedBaseType);
+  CCContext.setPreferredType(PreferredType);
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
                         CodeCompleter->getCodeCompletionTUInfo(), CCContext,
                         &ResultBuilder::IsMember);
@@ -4800,22 +5016,6 @@
   CodeCompleteExpression(S, Data);
 }
 
-void Sema::CodeCompleteReturn(Scope *S) {
-  QualType ResultType;
-  if (isa<BlockDecl>(CurContext)) {
-    if (BlockScopeInfo *BSI = getCurBlock())
-      ResultType = BSI->ReturnType;
-  } else if (const auto *Function = dyn_cast<FunctionDecl>(CurContext))
-    ResultType = Function->getReturnType();
-  else if (const auto *Method = dyn_cast<ObjCMethodDecl>(CurContext))
-    ResultType = Method->getReturnType();
-
-  if (ResultType.isNull())
-    CodeCompleteOrdinaryName(S, PCC_Expression);
-  else
-    CodeCompleteExpression(S, ResultType);
-}
-
 void Sema::CodeCompleteAfterIf(Scope *S) {
   ResultBuilder Results(*this, CodeCompleter->getAllocator(),
                         CodeCompleter->getCodeCompletionTUInfo(),
@@ -4877,91 +5077,6 @@
                             Results.data(), Results.size());
 }
 
-static QualType getPreferredTypeOfBinaryRHS(Sema &S, Expr *LHS,
-                                            tok::TokenKind Op) {
-  if (!LHS)
-    return QualType();
-
-  QualType LHSType = LHS->getType();
-  if (LHSType->isPointerType()) {
-    if (Op == tok::plus || Op == tok::plusequal || Op == tok::minusequal)
-      return S.getASTContext().getPointerDiffType();
-    // Pointer difference is more common than subtracting an int from a pointer.
-    if (Op == tok::minus)
-      return LHSType;
-  }
-
-  switch (Op) {
-  // No way to infer the type of RHS from LHS.
-  case tok::comma:
-    return QualType();
-  // Prefer the type of the left operand for all of these.
-  // Arithmetic operations.
-  case tok::plus:
-  case tok::plusequal:
-  case tok::minus:
-  case tok::minusequal:
-  case tok::percent:
-  case tok::percentequal:
-  case tok::slash:
-  case tok::slashequal:
-  case tok::star:
-  case tok::starequal:
-  // Assignment.
-  case tok::equal:
-  // Comparison operators.
-  case tok::equalequal:
-  case tok::exclaimequal:
-  case tok::less:
-  case tok::lessequal:
-  case tok::greater:
-  case tok::greaterequal:
-  case tok::spaceship:
-    return LHS->getType();
-  // Binary shifts are often overloaded, so don't try to guess those.
-  case tok::greatergreater:
-  case tok::greatergreaterequal:
-  case tok::lessless:
-  case tok::lesslessequal:
-    if (LHSType->isIntegralOrEnumerationType())
-      return S.getASTContext().IntTy;
-    return QualType();
-  // Logical operators, assume we want bool.
-  case tok::ampamp:
-  case tok::pipepipe:
-  case tok::caretcaret:
-    return S.getASTContext().BoolTy;
-  // Operators often used for bit manipulation are typically used with the type
-  // of the left argument.
-  case tok::pipe:
-  case tok::pipeequal:
-  case tok::caret:
-  case tok::caretequal:
-  case tok::amp:
-  case tok::ampequal:
-    if (LHSType->isIntegralOrEnumerationType())
-      return LHSType;
-    return QualType();
-  // RHS should be a pointer to a member of the 'LHS' type, but we can't give
-  // any particular type here.
-  case tok::periodstar:
-  case tok::arrowstar:
-    return QualType();
-  default:
-    // FIXME(ibiryukov): handle the missing op, re-add the assertion.
-    // assert(false && "unhandled binary op");
-    return QualType();
-  }
-}
-
-void Sema::CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op) {
-  auto PreferredType = getPreferredTypeOfBinaryRHS(*this, LHS, Op);
-  if (!PreferredType.isNull())
-    CodeCompleteExpression(S, PreferredType);
-  else
-    CodeCompleteOrdinaryName(S, PCC_Expression);
-}
-
 void Sema::CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                    bool EnteringContext, QualType BaseType) {
   if (SS.isEmpty() || !CodeCompleter)
Index: clang/lib/Parse/ParseStmt.cpp
===================================================================
--- clang/lib/Parse/ParseStmt.cpp
+++ clang/lib/Parse/ParseStmt.cpp
@@ -1970,9 +1970,13 @@
 
   ExprResult R;
   if (Tok.isNot(tok::semi)) {
+    llvm::Optional<PreferredTypeBuilder::RestoreRAII> TypeRAII;
+    if (!IsCoreturn)
+      TypeRAII.emplace(PreferredType.enterReturn(Actions, Tok.getLocation()));
     // FIXME: Code completion for co_return.
     if (Tok.is(tok::code_completion) && !IsCoreturn) {
-      Actions.CodeCompleteReturn(getCurScope());
+      Actions.CodeCompleteExpression(getCurScope(),
+                                     PreferredType.get(Tok.getLocation()));
       cutOffParsing();
       return StmtError();
     }
Index: clang/lib/Parse/ParseExprCXX.cpp
===================================================================
--- clang/lib/Parse/ParseExprCXX.cpp
+++ clang/lib/Parse/ParseExprCXX.cpp
@@ -1672,6 +1672,9 @@
     BalancedDelimiterTracker T(*this, tok::l_paren);
     T.consumeOpen();
 
+    auto TypeRAII =
+        PreferredType.enterTypeCast(Tok.getLocation(), TypeRep.get());
+
     ExprVector Exprs;
     CommaLocsTy CommaLocs;
 
@@ -1739,6 +1742,7 @@
                                                 Sema::ConditionKind CK,
                                                 ForRangeInfo *FRI) {
   ParenBraceBracketBalancer BalancerRAIIObj(*this);
+  auto TypeRAII = PreferredType.enterCondition(Actions, Tok.getLocation());
 
   if (Tok.is(tok::code_completion)) {
     Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Condition);
@@ -1858,6 +1862,7 @@
          diag::warn_cxx98_compat_generalized_initializer_lists);
     InitExpr = ParseBraceInitializer();
   } else if (CopyInitialization) {
+    auto TypeRAII = PreferredType.enterVariableInit(Tok.getLocation(), DeclOut);
     InitExpr = ParseAssignmentExpression();
   } else if (Tok.is(tok::l_paren)) {
     // This was probably an attempt to initialize the variable.
Index: clang/lib/Parse/ParseExpr.cpp
===================================================================
--- clang/lib/Parse/ParseExpr.cpp
+++ clang/lib/Parse/ParseExpr.cpp
@@ -158,7 +158,8 @@
 /// Parse an expr that doesn't include (top-level) commas.
 ExprResult Parser::ParseAssignmentExpression(TypeCastState isTypeCast) {
   if (Tok.is(tok::code_completion)) {
-    Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(getCurScope(),
+                                   PreferredType.get(Tok.getLocation()));
     cutOffParsing();
     return ExprError();
   }
@@ -392,15 +393,8 @@
       }
     }
 
-    // Code completion for the right-hand side of a binary expression goes
-    // through a special hook that takes the left-hand side into account.
-    if (Tok.is(tok::code_completion)) {
-      Actions.CodeCompleteBinaryRHS(getCurScope(), LHS.get(),
-                                    OpToken.getKind());
-      cutOffParsing();
-      return ExprError();
-    }
-
+    auto TypeRAII = PreferredType.enterBinary(Actions, Tok.getLocation(),
+                                              LHS.get(), OpToken.getKind());
     // Parse another leaf here for the RHS of the operator.
     // ParseCastExpression works here because all RHS expressions in C have it
     // as a prefix, at least. However, in C++, an assignment-expression could
@@ -1114,6 +1108,9 @@
     //     -- cast-expression
     Token SavedTok = Tok;
     ConsumeToken();
+
+    auto TypeRAII = PreferredType.enterUnary(
+        Actions, Tok.getLocation(), SavedTok.getKind(), SavedTok.getLocation());
     // One special case is implicitly handled here: if the preceding tokens are
     // an ambiguous cast expression, such as "(T())++", then we recurse to
     // determine whether the '++' is prefix or postfix.
@@ -1135,6 +1132,8 @@
   case tok::amp: {         // unary-expression: '&' cast-expression
     // Special treatment because of member pointers
     SourceLocation SavedLoc = ConsumeToken();
+    auto TypeRAII = PreferredType.enterUnary(Actions, Tok.getLocation(),
+                                             tok::amp, SavedLoc);
     Res = ParseCastExpression(false, true);
     if (!Res.isInvalid())
       Res = Actions.ActOnUnaryOp(getCurScope(), SavedLoc, SavedKind, Res.get());
@@ -1149,6 +1148,8 @@
   case tok::kw___real:     // unary-expression: '__real' cast-expression [GNU]
   case tok::kw___imag: {   // unary-expression: '__imag' cast-expression [GNU]
     SourceLocation SavedLoc = ConsumeToken();
+    auto TypeRAII = PreferredType.enterUnary(Actions, Tok.getLocation(),
+                                             SavedKind, SavedLoc);
     Res = ParseCastExpression(false);
     if (!Res.isInvalid())
       Res = Actions.ActOnUnaryOp(getCurScope(), SavedLoc, SavedKind, Res.get());
@@ -1423,7 +1424,8 @@
     Res = ParseBlockLiteralExpression();
     break;
   case tok::code_completion: {
-    Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(getCurScope(),
+                                   PreferredType.get(Tok.getLocation()));
     cutOffParsing();
     return ExprError();
   }
@@ -1503,7 +1505,8 @@
       if (InMessageExpression)
         return LHS;
 
-      Actions.CodeCompletePostfixExpression(getCurScope(), LHS);
+      Actions.CodeCompletePostfixExpression(
+          getCurScope(), LHS, PreferredType.get(Tok.getLocation()));
       cutOffParsing();
       return ExprError();
 
@@ -1545,6 +1548,8 @@
       Loc = T.getOpenLocation();
       ExprResult Idx, Length;
       SourceLocation ColonLoc;
+      auto TypeRAII =
+          PreferredType.enterSubscript(Actions, Tok.getLocation(), LHS.get());
       if (getLangOpts().CPlusPlus11 && Tok.is(tok::l_brace)) {
         Diag(Tok, diag::warn_cxx98_compat_generalized_initializer_lists);
         Idx = ParseBraceInitializer();
@@ -1726,6 +1731,9 @@
       bool MayBePseudoDestructor = false;
       Expr* OrigLHS = !LHS.isInvalid() ? LHS.get() : nullptr;
 
+      auto TypeRAII =
+          PreferredType.enterMemAccess(Actions, Tok.getLocation(), OrigLHS);
+
       if (getLangOpts().CPlusPlus && !LHS.isInvalid()) {
         Expr *Base = OrigLHS;
         const Type* BaseType = Base->getType().getTypePtrOrNull();
@@ -1772,7 +1780,8 @@
         // Code completion for a member access expression.
         Actions.CodeCompleteMemberReferenceExpr(
             getCurScope(), Base, CorrectedBase, OpLoc, OpKind == tok::arrow,
-            Base && ExprStatementTokLoc == Base->getBeginLoc());
+            Base && ExprStatementTokLoc == Base->getBeginLoc(),
+            PreferredType.get(Tok.getLocation()));
 
         cutOffParsing();
         return ExprError();
@@ -2326,14 +2335,16 @@
     return ExprError();
   SourceLocation OpenLoc = T.getOpenLocation();
 
+  auto TypeRAII = PreferredType.enterParenExpr(Tok.getLocation(), OpenLoc);
+
   ExprResult Result(true);
   bool isAmbiguousTypeId;
   CastTy = nullptr;
 
   if (Tok.is(tok::code_completion)) {
-    Actions.CodeCompleteOrdinaryName(getCurScope(),
-                 ExprType >= CompoundLiteral? Sema::PCC_ParenthesizedExpression
-                                            : Sema::PCC_Expression);
+    Actions.CodeCompleteExpression(
+        getCurScope(), PreferredType.get(Tok.getLocation()),
+        /*IsParenthesized=*/ExprType >= CompoundLiteral);
     cutOffParsing();
     return ExprError();
   }
@@ -2414,6 +2425,9 @@
     T.consumeClose();
     ColonProtection.restore();
     RParenLoc = T.getCloseLocation();
+
+    auto TypeRAII =
+        PreferredType.enterTypeCast(Tok.getLocation(), Ty.get().get());
     ExprResult SubExpr = ParseCastExpression(/*isUnaryExpression=*/false);
 
     if (Ty.isInvalid() || SubExpr.isInvalid())
@@ -2544,6 +2558,8 @@
           return ExprError();
         }
 
+        auto TypeRAII =
+            PreferredType.enterTypeCast(Tok.getLocation(), CastTy.get());
         // Parse the cast-expression that follows it next.
         // TODO: For cast expression with CastTy.
         Result = ParseCastExpression(/*isUnaryExpression=*/false,
@@ -2845,7 +2861,8 @@
       if (Completer)
         Completer();
       else
-        Actions.CodeCompleteOrdinaryName(getCurScope(), Sema::PCC_Expression);
+        Actions.CodeCompleteExpression(getCurScope(),
+                                       PreferredType.get(Tok.getLocation()));
       cutOffParsing();
       return true;
     }
Index: clang/lib/Parse/ParseDecl.cpp
===================================================================
--- clang/lib/Parse/ParseDecl.cpp
+++ clang/lib/Parse/ParseDecl.cpp
@@ -2293,7 +2293,12 @@
         return nullptr;
       }
 
-      ExprResult Init(ParseInitializer());
+      ExprResult Init;
+      {
+        auto TypeRAII =
+            PreferredType.enterVariableInit(Tok.getLocation(), ThisDecl);
+        Init = ParseInitializer();
+      }
 
       // If this is the only decl in (possibly) range based for statement,
       // our best guess is that the user meant ':' instead of '='.
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -273,6 +273,78 @@
   }
 };
 
+/// Keeps track of expected type during expression parsing. The type is tied to
+/// a particular token, all functions that update or consume the type take a
+/// start location of the token they are looking at as a parameter. This allows
+/// to avoid updating the type on hot paths in the parser.
+class PreferredTypeBuilder {
+public:
+  class RestoreRAII;
+
+  PreferredTypeBuilder() = default;
+  explicit PreferredTypeBuilder(QualType Type) : Type(Type) {}
+
+  LLVM_NODISCARD RestoreRAII enterCondition(Sema &S, SourceLocation Tok);
+  LLVM_NODISCARD RestoreRAII enterReturn(Sema &S, SourceLocation Tok);
+  LLVM_NODISCARD RestoreRAII enterVariableInit(SourceLocation Tok, Decl *D);
+
+  LLVM_NODISCARD RestoreRAII enterParenExpr(SourceLocation Tok,
+                                            SourceLocation LParLoc);
+  LLVM_NODISCARD RestoreRAII enterUnary(Sema &S, SourceLocation Tok,
+                                        tok::TokenKind OpKind,
+                                        SourceLocation OpLoc);
+  LLVM_NODISCARD RestoreRAII enterBinary(Sema &S, SourceLocation Tok, Expr *LHS,
+                                         tok::TokenKind Op);
+  LLVM_NODISCARD RestoreRAII enterMemAccess(Sema &S, SourceLocation Tok,
+                                            Expr *Base);
+  LLVM_NODISCARD RestoreRAII enterSubscript(Sema &S, SourceLocation Tok,
+                                            Expr *LHS);
+  /// Handles all type casts, including C-style cast, C++ casts, etc.
+  LLVM_NODISCARD RestoreRAII enterTypeCast(SourceLocation Tok,
+                                           QualType CastType);
+
+  QualType get(SourceLocation Tok) const {
+    if (Tok == ExpectedLoc)
+      return Type;
+    return QualType();
+  }
+
+private:
+  /// Start position of a token for which we store expected type.
+  SourceLocation ExpectedLoc;
+  /// Expected type for a token starting at ExpectedLoc.
+  QualType Type;
+};
+
+class PreferredTypeBuilder::RestoreRAII {
+public:
+  RestoreRAII(RestoreRAII const &) = delete;
+  RestoreRAII &operator=(RestoreRAII const &) = delete;
+
+  explicit RestoreRAII(PreferredTypeBuilder &Builder)
+      : OldType(Builder.Type), OldLoc(Builder.ExpectedLoc), Builder(&Builder) {}
+
+  RestoreRAII(RestoreRAII &&Other) {
+    OldType = Other.OldType;
+    OldLoc = Other.OldLoc;
+    Builder = Other.Builder;
+
+    Other.Builder = nullptr;
+  }
+
+  ~RestoreRAII() {
+    if (!Builder)
+      return;
+    Builder->Type = OldType;
+    Builder->ExpectedLoc = OldLoc;
+  }
+
+private:
+  QualType OldType;
+  SourceLocation OldLoc;
+  PreferredTypeBuilder *Builder;
+};
+
 /// Sema - This implements semantic analysis and AST building for C.
 class Sema {
   Sema(const Sema &) = delete;
@@ -10351,11 +10423,14 @@
   struct CodeCompleteExpressionData;
   void CodeCompleteExpression(Scope *S,
                               const CodeCompleteExpressionData &Data);
-  void CodeCompleteExpression(Scope *S, QualType PreferredType);
+  void CodeCompleteExpression(Scope *S, QualType PreferredType,
+                              bool IsParenthesized = false);
   void CodeCompleteMemberReferenceExpr(Scope *S, Expr *Base, Expr *OtherOpBase,
                                        SourceLocation OpLoc, bool IsArrow,
-                                       bool IsBaseExprStatement);
-  void CodeCompletePostfixExpression(Scope *S, ExprResult LHS);
+                                       bool IsBaseExprStatement,
+                                       QualType PreferredType);
+  void CodeCompletePostfixExpression(Scope *S, ExprResult LHS,
+                                     QualType PreferredType);
   void CodeCompleteTag(Scope *S, unsigned TagSpec);
   void CodeCompleteTypeQualifiers(DeclSpec &DS);
   void CodeCompleteFunctionQualifiers(DeclSpec &DS, Declarator &D,
@@ -10377,9 +10452,7 @@
                                               IdentifierInfo *II,
                                               SourceLocation OpenParLoc);
   void CodeCompleteInitializer(Scope *S, Decl *D);
-  void CodeCompleteReturn(Scope *S);
   void CodeCompleteAfterIf(Scope *S);
-  void CodeCompleteBinaryRHS(Scope *S, Expr *LHS, tok::TokenKind Op);
 
   void CodeCompleteQualifiedId(Scope *S, CXXScopeSpec &SS,
                                bool EnteringContext, QualType BaseType);
Index: clang/include/clang/Sema/CodeCompleteConsumer.h
===================================================================
--- clang/include/clang/Sema/CodeCompleteConsumer.h
+++ clang/include/clang/Sema/CodeCompleteConsumer.h
@@ -380,6 +380,7 @@
   /// if the expression is a variable initializer or a function argument, the
   /// type of the corresponding variable or function parameter.
   QualType getPreferredType() const { return PreferredType; }
+  void setPreferredType(QualType T) { PreferredType = T; }
 
   /// Retrieve the type of the base object in a member-access
   /// expression.
Index: clang/include/clang/Parse/Parser.h
===================================================================
--- clang/include/clang/Parse/Parser.h
+++ clang/include/clang/Parse/Parser.h
@@ -74,6 +74,10 @@
   // a statement).
   SourceLocation PrevTokLocation;
 
+  /// Tracks an expected type for the current token when parsing an expression.
+  /// Used by code completion for ranking.
+  PreferredTypeBuilder PreferredType;
+
   unsigned short ParenCount = 0, BracketCount = 0, BraceCount = 0;
   unsigned short MisplacedModuleBeginCount = 0;
 
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to