capfredf updated this revision to Diff 547241.
capfredf marked 2 inline comments as done.
capfredf added a comment.

changes per discussions


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D154382/new/

https://reviews.llvm.org/D154382

Files:
  clang/include/clang/Interpreter/CodeCompletion.h
  clang/include/clang/Interpreter/Interpreter.h
  clang/include/clang/Sema/CodeCompleteConsumer.h
  clang/include/clang/Sema/Sema.h
  clang/lib/Frontend/ASTUnit.cpp
  clang/lib/Interpreter/CMakeLists.txt
  clang/lib/Interpreter/CodeCompletion.cpp
  clang/lib/Interpreter/DeviceOffload.cpp
  clang/lib/Interpreter/ExternalSource.cpp
  clang/lib/Interpreter/ExternalSource.h
  clang/lib/Interpreter/IncrementalParser.cpp
  clang/lib/Interpreter/IncrementalParser.h
  clang/lib/Interpreter/Interpreter.cpp
  clang/lib/Parse/ParseDecl.cpp
  clang/lib/Parse/Parser.cpp
  clang/lib/Sema/CodeCompleteConsumer.cpp
  clang/lib/Sema/SemaCodeComplete.cpp
  clang/test/CodeCompletion/incrememal-mode-completion-no-error.cpp
  clang/test/CodeCompletion/incremental-top-level.cpp
  clang/tools/clang-repl/ClangRepl.cpp
  clang/tools/libclang/CIndexCodeCompletion.cpp
  clang/unittests/Interpreter/CMakeLists.txt
  clang/unittests/Interpreter/CodeCompletionTest.cpp
  clang/unittests/Interpreter/IncrementalProcessingTest.cpp
  clang/unittests/Interpreter/InterpreterTest.cpp

