ioeric updated this revision to Diff 182799.
ioeric added a comment.

- revert unintended change


Repository:
  rCTE Clang Tools Extra

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D57021/new/

https://reviews.llvm.org/D57021

Files:
  clangd/CMakeLists.txt
  clangd/ClangdServer.cpp
  clangd/ClangdUnit.cpp
  clangd/ClangdUnit.h
  clangd/CodeComplete.cpp
  clangd/Diagnostics.cpp
  clangd/Diagnostics.h
  clangd/Headers.cpp
  clangd/Headers.h
  clangd/IncludeFixer.cpp
  clangd/IncludeFixer.h
  clangd/SourceCode.cpp
  clangd/SourceCode.h
  unittests/clangd/CMakeLists.txt
  unittests/clangd/FileIndexTests.cpp
  unittests/clangd/IncludeFixerTests.cpp
  unittests/clangd/TUSchedulerTests.cpp
  unittests/clangd/TestTU.cpp
  unittests/clangd/TestTU.h

Index: unittests/clangd/TestTU.h
===================================================================
--- unittests/clangd/TestTU.h
+++ unittests/clangd/TestTU.h
@@ -48,6 +48,9 @@
   // Extra arguments for the compiler invocation.
   std::vector<const char *> ExtraArgs;
 
+  // Index to use when building AST.
+  const SymbolIndex *ExternalIndex = nullptr;
+
   ParsedAST build() const;
   SymbolSlab headerSymbols() const;
   std::unique_ptr<SymbolIndex> index() const;
Index: unittests/clangd/TestTU.cpp
===================================================================
--- unittests/clangd/TestTU.cpp
+++ unittests/clangd/TestTU.cpp
@@ -35,6 +35,7 @@
   Inputs.CompileCommand.Directory = testRoot();
   Inputs.Contents = Code;
   Inputs.FS = buildTestFS({{FullFilename, Code}, {FullHeaderName, HeaderCode}});
+  Inputs.Index = ExternalIndex;
   auto PCHs = std::make_shared<PCHContainerOperations>();
   auto CI = buildCompilerInvocation(Inputs);
   assert(CI && "Failed to build compilation invocation.");