Index: clang/unittests/Interpreter/InterpreterTest.cpp
===================================================================
--- clang/unittests/Interpreter/InterpreterTest.cpp
+++ clang/unittests/Interpreter/InterpreterTest.cpp
@@ -41,6 +41,7 @@
 
 namespace {
 using Args = std::vector<const char *>;
+
 static std::unique_ptr<Interpreter>
 createInterpreter(const Args &ExtraArgs = {},
                   DiagnosticConsumer *Client = nullptr) {
Index: clang/unittests/Interpreter/IncrementalProcessingTest.cpp
===================================================================
--- clang/unittests/Interpreter/IncrementalProcessingTest.cpp
+++ clang/unittests/Interpreter/IncrementalProcessingTest.cpp
@@ -55,6 +55,7 @@
   auto CB = clang::IncrementalCompilerBuilder();
   CB.SetCompilerArgs(ClangArgv);
   auto CI = cantFail(CB.CreateCpp());
+
   auto Interp = llvm::cantFail(Interpreter::create(std::move(CI)));
 
   std::array<clang::PartialTranslationUnit *, 2> PTUs;
Index: clang/unittests/Interpreter/CodeCompletionTest.cpp
===================================================================
--- /dev/null
+++ clang/unittests/Interpreter/CodeCompletionTest.cpp
@@ -0,0 +1,107 @@
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/Interpreter/Interpreter.h"
+
+#include "clang/Frontend/CompilerInstance.h"
+#include "llvm/LineEditor/LineEditor.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/raw_ostream.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+using namespace clang;
+namespace {
+auto CB = clang::IncrementalCompilerBuilder();
+
+static std::unique_ptr<Interpreter> createInterpreter() {
+  auto CI = cantFail(CB.CreateCpp());
+  return cantFail(clang::Interpreter::create(std::move(CI)));
+}
+
+static std::vector<std::string> runComp(clang::Interpreter &MainInterp,
+                                        llvm::StringRef Prefix,
+                                        llvm::Error &ErrR) {
+  std::vector<clang::CodeCompletionResult> Results;
+  auto CI = CB.CreateCpp();
+  if (auto Err = CI.takeError()) {
+    ErrR = std::move(Err);
+    return {};
+  }
+
+  size_t Lines = std::count(Prefix.begin(), Prefix.end(), '\n') + 1;
+  auto CFG = clang::CodeCompletionCfg{
+      Prefix.size(), Lines,
+      const_cast<clang::CompilerInstance *>(MainInterp.getCompilerInstance()),
+      Results};
+
+  auto Interp = clang::Interpreter::create(std::move(*CI), CFG);
+  if (auto Err = Interp.takeError()) {
+    // log the error and returns an empty vector;
+    ErrR = std::move(Err);
+
+    return {};
+  }
+
+  if (auto PTU = (*Interp)->Parse(Prefix); !PTU) {
+    ErrR = std::move(PTU.takeError());
+    return {};
+  }
+
+  std::vector<std::string> Comps;
+  for (auto c : ConvertToCodeCompleteStrings(Results)) {
+    if (c.startswith(Prefix))
+      Comps.push_back(c.substr(Prefix.size()).str());
+  }
+
+  return Comps;
+}
+
+TEST(CodeCompletionTest, Sanity) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "f", Err);
+  EXPECT_EQ((size_t)2, comps.size()); // foo and float
+  EXPECT_EQ(comps[0], std::string("oo"));
+  EXPECT_EQ((bool)Err, false);
+}
+
+TEST(CodeCompletionTest, SanityNoneValid) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int foo = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "babanana", Err);
+  EXPECT_EQ((size_t)0, comps.size()); // foo and float
+  EXPECT_EQ((bool)Err, false);
+}
+
+TEST(CodeCompletionTest, TwoDecls) {
+  auto Interp = createInterpreter();
+  if (auto R = Interp->ParseAndExecute("int application = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  if (auto R = Interp->ParseAndExecute("int apple = 12;")) {
+    consumeError(std::move(R));
+    return;
+  }
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "app", Err);
+  EXPECT_EQ((size_t)2, comps.size());
+  EXPECT_EQ((bool)Err, false);
+}
+
+TEST(CodeCompletionTest, CompFunDeclsNoError) {
+  auto Interp = createInterpreter();
+  auto Err = llvm::Error::success();
+  auto comps = runComp(*Interp, "void app(", Err);
+  EXPECT_EQ((bool)Err, false);
+}
+
+} // anonymous namespace
Index: clang/unittests/Interpreter/CMakeLists.txt
===================================================================
--- clang/unittests/Interpreter/CMakeLists.txt
+++ clang/unittests/Interpreter/CMakeLists.txt
@@ -9,6 +9,7 @@
 add_clang_unittest(ClangReplInterpreterTests
   IncrementalProcessingTest.cpp
   InterpreterTest.cpp
+  CodeCompletionTest.cpp
   )
 target_link_libraries(ClangReplInterpreterTests PUBLIC
   clangAST
Index: clang/tools/libclang/CIndexCodeCompletion.cpp
===================================================================
--- clang/tools/libclang/CIndexCodeCompletion.cpp
+++ clang/tools/libclang/CIndexCodeCompletion.cpp
@@ -543,6 +543,7 @@
     case CodeCompletionContext::CCC_PreprocessorExpression:
     case CodeCompletionContext::CCC_PreprocessorDirective:
     case CodeCompletionContext::CCC_Attribute:
+    case CodeCompletionContext::CCC_ReplTopLevel:
     case CodeCompletionContext::CCC_TypeQualifiers: {
       //Only Clang results should be accepted, so we'll set all of the other
       //context bits to 0 (i.e. the empty set)
Index: clang/tools/clang-repl/ClangRepl.cpp
===================================================================
--- clang/tools/clang-repl/ClangRepl.cpp
+++ clang/tools/clang-repl/ClangRepl.cpp
@@ -13,7 +13,9 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/FrontendDiagnostic.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Interpreter.h"
+#include "clang/Sema/CodeCompleteConsumer.h"
 
 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
 #include "llvm/LineEditor/LineEditor.h"
@@ -70,6 +72,69 @@
   return (Errs || HasError) ? EXIT_FAILURE : EXIT_SUCCESS;
 }
 
+struct ReplListCompleter {
+  clang::IncrementalCompilerBuilder &CB;
+  clang::Interpreter &MainInterp;
+  ReplListCompleter(clang::IncrementalCompilerBuilder &CB,
+                    clang::Interpreter &Interp)
+      : CB(CB), MainInterp(Interp){};
+
+  std::vector<llvm::LineEditor::Completion> operator()(llvm::StringRef Buffer,
+                                                       size_t Pos) const {
+    auto Err = llvm::Error::success();
+    auto res = (*this)(Buffer, Pos, Err);
+    if (Err)
+      llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
+    return res;
+  }
+
+  std::vector<llvm::LineEditor::Completion>
+  operator()(llvm::StringRef Buffer, size_t Pos, llvm::Error &ErrRes) const {
+    std::vector<llvm::LineEditor::Completion> Comps;
+    std::vector<clang::CodeCompletionResult> Results;
+
+    auto CI = CB.CreateCpp();
+    if (auto Err = CI.takeError()) {
+      ErrRes = std::move(Err);
+      return {};
+    }
+
+    size_t Lines = std::count(Buffer.begin(), Buffer.end(), '\n') + 1;
+    auto CFG = clang::CodeCompletionCfg{
+        Pos + 1, Lines,
+        const_cast<clang::CompilerInstance *>(MainInterp.getCompilerInstance()),
+        Results};
+    auto Interp = clang::Interpreter::create(std::move(*CI), CFG);
+
+    if (auto Err = Interp.takeError()) {
+      // log the error and returns an empty vector;
+      ErrRes = std::move(Err);
+
+      return {};
+    }
+
+    if (auto PTU = (*Interp)->Parse(Buffer); !PTU) {
+      ErrRes = std::move(PTU.takeError());
+      return {};
+    }
+
+    size_t space_pos = Buffer.rfind(" ");
+    llvm::StringRef s;
+    if (space_pos == llvm::StringRef::npos) {
+      s = Buffer;
+    } else {
+      s = Buffer.substr(space_pos + 1);
+    }
+
+    for (auto c : ConvertToCodeCompleteStrings(Results)) {
+      if (c.startswith(s))
+        Comps.push_back(
+            llvm::LineEditor::Completion(c.substr(s.size()).str(), c.str()));
+    }
+    return Comps;
+  }
+};
+
 llvm::ExitOnError ExitOnErr;
 int main(int argc, const char **argv) {
   ExitOnErr.setBanner("clang-repl: ");
@@ -133,6 +198,7 @@
     DeviceCI->LoadRequestedPlugins();
 
   std::unique_ptr<clang::Interpreter> Interp;
+
   if (CudaEnabled) {
     Interp = ExitOnErr(
         clang::Interpreter::createWithCUDA(std::move(CI), std::move(DeviceCI)));
@@ -155,8 +221,8 @@
 
   if (OptInputs.empty()) {
     llvm::LineEditor LE("clang-repl");
-    // FIXME: Add LE.setListCompleter
     std::string Input;
+    LE.setListCompleter(ReplListCompleter(CB, *Interp));
     while (std::optional<std::string> Line = LE.readLine()) {
       llvm::StringRef L = *Line;
       L = L.trim();
@@ -168,10 +234,10 @@
       }
 
       Input += L;
-
       if (Input == R"(%quit)") {
         break;
-      } else if (Input == R"(%undo)") {
+      }
+      if (Input == R"(%undo)") {
         if (auto Err = Interp->Undo()) {
           llvm::logAllUnhandledErrors(std::move(Err), llvm::errs(), "error: ");
           HasError = true;
Index: clang/test/CodeCompletion/incremental-top-level.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeCompletion/incremental-top-level.cpp
@@ -0,0 +1,4 @@
+int foo = 10;
+f
+// RUN: %clang_cc1 -fincremental-extensions -fsyntax-only -code-completion-at=%s:%(line-1):1 %s | FileCheck %s
+// CHECK: COMPLETION: foo : [#int#]foo
Index: clang/test/CodeCompletion/incrememal-mode-completion-no-error.cpp
===================================================================
--- /dev/null
+++ clang/test/CodeCompletion/incrememal-mode-completion-no-error.cpp
@@ -0,0 +1,3 @@
+void foo(
+// RUN: %clang_cc1 -fincremental-extensions -fsyntax-only -code-completion-at=%s:%(line-1):9 %s | wc -c | FileCheck %s
+// CHECK: 0
Index: clang/lib/Sema/SemaCodeComplete.cpp
===================================================================
--- clang/lib/Sema/SemaCodeComplete.cpp
+++ clang/lib/Sema/SemaCodeComplete.cpp
@@ -225,6 +225,7 @@
     case CodeCompletionContext::CCC_ObjCMessageReceiver:
     case CodeCompletionContext::CCC_ParenthesizedExpression:
     case CodeCompletionContext::CCC_Statement:
+    case CodeCompletionContext::CCC_ReplTopLevel:
     case CodeCompletionContext::CCC_Recovery:
       if (ObjCMethodDecl *Method = SemaRef.getCurMethodDecl())
         if (Method->isInstanceMethod())
@@ -1850,6 +1851,7 @@
   case Sema::PCC_ObjCInstanceVariableList:
   case Sema::PCC_Expression:
   case Sema::PCC_Statement:
+  case Sema::PCC_TopLevelStmtDecl:
   case Sema::PCC_ForInit:
   case Sema::PCC_Condition:
   case Sema::PCC_RecoveryInFunction:
@@ -1907,6 +1909,7 @@
   case Sema::PCC_Type:
   case Sema::PCC_ParenthesizedExpression:
   case Sema::PCC_LocalDeclarationSpecifiers:
+  case Sema::PCC_TopLevelStmtDecl:
     return true;
 
   case Sema::PCC_Expression:
@@ -2219,6 +2222,7 @@
     break;
 
   case Sema::PCC_RecoveryInFunction:
+  case Sema::PCC_TopLevelStmtDecl:
   case Sema::PCC_Statement: {
     if (SemaRef.getLangOpts().CPlusPlus11)
       AddUsingAliasResult(Builder, Results);
@@ -4208,6 +4212,8 @@
 
   case Sema::PCC_LocalDeclarationSpecifiers:
     return CodeCompletionContext::CCC_Type;
+  case Sema::PCC_TopLevelStmtDecl:
+    return CodeCompletionContext::CCC_ReplTopLevel;
   }
 
   llvm_unreachable("Invalid ParserCompletionContext!");
@@ -4348,6 +4354,7 @@
     break;
 
   case PCC_Statement:
+  case PCC_TopLevelStmtDecl:
   case PCC_ParenthesizedExpression:
   case PCC_Expression:
   case PCC_ForInit:
@@ -4385,6 +4392,7 @@
   case PCC_ParenthesizedExpression:
   case PCC_Expression:
   case PCC_Statement:
+  case PCC_TopLevelStmtDecl:
   case PCC_RecoveryInFunction:
     if (S->getFnParent())
       AddPrettyFunctionResults(getLangOpts(), Results);
Index: clang/lib/Sema/CodeCompleteConsumer.cpp
===================================================================
--- clang/lib/Sema/CodeCompleteConsumer.cpp
+++ clang/lib/Sema/CodeCompleteConsumer.cpp
@@ -51,6 +51,7 @@
   case CCC_ParenthesizedExpression:
   case CCC_Symbol:
   case CCC_SymbolOrNewName:
+  case CCC_ReplTopLevel:
     return true;
 
   case CCC_TopLevel:
@@ -169,6 +170,8 @@
     return "Recovery";
   case CCKind::CCC_ObjCClassForwardDecl:
     return "ObjCClassForwardDecl";
+  case CCKind::CCC_ReplTopLevel:
+    return "ReplTopLevel";
   }
   llvm_unreachable("Invalid CodeCompletionContext::Kind!");
 }
Index: clang/lib/Parse/Parser.cpp
===================================================================
--- clang/lib/Parse/Parser.cpp
+++ clang/lib/Parse/Parser.cpp
@@ -923,9 +923,18 @@
                                          /*IsInstanceMethod=*/std::nullopt,
                                          /*ReturnType=*/nullptr);
     }
+
+    Sema::ParserCompletionContext PCC;
+    if (CurParsedObjCImpl) {
+      PCC = Sema::PCC_ObjCImplementation;
+    } else if (PP.isIncrementalProcessingEnabled()) {
+      PCC = Sema::PCC_TopLevelStmtDecl;
+    } else {
+      PCC = Sema::PCC_Namespace;
+    };
     Actions.CodeCompleteOrdinaryName(
         getCurScope(),
-        CurParsedObjCImpl ? Sema::PCC_ObjCImplementation : Sema::PCC_Namespace);
+        PCC);
     return nullptr;
   case tok::kw_import: {
     Sema::ModuleImportState IS = Sema::ModuleImportState::NotACXX20Module;
Index: clang/lib/Parse/ParseDecl.cpp
===================================================================
--- clang/lib/Parse/ParseDecl.cpp
+++ clang/lib/Parse/ParseDecl.cpp
@@ -18,6 +18,7 @@
 #include "clang/Basic/Attributes.h"
 #include "clang/Basic/CharInfo.h"
 #include "clang/Basic/TargetInfo.h"
+#include "clang/Basic/TokenKinds.h"
 #include "clang/Parse/ParseDiagnostic.h"
 #include "clang/Parse/Parser.h"
 #include "clang/Parse/RAIIObjectsForParser.h"
@@ -6640,6 +6641,16 @@
 
   while (true) {
     if (Tok.is(tok::l_paren)) {
+      if (PP.isIncrementalProcessingEnabled() && NextToken().is(tok::code_completion)) {
+        // In clang-repl, code completion for input like `void foo(<tab>` should not trigger a parsing error.
+        // So we make the declarator malformed and exits the loop.
+        ConsumeParen();
+        cutOffParsing();
+        D.SetIdentifier(nullptr, Tok.getLocation());
+        D.setInvalidType(true);
+        break;
+      }
+
       bool IsFunctionDeclaration = D.isFunctionDeclaratorAFunctionDeclaration();
       // Enter function-declaration scope, limiting any declarators to the
       // function prototype scope, including parameter declarators.
Index: clang/lib/Interpreter/Interpreter.cpp
===================================================================
--- clang/lib/Interpreter/Interpreter.cpp
+++ clang/lib/Interpreter/Interpreter.cpp
@@ -14,6 +14,7 @@
 #include "clang/Interpreter/Interpreter.h"
 
 #include "DeviceOffload.h"
+#include "ExternalSource.h"
 #include "IncrementalExecutor.h"
 #include "IncrementalParser.h"
 
@@ -33,8 +34,10 @@
 #include "clang/Driver/Tool.h"
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/TextDiagnosticBuffer.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Value.h"
 #include "clang/Lex/PreprocessorOptions.h"
+#include "clang/Sema/CodeCompleteConsumer.h"
 #include "clang/Sema/Lookup.h"
 #include "llvm/ExecutionEngine/JITSymbol.h"
 #include "llvm/ExecutionEngine/Orc/LLJIT.h"
@@ -127,7 +130,6 @@
 
   Clang->getFrontendOpts().DisableFree = false;
   Clang->getCodeGenOpts().DisableFree = false;
-
   return std::move(Clang);
 }
 
@@ -228,13 +230,16 @@
   return IncrementalCompilerBuilder::createCuda(false);
 }
 
-Interpreter::Interpreter(std::unique_ptr<CompilerInstance> CI,
-                         llvm::Error &Err) {
+Interpreter::Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err,
+                         std::vector<CodeCompletionResult> &CCResults,
+                         const CompilerInstance *ParentCI) {
   llvm::ErrorAsOutParameter EAO(&Err);
   auto LLVMCtx = std::make_unique<llvm::LLVMContext>();
   TSCtx = std::make_unique<llvm::orc::ThreadSafeContext>(std::move(LLVMCtx));
   IncrParser = std::make_unique<IncrementalParser>(*this, std::move(CI),
-                                                   *TSCtx->getContext(), Err);
+                                                   *TSCtx->getContext(), Err,
+                                                   ParentCI,
+                                                   CCResults);
 }
 
 Interpreter::~Interpreter() {
@@ -269,16 +274,31 @@
     }
 )";
 
+std::vector<CodeCompletionResult> DummyRes;
+
 llvm::Expected<std::unique_ptr<Interpreter>>
-Interpreter::create(std::unique_ptr<CompilerInstance> CI) {
+Interpreter::create(std::unique_ptr<CompilerInstance> CI, std::optional<CodeCompletionCfg> CCCfg) {
   llvm::Error Err = llvm::Error::success();
-  auto Interp =
-      std::unique_ptr<Interpreter>(new Interpreter(std::move(CI), Err));
+  std::unique_ptr<Interpreter> Interp;
+  if (CCCfg) {
+    auto& opts = CI->getFrontendOpts();
+    opts.CodeCompletionAt.FileName = CodeCompletionFileName;
+    opts.CodeCompletionAt.Line = CCCfg->Line;
+    opts.CodeCompletionAt.Column = CCCfg->Col;
+    Interp = std::unique_ptr<Interpreter>(
+        new Interpreter(std::move(CI), Err, CCCfg->CCResult, CCCfg->ParentCI));
+  } else {
+    Interp = std::unique_ptr<Interpreter>(
+        new Interpreter(std::move(CI), Err, DummyRes));
+  }
   if (Err)
     return std::move(Err);
-  auto PTU = Interp->Parse(Runtimes);
-  if (!PTU)
-    return PTU.takeError();
+
+  if (!CCCfg) {
+    auto PTU = Interp->Parse(Runtimes);
+    if (!PTU)
+      return PTU.takeError();
+  }
 
   Interp->ValuePrintingInfo.resize(3);
   // FIXME: This is a ugly hack. Undo command checks its availability by looking
@@ -288,6 +308,7 @@
   return std::move(Interp);
 }
 
+
 llvm::Expected<std::unique_ptr<Interpreter>>
 Interpreter::createWithCUDA(std::unique_ptr<CompilerInstance> CI,
                             std::unique_ptr<CompilerInstance> DCI) {
Index: clang/lib/Interpreter/IncrementalParser.h
===================================================================
--- clang/lib/Interpreter/IncrementalParser.h
+++ clang/lib/Interpreter/IncrementalParser.h
@@ -24,10 +24,11 @@
 #include <memory>
 namespace llvm {
 class LLVMContext;
-}
+} // namespace llvm
 
 namespace clang {
 class ASTConsumer;
+class CodeCompletionResult;
 class CodeGenerator;
 class CompilerInstance;
 class IncrementalAction;
@@ -62,7 +63,9 @@
 public:
   IncrementalParser(Interpreter &Interp,
                     std::unique_ptr<CompilerInstance> Instance,
-                    llvm::LLVMContext &LLVMCtx, llvm::Error &Err);
+                    llvm::LLVMContext &LLVMCtx, llvm::Error &Err,
+                    const CompilerInstance *ParentCI,
+                    std::vector<CodeCompletionResult>& CCResults);
   virtual ~IncrementalParser();
 
   CompilerInstance *getCI() { return CI.get(); }
@@ -72,7 +75,7 @@
   ///\returns a \c PartialTranslationUnit which holds information about the
   /// \c TranslationUnitDecl and \c llvm::Module corresponding to the input.
   virtual llvm::Expected<PartialTranslationUnit &> Parse(llvm::StringRef Input);
-
+  
   /// Uses the CodeGenModule mangled name cache and avoids recomputing.
   ///\returns the mangled name of a \c GD.
   llvm::StringRef GetMangledName(GlobalDecl GD) const;
@@ -84,7 +87,11 @@
   std::unique_ptr<llvm::Module> GenModule();
 
 private:
+  bool isCodeCompletionEnabled();
   llvm::Expected<PartialTranslationUnit &> ParseOrWrapTopLevelDecl();
+
+  std::pair<FileID, SourceLocation> createSourceFile(llvm::StringRef SourceName,
+                                                     llvm::StringRef Input);
 };
 } // end namespace clang
 
Index: clang/lib/Interpreter/IncrementalParser.cpp
===================================================================
--- clang/lib/Interpreter/IncrementalParser.cpp
+++ clang/lib/Interpreter/IncrementalParser.cpp
@@ -11,6 +11,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "IncrementalParser.h"
+#include "ExternalSource.h"
 #include "clang/AST/DeclContextInternals.h"
 #include "clang/CodeGen/BackendUtil.h"
 #include "clang/CodeGen/CodeGenAction.h"
@@ -18,7 +19,9 @@
 #include "clang/Frontend/CompilerInstance.h"
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/FrontendTool/Utils.h"
+#include "clang/Interpreter/CodeCompletion.h"
 #include "clang/Interpreter/Interpreter.h"
+#include "clang/Lex/PreprocessorOptions.h"
 #include "clang/Parse/Parser.h"
 #include "clang/Sema/Sema.h"
 #include "llvm/Option/ArgList.h"
@@ -115,10 +118,14 @@
 class IncrementalAction : public WrapperFrontendAction {
 private:
   bool IsTerminating = false;
+  std::vector<CodeCompletionResult>& CCResults;
+  const CompilerInstance *ParentCI;
 
 public:
   IncrementalAction(CompilerInstance &CI, llvm::LLVMContext &LLVMCtx,
-                    llvm::Error &Err)
+                    llvm::Error &Err,
+                    const CompilerInstance *ParentCI,
+                    std::vector<CodeCompletionResult>& CCResults)
       : WrapperFrontendAction([&]() {
           llvm::ErrorAsOutParameter EAO(&Err);
           std::unique_ptr<FrontendAction> Act;
@@ -152,21 +159,25 @@
             break;
           }
           return Act;
-        }()) {}
+        }()),
+        CCResults(CCResults), ParentCI(ParentCI){}
   FrontendAction *getWrapped() const { return WrappedAction.get(); }
   TranslationUnitKind getTranslationUnitKind() override {
     return TU_Incremental;
   }