Index: unittests/clangd/TUSchedulerTests.cpp
===================================================================
--- unittests/clangd/TUSchedulerTests.cpp
+++ unittests/clangd/TUSchedulerTests.cpp
@@ -37,8 +37,9 @@
 class TUSchedulerTests : public ::testing::Test {
 protected:
   ParseInputs getInputs(PathRef File, std::string Contents) {
-    return ParseInputs{*CDB.getCompileCommand(File),
-                       buildTestFS(Files, Timestamps), std::move(Contents)};
+    return ParseInputs(*CDB.getCompileCommand(File),
+                       buildTestFS(Files, Timestamps), std::move(Contents),
+                       /*Index=*/nullptr);
   }
 
   void updateWithCallback(TUScheduler &S, PathRef File,
Index: unittests/clangd/IncludeFixerTests.cpp
===================================================================
--- /dev/null
+++ unittests/clangd/IncludeFixerTests.cpp
@@ -0,0 +1,156 @@
+//===-- ClangdUnitTests.cpp - ClangdUnit tests ------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "Annotations.h"
+#include "ClangdUnit.h"
+#include "IncludeFixer.h"
+#include "TestIndex.h"
+#include "TestTU.h"
+#include "index/MemIndex.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/StringRef.h"
+#include "llvm/Support/ScopedPrinter.h"
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace clang {
+namespace clangd {
+namespace {
+
+using testing::UnorderedElementsAre;
+
+testing::Matcher<const Diag &> WithFix(testing::Matcher<Fix> FixMatcher) {
+  return Field(&Diag::Fixes, UnorderedElementsAre(FixMatcher));
+}
+
+testing::Matcher<const Diag &> WithFix(testing::Matcher<Fix> FixMatcher1,
+                                       testing::Matcher<Fix> FixMatcher2) {
+  return Field(&Diag::Fixes, UnorderedElementsAre(FixMatcher1, FixMatcher2));
+}
+
+MATCHER_P2(Diag, Range, Message,
+           "Diag at " + llvm::to_string(Range) + " = [" + Message + "]") {
+  return arg.Range == Range && arg.Message == Message;
+}
+
+MATCHER_P3(Fix, Range, Replacement, Message,
+           "Fix " + llvm::to_string(Range) + " => " +
+               testing::PrintToString(Replacement) + " = [" + Message + "]") {
+  return arg.Message == Message && arg.Edits.size() == 1 &&
+         arg.Edits[0].range == Range && arg.Edits[0].newText == Replacement;
+}
+
+struct SymbolWithHeader {
+  std::string QName;
+  std::string DeclaringFile;
+  std::string IncludeHeader;
+};
+
+std::unique_ptr<SymbolIndex> buildIndexWithSymbol(llvm::ArrayRef<SymbolWithHeader> Syms) {
+  SymbolSlab::Builder Slab;
+  for (const auto &S : Syms) {
+    Symbol Sym = symbol(S.QName);
+    Sym.Flags |= Symbol::IndexedForCodeCompletion;
+    Sym.CanonicalDeclaration.FileURI = S.DeclaringFile.c_str();
+    Sym.IncludeHeaders.emplace_back(S.IncludeHeader, 1);
+    Slab.insert(Sym);
+  }
+  return MemIndex::build(std::move(Slab).build(), RefSlab());
+}
+
+TEST(IncludeFixerTest, IncompleteType) {
+  Annotations Test(R"cpp(
+$insert[[]]namespace ns {
+  class X;
+}
+class Y : $base[[public ns::X]] {};
+int main() {
+  ns::X *x;
+  x$access[[->]]f();
+}
+  )cpp");
+  auto TU = TestTU::withCode(Test.code());
+  auto Index = buildIndexWithSymbol(
+      {SymbolWithHeader{"ns::X", "unittest:///x.h", "\"x.h\""}});
+  TU.ExternalIndex = Index.get();
+
+  EXPECT_THAT(
+      TU.build().getDiagnostics(),
+      UnorderedElementsAre(
+          AllOf(Diag(Test.range("base"), "base class has incomplete type"),
+                WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n",
+                            "Add include \"x.h\" for symbol ns::X"))),
+          AllOf(Diag(Test.range("access"),
+                     "member access into incomplete type 'ns::X'"),
+                WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n",
+                            "Add include \"x.h\" for symbol ns::X")))));
+}
+
+TEST(IncludeFixerTest, Typo) {
+  Annotations Test(R"cpp(
+$insert[[]]namespace ns {
+void foo() {
+  $unqualified[[X]] x;
+}
+}
+void bar() {
+  ns::$qualified[[X]] x; // ns:: is valid.
+  ::$global[[Global]] glob;
+}
+  )cpp");
+  auto TU = TestTU::withCode(Test.code());
+  auto Index = buildIndexWithSymbol(
+      {SymbolWithHeader{"ns::X", "unittest:///x.h", "\"x.h\""},
+       SymbolWithHeader{"Global", "unittest:///global.h", "\"global.h\""}});
+  TU.ExternalIndex = Index.get();
+
+  EXPECT_THAT(
+      TU.build().getDiagnostics(),
+      UnorderedElementsAre(
+          AllOf(Diag(Test.range("unqualified"), "unknown type name 'X'"),
+                WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n",
+                            "Add include \"x.h\" for symbol ns::X"))),
+          AllOf(Diag(Test.range("qualified"),
+                     "no type named 'X' in namespace 'ns'"),
+                WithFix(Fix(Test.range("insert"), "#include \"x.h\"\n",
+                            "Add include \"x.h\" for symbol ns::X"))),
+          AllOf(Diag(Test.range("global"),
+                     "no type named 'Global' in the global namespace"),
+                WithFix(Fix(Test.range("insert"), "#include \"global.h\"\n",
+                            "Add include \"global.h\" for symbol Global")))));
+}
+
+TEST(IncludeFixerTest, MultipleMatchedSymbols) {
+  Annotations Test(R"cpp(
+$insert[[]]namespace na {
+namespace nb {
+void foo() {
+  $unqualified[[X]] x;
+}
+}
+}
+  )cpp");
+  auto TU = TestTU::withCode(Test.code());
+  auto Index = buildIndexWithSymbol(
+      {SymbolWithHeader{"na::X", "unittest:///a.h", "\"a.h\""},
+       SymbolWithHeader{"na::nb::X", "unittest:///b.h", "\"b.h\""}});
+  TU.ExternalIndex = Index.get();
+
+  EXPECT_THAT(TU.build().getDiagnostics(),
+              UnorderedElementsAre(AllOf(
+                  Diag(Test.range("unqualified"), "unknown type name 'X'"),
+                  WithFix(Fix(Test.range("insert"), "#include \"a.h\"\n",
+                              "Add include \"a.h\" for symbol na::X"),
+                          Fix(Test.range("insert"), "#include \"b.h\"\n",
+                              "Add include \"b.h\" for symbol na::nb::X")))));
+}
+
+} // namespace
+} // namespace clangd
+} // namespace clang
Index: unittests/clangd/FileIndexTests.cpp
===================================================================
--- unittests/clangd/FileIndexTests.cpp
+++ unittests/clangd/FileIndexTests.cpp
@@ -360,10 +360,10 @@
       /*StoreInMemory=*/true,
       [&](ASTContext &Ctx, std::shared_ptr<Preprocessor> PP) {});
   // Build AST for main file with preamble.
-  auto AST =
-      ParsedAST::build(createInvocationFromCommandLine(Cmd), PreambleData,
-                       llvm::MemoryBuffer::getMemBufferCopy(Main.code()),
-                       std::make_shared<PCHContainerOperations>(), PI.FS);
+  auto AST = ParsedAST::build(
+      createInvocationFromCommandLine(Cmd), PreambleData,
+      llvm::MemoryBuffer::getMemBufferCopy(Main.code()),
+      std::make_shared<PCHContainerOperations>(), PI.FS, /*Index=*/nullptr);
   ASSERT_TRUE(AST);
   FileIndex Index;
   Index.updateMain(MainFile, *AST);