+
   void ExecuteAction() override {
     CompilerInstance &CI = getCompilerInstance();
     assert(CI.hasPreprocessor() && "No PP!");
 
-    // FIXME: Move the truncation aspect of this into Sema, we delayed this till
-    // here so the source manager would be initialized.
-    if (hasCodeCompletionSupport() &&
-        !CI.getFrontendOpts().CodeCompletionAt.FileName.empty())
-      CI.createCodeCompletionConsumer();
+    if (ParentCI) {
+      // in code completion mode,
+      CI.getPreprocessorOpts().SingleFileParseMode = true;
 
+      CI.getLangOpts().SpellChecking = false;
+      CI.getLangOpts().DelayedTemplateParsing = false;
+      CI.setCodeCompletionConsumer(new ReplCompletionConsumer(CCResults));
+    }
     // Use a code completion consumer?
     CodeCompleteConsumer *CompletionConsumer = nullptr;
     if (CI.hasCodeCompletionConsumer())
@@ -175,6 +186,17 @@
     Preprocessor &PP = CI.getPreprocessor();
     PP.EnterMainSourceFile();
 
+    if (ParentCI) {
+      ExternalSource *myExternalSource = new ExternalSource(
+          CI.getASTContext(), CI.getFileManager(), ParentCI->getASTContext(),
+          ParentCI->getFileManager());
+      llvm::IntrusiveRefCntPtr<ExternalASTSource> astContextExternalSource(
+          myExternalSource);
+      CI.getASTContext().setExternalSource(astContextExternalSource);
+      CI.getASTContext().getTranslationUnitDecl()->setHasExternalVisibleStorage(
+          true);
+    }
+
     if (!CI.hasSema())
       CI.createSema(getTranslationUnitKind(), CompletionConsumer);
   }
@@ -206,10 +228,12 @@
 IncrementalParser::IncrementalParser(Interpreter &Interp,
                                      std::unique_ptr<CompilerInstance> Instance,
                                      llvm::LLVMContext &LLVMCtx,
-                                     llvm::Error &Err)
+                                     llvm::Error &Err,
+                                     const CompilerInstance *ParentCI,
+                                     std::vector<CodeCompletionResult>& CCResults)
     : CI(std::move(Instance)) {
   llvm::ErrorAsOutParameter EAO(&Err);
-  Act = std::make_unique<IncrementalAction>(*CI, LLVMCtx, Err);
+  Act = std::make_unique<IncrementalAction>(*CI, LLVMCtx, Err, ParentCI, CCResults);
   if (Err)
     return;
   CI->ExecuteAction(*Act);
@@ -305,22 +329,17 @@
   return LastPTU;
 }
 
-llvm::Expected<PartialTranslationUnit &>
-IncrementalParser::Parse(llvm::StringRef input) {
-  Preprocessor &PP = CI->getPreprocessor();
-  assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?");
-
-  std::ostringstream SourceName;
-  SourceName << "input_line_" << InputCount++;
-
+std::pair<FileID, SourceLocation>
+IncrementalParser::createSourceFile(llvm::StringRef SourceName,
+                                    llvm::StringRef Input) {
   // Create an uninitialized memory buffer, copy code in and append "\n"
-  size_t InputSize = input.size(); // don't include trailing 0
+  size_t InputSize = Input.size(); // don't include trailing 0
   // MemBuffer size should *not* include terminating zero
   std::unique_ptr<llvm::MemoryBuffer> MB(
       llvm::WritableMemoryBuffer::getNewUninitMemBuffer(InputSize + 1,
                                                         SourceName.str()));
   char *MBStart = const_cast<char *>(MB->getBufferStart());
-  memcpy(MBStart, input.data(), InputSize);
+  memcpy(MBStart, Input.data(), InputSize);
   MBStart[InputSize] = '\n';
 
   SourceManager &SM = CI->getSourceManager();
@@ -329,19 +348,56 @@
   // candidates for example
   SourceLocation NewLoc = SM.getLocForStartOfFile(SM.getMainFileID());
 
+
+  const clang::FileEntry *FE = SM.getFileManager().getVirtualFile(
+      SourceName.str(), InputSize, 0 /* mod time*/);
+  SM.overrideFileContents(FE, std::move(MB));
+
   // Create FileID for the current buffer.
-  FileID FID = SM.createFileID(std::move(MB), SrcMgr::C_User, /*LoadedID=*/0,
-                               /*LoadedOffset=*/0, NewLoc);
+  FileID FID = SM.createFileID(FE, NewLoc, SrcMgr::C_User);
+  return {FID, NewLoc};
+}
+
+bool IncrementalParser::isCodeCompletionEnabled(){
+  return !(CI->getFrontendOpts().CodeCompletionAt.FileName.empty());
+}
+
+llvm::Expected<PartialTranslationUnit &>
+IncrementalParser::Parse(llvm::StringRef input) {
+  Preprocessor &PP = CI->getPreprocessor();
+  assert(PP.isIncrementalProcessingEnabled() && "Not in incremental mode!?");
+  std::ostringstream SourceName;
+
+  if (isCodeCompletionEnabled()) {
+    SourceName << CI->getFrontendOpts().CodeCompletionAt.FileName;
+  } else {
+    SourceName << "input_line_" << InputCount++;
+  }
+
+  auto [FID, SrcLoc] = createSourceFile(SourceName.str(), input);
+
+  if (isCodeCompletionEnabled()) {
+    // createCodeCompletionConsumer enables the code completion point, which
+    // must happen after the source file is created.
+    CI->createCodeCompletionConsumer();
+  }
 
   // NewLoc only used for diags.
-  if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, NewLoc))
+  if (PP.EnterSourceFile(FID, /*DirLookup=*/nullptr, SrcLoc))
     return llvm::make_error<llvm::StringError>("Parsing failed. "
                                                "Cannot enter source file.",
                                                std::error_code());
 
   auto PTU = ParseOrWrapTopLevelDecl();