Index: unittests/clangd/CMakeLists.txt
===================================================================
--- unittests/clangd/CMakeLists.txt
+++ unittests/clangd/CMakeLists.txt
@@ -28,6 +28,7 @@
   FuzzyMatchTests.cpp
   GlobalCompilationDatabaseTests.cpp
   HeadersTests.cpp
+  IncludeFixerTests.cpp
   IndexActionTests.cpp
   IndexTests.cpp
   JSONTransportTests.cpp
Index: clangd/SourceCode.h
===================================================================
--- clangd/SourceCode.h
+++ clangd/SourceCode.h
@@ -16,7 +16,9 @@
 #include "clang/Basic/Diagnostic.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Basic/SourceManager.h"
+#include "clang/Format/Format.h"
 #include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/Support/SHA1.h"
 
 namespace clang {
@@ -91,6 +93,11 @@
                                              const SourceManager &SourceMgr);
 
 bool IsRangeConsecutive(const Range &Left, const Range &Right);
+
+format::FormatStyle getFormatStyleForFile(llvm::StringRef File,
+                                          llvm::StringRef Content,
+                                          llvm::vfs::FileSystem *FS);
+
 } // namespace clangd
 } // namespace clang
 #endif
Index: clangd/SourceCode.cpp
===================================================================
--- clangd/SourceCode.cpp
+++ clangd/SourceCode.cpp
@@ -248,5 +248,18 @@
   return digest(Content);
 }
 
+format::FormatStyle getFormatStyleForFile(llvm::StringRef File,
+                                          llvm::StringRef Content,
+                                          llvm::vfs::FileSystem *FS) {
+  auto Style = format::getStyle(format::DefaultFormatStyle, File,
+                                format::DefaultFallbackStyle, Content, FS);
+  if (!Style) {
+    log("getStyle() failed for file {0}: {1}. Fallback is LLVM style.", File,
+        Style.takeError());
+    Style = format::getLLVMStyle();
+  }
+  return *Style;
+}
+
 } // namespace clangd
 } // namespace clang