-  if (!PTU)
-    return PTU.takeError();
+
+  if (!PTU) {
+    return std::move(PTU.takeError());
+  }
+
+  if (isCodeCompletionEnabled()) {
+    // there is no need to do extra lexing for code completion
+    return PTU;
+  }
 
   if (PP.getLangOpts().DelayedTemplateParsing) {
     // Microsoft-specific:
Index: clang/lib/Interpreter/ExternalSource.h
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/ExternalSource.h
@@ -0,0 +1,38 @@
+//==----- ExternalSource.h - External AST Source for Code Completion ------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines components that make declarations parsed and executed by
+// the interpreter visible to the context where code completion is being
+// triggered.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/AST/ExternalASTSource.h"
+
+namespace clang {
+class ASTContext;
+class FileManager;
+class ASTImporter;
+
+class ExternalSource : public clang::ExternalASTSource {
+  ASTContext &ChildASTCtxt;
+  TranslationUnitDecl *ChildTUDeclCtxt;
+  ASTContext &ParentASTCtxt;
+  TranslationUnitDecl *ParentTUDeclCtxt;
+
+  std::unique_ptr<ASTImporter> Importer;
+
+public:
+  ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                 ASTContext &ParentASTCtxt, FileManager &ParentFM);
+  bool FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                      DeclarationName Name) override;
+  void
+  completeVisibleDeclsMap(const clang::DeclContext *childDeclContext) override;
+};
+} // namespace clang
Index: clang/lib/Interpreter/ExternalSource.cpp
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/ExternalSource.cpp
@@ -0,0 +1,77 @@
+//===--- ExternalSource.cpp - External AST Source for Code Completion ----===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// The file implements classes that make declarations parsed and executed by the
+// interpreter visible to the context where code completion is being triggered.
+//
+//===----------------------------------------------------------------------===//
+
+#include "ExternalSource.h"
+#include "clang/AST/ASTImporter.h"
+#include "clang/AST/DeclarationName.h"
+#include "clang/Basic/IdentifierTable.h"
+
+namespace clang {
+ExternalSource::ExternalSource(ASTContext &ChildASTCtxt, FileManager &ChildFM,
+                               ASTContext &ParentASTCtxt, FileManager &ParentFM)
+    : ChildASTCtxt(ChildASTCtxt),
+      ChildTUDeclCtxt(ChildASTCtxt.getTranslationUnitDecl()),
+      ParentASTCtxt(ParentASTCtxt),
+      ParentTUDeclCtxt(ParentASTCtxt.getTranslationUnitDecl()) {
+  ASTImporter *importer =
+      new ASTImporter(ChildASTCtxt, ChildFM, ParentASTCtxt, ParentFM,
+                      /*MinimalImport : ON*/ true);
+  Importer.reset(importer);
+}
+
+bool ExternalSource::FindExternalVisibleDeclsByName(const DeclContext *DC,
+                                                    DeclarationName Name) {
+  IdentifierTable &ParentIdTable = ParentASTCtxt.Idents;
+
+  auto ParentDeclName =
+      DeclarationName(&(ParentIdTable.get(Name.getAsString())));
+
+  DeclContext::lookup_result lookup_result =
+      ParentTUDeclCtxt->lookup(ParentDeclName);
+
+  if (!lookup_result.empty()) {
+    return true;
+  }
+  return false;
+}
+
+void ExternalSource::completeVisibleDeclsMap(
+    const DeclContext *ChildDeclContext) {
+  assert(ChildDeclContext && ChildDeclContext == ChildTUDeclCtxt &&
+         "No child decl context!");
+
+  if (!ChildDeclContext->hasExternalVisibleStorage())
+    return;
+
+  for (auto *DeclCtxt = ParentTUDeclCtxt; DeclCtxt != nullptr;
+       DeclCtxt = DeclCtxt->getPreviousDecl()) {
+    for (auto &IDeclContext : DeclCtxt->decls()) {
+      if (NamedDecl *Decl = llvm::dyn_cast<NamedDecl>(IDeclContext)) {
+        if (auto DeclOrErr = Importer->Import(Decl)) {
+          if (NamedDecl *importedNamedDecl =
+                  llvm::dyn_cast<NamedDecl>(*DeclOrErr)) {
+            SetExternalVisibleDeclsForName(ChildDeclContext,
+                                           importedNamedDecl->getDeclName(),
+                                           importedNamedDecl);
+          }
+
+        } else {
+          llvm::consumeError(DeclOrErr.takeError());
+        }
+      }
+    }
+    ChildDeclContext->setHasExternalLexicalStorage(false);
+  }
+}
+
+} // namespace clang
Index: clang/lib/Interpreter/DeviceOffload.cpp
===================================================================
--- clang/lib/Interpreter/DeviceOffload.cpp
+++ clang/lib/Interpreter/DeviceOffload.cpp
@@ -15,19 +15,21 @@
 #include "clang/Basic/TargetOptions.h"
 #include "clang/CodeGen/ModuleBuilder.h"
 #include "clang/Frontend/CompilerInstance.h"
+#include "clang/Sema/CodeCompleteConsumer.h"
 
 #include "llvm/IR/LegacyPassManager.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Target/TargetMachine.h"
 
 namespace clang {
+std::vector<CodeCompletionResult> DummyResult;
 
 IncrementalCUDADeviceParser::IncrementalCUDADeviceParser(
     Interpreter &Interp, std::unique_ptr<CompilerInstance> Instance,
     IncrementalParser &HostParser, llvm::LLVMContext &LLVMCtx,
     llvm::IntrusiveRefCntPtr<llvm::vfs::InMemoryFileSystem> FS,
     llvm::Error &Err)
-    : IncrementalParser(Interp, std::move(Instance), LLVMCtx, Err),
+  : IncrementalParser(Interp, std::move(Instance), LLVMCtx, Err, nullptr, DummyResult),
       HostParser(HostParser), VFS(FS) {
   if (Err)
     return;
Index: clang/lib/Interpreter/CodeCompletion.cpp
===================================================================
--- /dev/null
+++ clang/lib/Interpreter/CodeCompletion.cpp
@@ -0,0 +1,77 @@
+//===------ CodeCompletion.cpp - Code Completion for ClangRepl -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#include "clang/Interpreter/CodeCompletion.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Interpreter/Interpreter.h"
+#include "clang/Lex/PreprocessorOptions.h"
+#include "clang/Sema/CodeCompleteConsumer.h"
+#include "clang/Sema/CodeCompleteOptions.h"
+#include "clang/Sema/Sema.h"
+
+namespace clang {
+
+clang::CodeCompleteOptions getClangCompleteOpts() {
+  clang::CodeCompleteOptions Opts;
+  Opts.IncludeCodePatterns = true;
+  Opts.IncludeMacros = true;
+  Opts.IncludeGlobals = true;
+  Opts.IncludeBriefComments = true;
+  return Opts;
+}
+
+ReplCompletionConsumer::ReplCompletionConsumer(std::vector<CodeCompletionResult> &Results)
+      : CodeCompleteConsumer(getClangCompleteOpts()),
+        CCAllocator(std::make_shared<GlobalCodeCompletionAllocator>()),
+        CCTUInfo(CCAllocator), Results(Results){};
+
+void ReplCompletionConsumer::ProcessCodeCompleteResults(
+    class Sema &S, CodeCompletionContext Context,
+    CodeCompletionResult *InResults, unsigned NumResults) {
+  for (unsigned I = 0; I < NumResults; ++I) {
+    auto &Result = InResults[I];
+    switch (Result.Kind) {
+    case CodeCompletionResult::RK_Declaration:
+      if (Result.Declaration->getIdentifier()) {
+        Results.push_back(Result);
+      }
+      break;
+    default:
+      break;
+    case CodeCompletionResult::RK_Keyword:
+      Results.push_back(Result);
+      break;
+    }
+  }
+}
+
+std::vector<llvm::StringRef> ConvertToCodeCompleteStrings(const std::vector<clang::CodeCompletionResult> &Results) {
+  std::vector<llvm::StringRef> CompletionStrings;
+  for (auto Res : Results) {
+    switch (Res.Kind) {
+    case clang::CodeCompletionResult::RK_Declaration:
+      if (auto *ID = Res.Declaration->getIdentifier()) {
+        CompletionStrings.push_back(ID->getName());
+      }
+      break;
+    case clang::CodeCompletionResult::RK_Keyword:
+      CompletionStrings.push_back(Res.Keyword);
+      break;
+    default:
+      break;
+    }
+  }
+  return CompletionStrings;
+}
+
+
+} // namespace clang
Index: clang/lib/Interpreter/CMakeLists.txt
===================================================================
--- clang/lib/Interpreter/CMakeLists.txt
+++ clang/lib/Interpreter/CMakeLists.txt
@@ -12,7 +12,9 @@
   )
 
 add_clang_library(clangInterpreter
+  CodeCompletion.cpp
   DeviceOffload.cpp
+  ExternalSource.cpp
   IncrementalExecutor.cpp
   IncrementalParser.cpp
   Interpreter.cpp
Index: clang/lib/Frontend/ASTUnit.cpp
===================================================================
--- clang/lib/Frontend/ASTUnit.cpp
+++ clang/lib/Frontend/ASTUnit.cpp
@@ -2005,6 +2005,7 @@
   case CodeCompletionContext::CCC_SymbolOrNewName:
   case CodeCompletionContext::CCC_ParenthesizedExpression:
   case CodeCompletionContext::CCC_ObjCInterfaceName:
+  case CodeCompletionContext::CCC_ReplTopLevel:
     break;
 
   case CodeCompletionContext::CCC_EnumTag:
Index: clang/include/clang/Sema/Sema.h
===================================================================
--- clang/include/clang/Sema/Sema.h
+++ clang/include/clang/Sema/Sema.h
@@ -13319,7 +13319,9 @@
     PCC_ParenthesizedExpression,
     /// Code completion occurs within a sequence of declaration
     /// specifiers within a function, method, or block.
-    PCC_LocalDeclarationSpecifiers
+    PCC_LocalDeclarationSpecifiers,
+    /// Code completion occurs at top-level in a REPL session
+    PCC_TopLevelStmtDecl,
   };
 
   void CodeCompleteModuleImport(SourceLocation ImportLoc, ModuleIdPath Path);
Index: clang/include/clang/Sema/CodeCompleteConsumer.h
===================================================================
--- clang/include/clang/Sema/CodeCompleteConsumer.h
+++ clang/include/clang/Sema/CodeCompleteConsumer.h
@@ -336,7 +336,10 @@
     CCC_Recovery,
 
     /// Code completion in a @class forward declaration.
-    CCC_ObjCClassForwardDecl
+    CCC_ObjCClassForwardDecl,
+
+    /// Code completion at a top level in a REPL session.
+    CCC_ReplTopLevel,
   };
 
   using VisitedContextSet = llvm::SmallPtrSet<DeclContext *, 8>;