Index: clangd/IncludeFixer.h
===================================================================
--- /dev/null
+++ clangd/IncludeFixer.h
@@ -0,0 +1,102 @@
+//===- IncludeFixer.h - Add missing includes --------------------*- C++ -*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H
+#define LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H
+
+#include "Diagnostics.h"
+#include "Headers.h"
+#include "index/Index.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Frontend/CompilerInstance.h"
+#include "clang/Sema/ExternalSemaSource.h"
+#include "clang/Sema/Sema.h"
+#include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/IntrusiveRefCntPtr.h"
+#include "llvm/ADT/Optional.h"
+#include "llvm/ADT/StringRef.h"
+#include <memory>
+
+namespace clang {
+namespace clangd {
+
+/// Attempts to recover from error diagnostics by suggesting include insertion
+/// fixes. For example, member access into incomplete type can be fixes by
+/// include headers with the definition.
+class IncludeFixer {
+public:
+  IncludeFixer(CompilerInstance &Compiler, llvm::StringRef File,
+               std::unique_ptr<IncludeInserter> Inserter,
+               const SymbolIndex &Index)
+      : File(File), Inserter(std::move(Inserter)), Index(Index),
+        Compiler(Compiler), RecordTypo(new TypoRecorder(Compiler)) {}
+
+  /// Returns include insertions that can potentially recover the diagnostic.
+  std::vector<Fix> fix(DiagnosticsEngine::Level DiagLevel,
+                       const clang::Diagnostic &Info) const;
+
+  /// Returns an ExternalSemaSource that records typos seen in Sema. It must be
+  /// used in the same Sema run as the IncludeFixer.
+  llvm::IntrusiveRefCntPtr<ExternalSemaSource> typoRecorder() {
+    return RecordTypo;
+  }
+
+private:
+  std::vector<Fix> fixInCompleteType(const Type &T) const;
+
+  std::vector<Fix> fixesForSymbol(const Symbol &Sym) const;
+
+  struct TypoRecord {
+    std::string Typo;   // The typo identifier e.g. "X" in ns::X.
+    SourceLocation Loc; // Location of the typo.
+    Scope *S;           // Scope in which the typo is found.
+    llvm::Optional<std::string> SS; // The scope qualifier before the typo.
+    Sema::LookupNameKind LookupKind; // LookupKind of the typo.
+  };
+
+  /// Records the last typo seen by Sema.
+  class TypoRecorder : public ExternalSemaSource {
+  public:
+    TypoRecorder(CompilerInstance &Compiler) : Compiler(Compiler) {}
+
+    // Captures the latest typo.
+    TypoCorrection CorrectTypo(const DeclarationNameInfo &Typo, int LookupKind,
+                               Scope *S, CXXScopeSpec *SS,
+                               CorrectionCandidateCallback &CCC,
+                               DeclContext *MemberContext, bool EnteringContext,
+                               const ObjCObjectPointerType *OPT) override;
+
+    llvm::Optional<TypoRecord> lastTypo() const { return LastTypo; }
+
+  private:
+    CompilerInstance &Compiler;
+
+    llvm::Optional<TypoRecord> LastTypo;
+  };
+
+  /// Attempts to fix the typo associated with the current diagnostic. We assume
+  /// a diagnostic is caused by a typo when they have the same source location
+  /// and the typo is the last typo we've seen during the Sema run.
+  std::vector<Fix> fixTypo(const TypoRecord &Typo) const;
+
+  std::string File;
+  std::unique_ptr<IncludeInserter> Inserter;
+  const SymbolIndex &Index;
+  CompilerInstance &Compiler;
+  // This collects the last typo so that we can associate it with the
+  // diagnostic.
+  llvm::IntrusiveRefCntPtr<TypoRecorder> RecordTypo;
+};
+
+} // namespace clangd
+} // namespace clang
+
+#endif // LLVM_CLANG_TOOLS_EXTRA_CLANGD_INCLUDE_FIXER_H
Index: clangd/IncludeFixer.cpp
===================================================================
--- /dev/null
+++ clangd/IncludeFixer.cpp
@@ -0,0 +1,250 @@
+//===--- IncludeFixer.cpp ----------------------------------------*- C++-*-===//
+//
+//                     The LLVM Compiler Infrastructure
+//
+// This file is distributed under the University of Illinois Open Source
+// License. See LICENSE.TXT for details.
+//
+//===----------------------------------------------------------------------===//
+
+#include "IncludeFixer.h"
+#include "AST.h"
+#include "Diagnostics.h"
+#include "Logger.h"
+#include "SourceCode.h"
+#include "index/Index.h"
+#include "clang/AST/Decl.h"
+#include "clang/AST/DeclBase.h"
+#include "clang/AST/NestedNameSpecifier.h"
+#include "clang/AST/Type.h"
+#include "clang/Basic/Diagnostic.h"
+#include "clang/Basic/DiagnosticSema.h"
+#include "clang/Sema/DeclSpec.h"
+#include "clang/Sema/Lookup.h"
+#include "clang/Sema/Scope.h"
+#include "clang/Sema/Sema.h"
+#include "clang/Sema/TypoCorrection.h"
+#include "llvm/ADT/None.h"
+#include "llvm/Support/Error.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <vector>
+
+namespace clang {
+namespace clangd {
+
+namespace {
+
+bool isIncompleteTypeDiag(unsigned int DiagID) {
+  return DiagID == diag::err_incomplete_type ||
+         DiagID == diag::err_incomplete_member_access ||
+         DiagID == diag::err_incomplete_base_class;
+}
+
+// Collects contexts visited during a Sema name lookup.
+class VisitedContextCollector : public VisibleDeclConsumer {
+public:
+  void EnteredContext(DeclContext *Ctx) override { Visited.push_back(Ctx); }
+
+  void FoundDecl(NamedDecl *ND, NamedDecl *Hiding, DeclContext *Ctx,
+                 bool InBaseClass) override {}
+
+  std::vector<DeclContext *> takeVisitedContexts() {
+    return std::move(Visited);
+  }
+
+private:
+  std::vector<DeclContext *> Visited;
+};
+
+} // namespace
+
+std::vector<Fix> IncludeFixer::fix(DiagnosticsEngine::Level DiagLevel,
+                                   const clang::Diagnostic &Info) const {
+  if (isIncompleteTypeDiag(Info.getID())) {
+    // Incomplete type diagnostics should have a QualType argument for the
+    // incomplete type.
+    for (unsigned i = 0; i < Info.getNumArgs(); ++i) {
+      if (Info.getArgKind(i) == DiagnosticsEngine::ak_qualtype) {
+        auto QT = QualType::getFromOpaquePtr((void *)Info.getRawArg(i));
+        if (const Type *T = QT.getTypePtrOrNull())
+          if (T->isIncompleteType())
+            return fixInCompleteType(*T);
+      }
+    }
+  } else if (auto LastTypo = RecordTypo->lastTypo()) {
+    // Try to fix typos caused by missing declaraion.
+    // E.g.
+    //   clang::SourceManager SM;
+    //          ~~~~~~~~~~~~~
+    //          Typo
+    //   or
+    //   namespace clang {  SourceManager SM; }
+    //                      ~~~~~~~~~~~~~
+    //                      Typo
+    // We only attempt to recover a diagnostic if it has the same location as
+    // the last seen typo.
+    if (DiagLevel >= DiagnosticsEngine::Error &&
+        LastTypo->Loc == Info.getLocation())
+      return fixTypo(*LastTypo);
+  }
+  return {};
+}
+
+std::vector<Fix> IncludeFixer::fixInCompleteType(const Type &T) const {
+  // Only handle incomplete TagDecl type.
+  const TagDecl *TD = T.getAsTagDecl();
+  if (!TD)
+    return {};
+  std::string IncompleteType = printQualifiedName(*TD);
+
+  if (IncompleteType.empty()) {
+    vlog("No incomplete type name is found in diagnostic. Ignore.");
+    return {};
+  }
+  vlog("Trying to fix include for incomplete type {0}", IncompleteType);
+  FuzzyFindRequest Req;
+  Req.AnyScope = false;
+  auto ScopeAndName = splitQualifiedName(IncompleteType);
+  Req.Scopes.push_back(ScopeAndName.first);
+  Req.Query = ScopeAndName.second;
+  // Only code completion symbols insert includes.
+  Req.RestrictForCodeCompletion = true;
+  llvm::Optional<Symbol> Matched;
+  Index.fuzzyFind(Req, [&](const Symbol &Sym) {
+    // FIXME: support multiple matched symbols.
+    if (Matched || Sym.Name != Req.Query)
+      return;
+    Matched = Sym;
+  });
+
+  if (!Matched || Matched->IncludeHeaders.empty())
+    return {};
+  return fixesForSymbol(*Matched);
+}
+
+std::vector<Fix> IncludeFixer::fixesForSymbol(const Symbol &Sym) const {
+  auto Inserted = [&](llvm::StringRef Header)
+      -> llvm::Expected<std::pair<std::string, bool>> {
+    auto ResolvedDeclaring =
+        toHeaderFile(Sym.CanonicalDeclaration.FileURI, File);
+    if (!ResolvedDeclaring)
+      return ResolvedDeclaring.takeError();
+    auto ResolvedInserted = toHeaderFile(Header, File);
+    if (!ResolvedInserted)
+      return ResolvedInserted.takeError();
+    return std::make_pair(
+        Inserter->calculateIncludePath(*ResolvedDeclaring, *ResolvedInserted),
+        Inserter->shouldInsertInclude(*ResolvedDeclaring, *ResolvedInserted));
+  };
+
+  std::vector<Fix> Fixes;
+  for (const auto &Inc : getRankedIncludes(Sym)) {
+    if (auto ToInclude = Inserted(Inc)) {
+      if (ToInclude->second)
+        if (auto Edit = Inserter->insert(ToInclude->first))
+          Fixes.push_back(
+              Fix{llvm::formatv("Add include {0} for symbol {1}{2}",
+                                ToInclude->first, Sym.Scope, Sym.Name),
+                  {std::move(*Edit)}});
+    } else {
+      vlog("Failed to calculate include insertion for {0} into {1}: {2}", File,
+           Inc, llvm::toString(ToInclude.takeError()));
+    }
+  }
+  return Fixes;
+}
+
+TypoCorrection IncludeFixer::TypoRecorder::CorrectTypo(
+    const DeclarationNameInfo &Typo, int LookupKind, Scope *S, CXXScopeSpec *SS,
+    CorrectionCandidateCallback &CCC, DeclContext *MemberContext,
+    bool EnteringContext, const ObjCObjectPointerType *OPT) {
+  if (Compiler.getSema().isSFINAEContext())
+    return TypoCorrection();
+  if (!Compiler.getSourceManager().isWrittenInMainFile(Typo.getLoc()))
+    return clang::TypoCorrection();
+
+  TypoRecord Record;
+  Record.Typo = Typo.getAsString();
+  Record.Loc = Typo.getBeginLoc();
+  assert(S);
+  Record.S = S;
+  Record.LookupKind = static_cast<Sema::LookupNameKind>(LookupKind);
+
+  // FIXME: support invalid scope before a type name. In the following example,
+  // namespace "clang::tidy::" hasn't been declared/imported.
+  //    namespace clang {
+  //    void f() {
+  //      tidy::Check c;
+  //      ~~~~
+  //      // or
+  //      clang::tidy::Check c;
+  //             ~~~~
+  //    }
+  //    }
+  // For both cases, the typo and the diagnostic are both on "tidy", and no
+  // diagnostic is generated for "Check". However, what we want to fix is
+  // "clang::tidy::Check".
+  if (SS && SS->isNotEmpty()) { // "::" or "ns::"
+    if (auto *Nested = SS->getScopeRep()) {
+      if (Nested->getKind() == NestedNameSpecifier::Global)
+        Record.SS = "";
+      else if (const auto *NS = Nested->getAsNamespace())
+        Record.SS = printNamespaceScope(*NS);
+      else
+        // We don't fix symbols in scopes that are not top-level e.g. class
+        // members, as we don't collect includes for them.
+        return TypoCorrection();
+    }
+  }
+
+  LastTypo = std::move(Record);
+
+  return TypoCorrection();
+}
+
+std::vector<Fix> IncludeFixer::fixTypo(const TypoRecord &Typo) const {
+  std::vector<std::string> Scopes;
+  if (Typo.SS) {
+    Scopes.push_back(*Typo.SS);
+  } else {
+    // No scope qualifier is specified. Collect all accessible scopes in the
+    // context.
+    VisitedContextCollector Collector;
+    Compiler.getSema().LookupVisibleDecls(Typo.S, Typo.LookupKind, Collector,
+                                          /*IncludeGlobalScope=*/false,
+                                          /*LoadExternal=*/false);
+
+    Scopes.push_back("");
+    for (const auto *Ctx : Collector.takeVisitedContexts())
+      if (isa<NamespaceDecl>(Ctx))
+        Scopes.push_back(printNamespaceScope(*Ctx));
+  }
+  vlog("Trying to fix typo \"{0}\" in scopes: [{1}]", Typo.Typo,
+       llvm::join(Scopes.begin(), Scopes.end(), ", "));
+
+  FuzzyFindRequest Req;
+  Req.AnyScope = false;
+  Req.Query = Typo.Typo;
+  Req.Scopes = Scopes;
+  Req.RestrictForCodeCompletion = true;
+
+  SymbolSlab::Builder Matches;
+  Index.fuzzyFind(Req, [&](const Symbol &Sym) {
+    if (Sym.Name != Req.Query)
+      return;
+    if (!Sym.IncludeHeaders.empty())
+      Matches.insert(Sym);
+  });
+  auto Syms = std::move(Matches).build();
+  if (Syms.empty())
+    return {};
+  std::vector<Fix> Results;
+  for (const auto &Sym : Syms) {
+    auto Fixes = fixesForSymbol(Sym);
+    Results.insert(Results.end(), Fixes.begin(), Fixes.end());
+  }
+  return Results;
+}
+
+} // namespace clangd
+} // namespace clang
Index: clangd/Headers.h
===================================================================
--- clangd/Headers.h
+++ clangd/Headers.h
@@ -12,10 +12,12 @@
 #include "Path.h"
 #include "Protocol.h"
 #include "SourceCode.h"