Index: clang/include/clang/Interpreter/Interpreter.h
===================================================================
--- clang/include/clang/Interpreter/Interpreter.h
+++ clang/include/clang/Interpreter/Interpreter.h
@@ -35,6 +35,8 @@
 
 namespace clang {
 
+class CodeCompleteConsumer;
+class CodeCompletionResult;
 class CompilerInstance;
 class IncrementalExecutor;
 class IncrementalParser;
@@ -72,6 +74,14 @@
   llvm::StringRef CudaSDKPath;
 };
 
+const std::string CodeCompletionFileName = "input_line_[Completion]";
+struct CodeCompletionCfg {
+  size_t Col;
+  size_t Line = 1;
+  CompilerInstance *ParentCI = nullptr;
+  std::vector<CodeCompletionResult> &CCResult;
+};
+
 /// Provides top-level interfaces for incremental compilation and execution.
 class Interpreter {
   std::unique_ptr<llvm::orc::ThreadSafeContext> TSCtx;
@@ -81,7 +91,9 @@
   // An optional parser for CUDA offloading
   std::unique_ptr<IncrementalParser> DeviceParser;
 
-  Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err);
+  Interpreter(std::unique_ptr<CompilerInstance> CI, llvm::Error &Err,
+              std::vector<CodeCompletionResult> &CCResult,
+              const CompilerInstance *ParentCI = nullptr);
 
   llvm::Error CreateExecutor();
   unsigned InitPTUSize = 0;
@@ -93,11 +105,15 @@
 
 public:
   ~Interpreter();
+
   static llvm::Expected<std::unique_ptr<Interpreter>>
-  create(std::unique_ptr<CompilerInstance> CI);
+  create(std::unique_ptr<CompilerInstance> CI,
+         std::optional<CodeCompletionCfg> CCCfg = std::nullopt);
+
   static llvm::Expected<std::unique_ptr<Interpreter>>
   createWithCUDA(std::unique_ptr<CompilerInstance> CI,
                  std::unique_ptr<CompilerInstance> DCI);
+
   const ASTContext &getASTContext() const;
   ASTContext &getASTContext();
   const CompilerInstance *getCompilerInstance() const;
Index: clang/include/clang/Interpreter/CodeCompletion.h
===================================================================
--- /dev/null
+++ clang/include/clang/Interpreter/CodeCompletion.h
@@ -0,0 +1,38 @@
+//===------ CodeCompletion.h - Code Completion for ClangRepl -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines the classes which performs code completion at the REPL.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#define LLVM_CLANG_INTERPRETER_CODE_COMPLETION_H
+#include "clang/Sema/CodeCompleteConsumer.h"
+
+namespace clang {
+class ReplCompletionConsumer : public CodeCompleteConsumer {
+public:
+  ReplCompletionConsumer(std::vector<CodeCompletionResult> &Results);
+  void ProcessCodeCompleteResults(class Sema &S, CodeCompletionContext Context,
+                                  CodeCompletionResult *InResults,
+                                  unsigned NumResults) final;
+
+  CodeCompletionAllocator &getAllocator() override { return *CCAllocator; }
+
+  CodeCompletionTUInfo &getCodeCompletionTUInfo() override { return CCTUInfo; }
+
+private:
+  std::shared_ptr<GlobalCodeCompletionAllocator> CCAllocator;
+  CodeCompletionTUInfo CCTUInfo;
+  std::vector<CodeCompletionResult> &Results;
+};
+
+std::vector<llvm::StringRef> ConvertToCodeCompleteStrings(
+    const std::vector<clang::CodeCompletionResult> &Results);
+} // namespace clang
+#endif
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to