+#include "index/Index.h"
 #include "clang/Format/Format.h"
 #include "clang/Lex/HeaderSearch.h"
 #include "clang/Lex/PPCallbacks.h"
 #include "clang/Tooling/Inclusions/HeaderIncludes.h"
+#include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/Support/Error.h"
@@ -37,6 +39,15 @@
   bool valid() const;
 };
 
+/// Creates a `HeaderFile` from \p Header which can be either a URI or a literal
+/// include.
+llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
+                                        llvm::StringRef HintPath);
+
+// Returns include headers for \p Sym sorted by popularity. If two headers are
+// equally popular, prefer the shorter one.
+llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym);
+
 // An #include directive that we found in the main file.
 struct Inclusion {
   Range R;             // Inclusion range.
Index: clangd/Headers.cpp
===================================================================
--- clangd/Headers.cpp
+++ clangd/Headers.cpp
@@ -73,6 +73,41 @@
          (!Verbatim && llvm::sys::path::is_absolute(File));
 }
 
+llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
+                                        llvm::StringRef HintPath) {
+  if (isLiteralInclude(Header))
+    return HeaderFile{Header.str(), /*Verbatim=*/true};
+  auto U = URI::parse(Header);
+  if (!U)
+    return U.takeError();
+
+  auto IncludePath = URI::includeSpelling(*U);
+  if (!IncludePath)
+    return IncludePath.takeError();
+  if (!IncludePath->empty())
+    return HeaderFile{std::move(*IncludePath), /*Verbatim=*/true};
+
+  auto Resolved = URI::resolve(*U, HintPath);
+  if (!Resolved)
+    return Resolved.takeError();
+  return HeaderFile{std::move(*Resolved), /*Verbatim=*/false};
+}
+
+llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym) {
+  auto Includes = Sym.IncludeHeaders;
+  // Sort in descending order by reference count and header length.
+  llvm::sort(Includes, [](const Symbol::IncludeHeaderWithReferences &LHS,
+                          const Symbol::IncludeHeaderWithReferences &RHS) {
+    if (LHS.References == RHS.References)
+      return LHS.IncludeHeader.size() < RHS.IncludeHeader.size();
+    return LHS.References > RHS.References;
+  });
+  llvm::SmallVector<llvm::StringRef, 1> Headers;
+  for (const auto &Include : Includes)
+    Headers.push_back(Include.IncludeHeader);
+  return Headers;
+}
+
 std::unique_ptr<PPCallbacks>
 collectIncludeStructureCallback(const SourceManager &SM,
                                 IncludeStructure *Out) {
Index: clangd/Diagnostics.h
===================================================================
--- clangd/Diagnostics.h
+++ clangd/Diagnostics.h
@@ -86,6 +86,8 @@
 /// Convert from clang diagnostic level to LSP severity.
 int getSeverity(DiagnosticsEngine::Level L);
 
+class IncludeFixer;
+
 /// StoreDiags collects the diagnostics that can later be reported by
 /// clangd. It groups all notes for a diagnostic into a single Diag
 /// and filters out diagnostics that don't mention the main file (i.e. neither
@@ -99,9 +101,14 @@
   void HandleDiagnostic(DiagnosticsEngine::Level DiagLevel,
                         const clang::Diagnostic &Info) override;
 
+  /// If set, possibly adds fixes for diagnostics using \p Fixer.
+  void setIncludeFixer(const IncludeFixer &Fixer) { FixIncludes = &Fixer; }
+
 private:
   void flushLastDiag();
 
+  const IncludeFixer *FixIncludes = nullptr;
+
   std::vector<Diag> Output;
   llvm::Optional<LangOptions> LangOpts;
   llvm::Optional<Diag> LastDiag;
Index: clangd/Diagnostics.cpp
===================================================================
--- clangd/Diagnostics.cpp
+++ clangd/Diagnostics.cpp
@@ -8,6 +8,7 @@
 
 #include "Diagnostics.h"
 #include "Compiler.h"
+#include "IncludeFixer.h"
 #include "Logger.h"
 #include "SourceCode.h"
 #include "clang/Basic/SourceManager.h"
@@ -374,6 +375,11 @@
 
     if (!Info.getFixItHints().empty())
       AddFix(true /* try to invent a message instead of repeating the diag */);
+    if (FixIncludes) {
+      auto ExtraFixes = FixIncludes->fix(DiagLevel, Info);
+      LastDiag->Fixes.insert(LastDiag->Fixes.end(), ExtraFixes.begin(),
+                             ExtraFixes.end());
+    }
   } else {
     // Handle a note to an existing diagnostic.
     if (!LastDiag) {
Index: clangd/CodeComplete.cpp
===================================================================
--- clangd/CodeComplete.cpp
+++ clangd/CodeComplete.cpp
@@ -177,28 +177,6 @@
   return Result;
 }
 
-/// Creates a `HeaderFile` from \p Header which can be either a URI or a literal
-/// include.
-static llvm::Expected<HeaderFile> toHeaderFile(llvm::StringRef Header,
-                                               llvm::StringRef HintPath) {
-  if (isLiteralInclude(Header))
-    return HeaderFile{Header.str(), /*Verbatim=*/true};
-  auto U = URI::parse(Header);
-  if (!U)
-    return U.takeError();
-
-  auto IncludePath = URI::includeSpelling(*U);
-  if (!IncludePath)
-    return IncludePath.takeError();
-  if (!IncludePath->empty())
-    return HeaderFile{std::move(*IncludePath), /*Verbatim=*/true};
-
-  auto Resolved = URI::resolve(*U, HintPath);
-  if (!Resolved)
-    return Resolved.takeError();
-  return HeaderFile{std::move(*Resolved), /*Verbatim=*/false};
-}
-
 /// A code completion result, in clang-native form.
 /// It may be promoted to a CompletionItem if it's among the top-ranked results.
 struct CompletionCandidate {
@@ -1155,24 +1133,6 @@
   return CachedReq;
 }
 
-// Returns the most popular include header for \p Sym. If two headers are
-// equally popular, prefer the shorter one. Returns empty string if \p Sym has
-// no include header.
-llvm::SmallVector<llvm::StringRef, 1> getRankedIncludes(const Symbol &Sym) {
-  auto Includes = Sym.IncludeHeaders;
-  // Sort in descending order by reference count and header length.
-  llvm::sort(Includes, [](const Symbol::IncludeHeaderWithReferences &LHS,
-                          const Symbol::IncludeHeaderWithReferences &RHS) {
-    if (LHS.References == RHS.References)
-      return LHS.IncludeHeader.size() < RHS.IncludeHeader.size();
-    return LHS.References > RHS.References;
-  });
-  llvm::SmallVector<llvm::StringRef, 1> Headers;
-  for (const auto &Include : Includes)
-    Headers.push_back(Include.IncludeHeader);
-  return Headers;
-}
-
 // Runs Sema-based (AST) and Index-based completion, returns merged results.
 //
 // There are a few tricky considerations:
@@ -1253,19 +1213,12 @@
     CodeCompleteResult Output;
     auto RecorderOwner = llvm::make_unique<CompletionRecorder>(Opts, [&]() {
       assert(Recorder && "Recorder is not set");
-      auto Style =
-          format::getStyle(format::DefaultFormatStyle, SemaCCInput.FileName,
-                           format::DefaultFallbackStyle, SemaCCInput.Contents,
-                           SemaCCInput.VFS.get());
-      if (!Style) {
-        log("getStyle() failed for file {0}: {1}. Fallback is LLVM style.",
-            SemaCCInput.FileName, Style.takeError());
-        Style = format::getLLVMStyle();
-      }
+      auto Style = getFormatStyleForFile(
+          SemaCCInput.FileName, SemaCCInput.Contents, SemaCCInput.VFS.get());
       // If preprocessor was run, inclusions from preprocessor callback should
       // already be added to Includes.
       Inserter.emplace(
-          SemaCCInput.FileName, SemaCCInput.Contents, *Style,
+          SemaCCInput.FileName, SemaCCInput.Contents, Style,
           SemaCCInput.Command.Directory,
           Recorder->CCSema->getPreprocessor().getHeaderSearchInfo());
       for (const auto &Inc : Includes.MainFileIncludes)
Index: clangd/ClangdUnit.h
===================================================================
--- clangd/ClangdUnit.h
+++ clangd/ClangdUnit.h
@@ -15,6 +15,7 @@
 #include "Headers.h"
 #include "Path.h"
 #include "Protocol.h"
+#include "index/Index.h"
 #include "clang/Frontend/FrontendAction.h"
 #include "clang/Frontend/PrecompiledPreamble.h"
 #include "clang/Lex/Preprocessor.h"
@@ -61,9 +62,19 @@
 
 /// Information required to run clang, e.g. to parse AST or do code completion.
 struct ParseInputs {
+  ParseInputs(tooling::CompileCommand CompileCommand,
+              IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS,
+              std::string Contents, const SymbolIndex *Index)
+      : CompileCommand(CompileCommand), FS(std::move(FS)),
+        Contents(std::move(Contents)), Index(Index) {}
+
+  ParseInputs() = default;
+
   tooling::CompileCommand CompileCommand;
   IntrusiveRefCntPtr<llvm::vfs::FileSystem> FS;
   std::string Contents;
+  // Used to recover from diagnostics (e.g. find missing includes for symbol).
+  const SymbolIndex *Index = nullptr;
 };
 
 /// Stores and provides access to parsed AST.
@@ -76,7 +87,8 @@
         std::shared_ptr<const PreambleData> Preamble,
         std::unique_ptr<llvm::MemoryBuffer> Buffer,
         std::shared_ptr<PCHContainerOperations> PCHs,
-        IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS);
+        IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS,
+        const SymbolIndex *Index);
 
   ParsedAST(ParsedAST &&Other);
   ParsedAST &operator=(ParsedAST &&Other);
Index: clangd/ClangdUnit.cpp
===================================================================
--- clangd/ClangdUnit.cpp
+++ clangd/ClangdUnit.cpp
@@ -11,9 +11,12 @@
 #include "../clang-tidy/ClangTidyModuleRegistry.h"
 #include "Compiler.h"
 #include "Diagnostics.h"
+#include "Headers.h"
+#include "IncludeFixer.h"
 #include "Logger.h"
 #include "SourceCode.h"
 #include "Trace.h"
+#include "index/Index.h"
 #include "clang/AST/ASTContext.h"
 #include "clang/Basic/LangOptions.h"
 #include "clang/Frontend/CompilerInstance.h"
@@ -30,6 +33,7 @@
 #include "clang/Serialization/ASTWriter.h"
 #include "clang/Tooling/CompilationDatabase.h"
 #include "llvm/ADT/ArrayRef.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/raw_ostream.h"
@@ -227,7 +231,8 @@
                  std::shared_ptr<const PreambleData> Preamble,
                  std::unique_ptr<llvm::MemoryBuffer> Buffer,
                  std::shared_ptr<PCHContainerOperations> PCHs,
-                 llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS) {
+                 llvm::IntrusiveRefCntPtr<llvm::vfs::FileSystem> VFS,
+                 const SymbolIndex *Index) {
   assert(CI);
   // Command-line parsing sets DisableFree to true by default, but we don't want
   // to leak memory in clangd.
@@ -236,9 +241,11 @@
       Preamble ? &Preamble->Preamble : nullptr;
 
   StoreDiags ASTDiags;
+  std::string Content = Buffer->getBuffer();
+
   auto Clang =
       prepareCompilerInstance(std::move(CI), PreamblePCH, std::move(Buffer),
-                              std::move(PCHs), std::move(VFS), ASTDiags);
+                              std::move(PCHs), VFS, ASTDiags);
   if (!Clang)
     return None;
 
@@ -285,6 +292,24 @@
     }
   }
 
+  llvm::Optional<IncludeFixer> FixIncludes;
+  auto BuildDir = VFS->getCurrentWorkingDirectory();
+  // Add IncludeFixer if Index is provided.
+  if (Index && !BuildDir.getError()) {
+    auto Style = getFormatStyleForFile(MainInput.getFile(), Content, VFS.get());
+    auto Inserter = llvm::make_unique<IncludeInserter>(
+        MainInput.getFile(), Content, Style, BuildDir.get(),
+        Clang->getPreprocessor().getHeaderSearchInfo());
+    if (Preamble) {
+      for (const auto &Inc : Preamble->Includes.MainFileIncludes)
+        Inserter->addExisting(Inc);
+    }
+    FixIncludes.emplace(*Clang, MainInput.getFile(), std::move(Inserter),
+                        *Index);
+    ASTDiags.setIncludeFixer(*FixIncludes);
+    Clang->setExternalSemaSource(FixIncludes->typoRecorder());
+  }
+
   // Copy over the includes from the preamble, then combine with the
   // non-preamble includes below.
   auto Includes = Preamble ? Preamble->Includes : IncludeStructure{};
@@ -538,7 +563,7 @@
   return ParsedAST::build(llvm::make_unique<CompilerInvocation>(*Invocation),
                           Preamble,
                           llvm::MemoryBuffer::getMemBufferCopy(Inputs.Contents),
-                          PCHs, std::move(VFS));
+                          PCHs, std::move(VFS), Inputs.Index);
 }
 
 SourceLocation getBeginningOfIdentifier(ParsedAST &Unit, const Position &Pos,
Index: clangd/ClangdServer.cpp
===================================================================
--- clangd/ClangdServer.cpp
+++ clangd/ClangdServer.cpp
@@ -152,8 +152,9 @@
   // "PreparingBuild" status to inform users, it is non-trivial given the
   // current implementation.
   WorkScheduler.update(File,
-                       ParseInputs{getCompileCommand(File),
-                                   FSProvider.getFileSystem(), Contents.str()},
+                       ParseInputs(getCompileCommand(File),
+                                   FSProvider.getFileSystem(), Contents.str(),
+                                   Index),
                        WantDiags);
 }
 
Index: clangd/CMakeLists.txt
===================================================================
--- clangd/CMakeLists.txt
+++ clangd/CMakeLists.txt
@@ -40,6 +40,7 @@
   FuzzyMatch.cpp
   GlobalCompilationDatabase.cpp
   Headers.cpp
+  IncludeFixer.cpp
   JSONTransport.cpp
   Logger.cpp
   Protocol.cpp
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to