klimek created this revision.
klimek added a reviewer: bkramer.
klimek added a subscriber: cfe-commits.

Instead of just using popularity, we also take into account how similar the
path of the current file is to the path of the header.
Our first approach is to get popularity into a reasonably small scale by taking
log2 (which is roughly intuitive to how humans would bucket popularity), and
multiply that with the number of matching prefix path fragments of the included
header with the current file.
Note that currently we do not take special care for unclean paths containing
"../" or "./".


https://reviews.llvm.org/D28548

Files:
  include-fixer/IncludeFixer.cpp
  include-fixer/SymbolIndexManager.cpp
  include-fixer/SymbolIndexManager.h
  include-fixer/tool/ClangIncludeFixer.cpp
  test/include-fixer/Inputs/fake_yaml_db.yaml
  test/include-fixer/ranking.cpp

Index: test/include-fixer/ranking.cpp
===================================================================
--- test/include-fixer/ranking.cpp
+++ test/include-fixer/ranking.cpp
@@ -1,6 +1,9 @@
 // RUN: clang-include-fixer -db=yaml -input=%S/Inputs/fake_yaml_db.yaml -output-headers %s -- | FileCheck %s
+// RUN: clang-include-fixer -query-symbol bar -db=yaml -input=%S/Inputs/fake_yaml_db.yaml -output-headers %s -- | FileCheck %s
 
 // CHECK:     "HeaderInfos": [
+// CHECK-NEXT:  {"Header": "\"test/include-fixer/baz.h\"",
+// CHECK-NEXT:   "QualifiedName": "c::bar"},
 // CHECK-NEXT:  {"Header": "\"../include/bar.h\"",
 // CHECK-NEXT:   "QualifiedName": "b::a::bar"},
 // CHECK-NEXT:  {"Header": "\"../include/zbar.h\"",
Index: test/include-fixer/Inputs/fake_yaml_db.yaml
===================================================================
--- test/include-fixer/Inputs/fake_yaml_db.yaml
+++ test/include-fixer/Inputs/fake_yaml_db.yaml
@@ -9,7 +9,6 @@
 LineNumber:      1
 Type:            Class
 NumOccurrences:  1
-...
 ---
 Name:           bar
 Contexts:
@@ -21,7 +20,7 @@
 LineNumber:      1
 Type:            Class
 NumOccurrences:  1
-...
+---
 Name:           bar
 Contexts:
   - ContextType:     Namespace
@@ -32,7 +31,7 @@
 LineNumber:      2
 Type:            Class
 NumOccurrences:  3
-...
+---
 Name:           bar
 Contexts:
   - ContextType:     Namespace
@@ -50,4 +49,12 @@
 LineNumber:      1
 Type:            Variable
 NumOccurrences:  1
-...
+---
+Name:            bar
+Contexts:
+  - ContextType:    Namespace
+    ContextName:    c
+FilePath:        test/include-fixer/baz.h
+LineNumber:      1
+Type:            Class
+NumOccurrences:  1
Index: include-fixer/tool/ClangIncludeFixer.cpp
===================================================================
--- include-fixer/tool/ClangIncludeFixer.cpp
+++ include-fixer/tool/ClangIncludeFixer.cpp
@@ -332,7 +332,8 @@
 
   // Query symbol mode.
   if (!QuerySymbol.empty()) {
-    auto MatchedSymbols = SymbolIndexMgr->search(QuerySymbol);
+    auto MatchedSymbols = SymbolIndexMgr->search(
+        QuerySymbol, /*IsNestedSearch=*/true, SourceFilePath);
     for (auto &Symbol : MatchedSymbols) {
       std::string HeaderPath = Symbol.getFilePath().str();
       Symbol.SetFilePath(((HeaderPath[0] == '"' || HeaderPath[0] == '<')
Index: include-fixer/SymbolIndexManager.h
===================================================================
--- include-fixer/SymbolIndexManager.h
+++ include-fixer/SymbolIndexManager.h
@@ -42,7 +42,8 @@
   ///
   /// \returns A list of symbol candidates.
   std::vector<find_all_symbols::SymbolInfo>
-  search(llvm::StringRef Identifier, bool IsNestedSearch = true) const;
+  search(llvm::StringRef Identifier, bool IsNestedSearch = true,
+         llvm::StringRef FileName = "") const;
 
 private:
   std::vector<std::shared_future<std::unique_ptr<SymbolIndex>>> SymbolIndices;
Index: include-fixer/SymbolIndexManager.cpp
===================================================================
--- include-fixer/SymbolIndexManager.cpp
+++ include-fixer/SymbolIndexManager.cpp
@@ -12,38 +12,66 @@
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/Path.h"
 
 #define DEBUG_TYPE "include-fixer"
 
 namespace clang {
 namespace include_fixer {
 
 using clang::find_all_symbols::SymbolInfo;
 
-/// Sorts SymbolInfos based on the popularity info in SymbolInfo.
-static void rankByPopularity(std::vector<SymbolInfo> &Symbols) {
-  // First collect occurrences per header file.
-  llvm::DenseMap<llvm::StringRef, unsigned> HeaderPopularity;
-  for (const SymbolInfo &Symbol : Symbols) {
-    unsigned &Popularity = HeaderPopularity[Symbol.getFilePath()];
-    Popularity = std::max(Popularity, Symbol.getNumOccurrences());
+// Calculate a score based on whether we think the given header is closely
+// related to the given source file.
+static double similarityScore(llvm::StringRef FileName,
+                              llvm::StringRef Header) {
+  // Compute the maximum number of common path segements between Header and
+  // a suffix of FileName.
+  // We do not do a full longest common substring computation, as Header
+  // specifies the path we would directly #include, so we assume it is rooted
+  // relatively to a subproject of the repository.
+  int MaxSegments = 1;
+  for (auto FileI = llvm::sys::path::begin(FileName),
+            FileE = llvm::sys::path::end(FileName);
+       FileI != FileE; ++FileI) {
+    int Segments = 0;
+    for (auto HeaderI = llvm::sys::path::begin(Header),
+              HeaderE = llvm::sys::path::end(Header), I = FileI;
+         HeaderI != HeaderE && *I == *HeaderI && I != FileE; ++I, ++HeaderI) {
+      ++Segments;
+    }
+    MaxSegments = std::max(Segments, MaxSegments);
   }
+  return MaxSegments;
+}
 
-  // Sort by the gathered popularities. Use file name as a tie breaker so we can
+static void rank(std::vector<SymbolInfo> &Symbols,
+                 llvm::StringRef FileName) {
+  llvm::DenseMap<llvm::StringRef, double> Score;
+  for (const SymbolInfo &Symbol : Symbols) {
+    // Calculate a score from the similarity of the header the symbol is in
+    // with the current file and the popularity of the symbol.
+    double NewScore = similarityScore(FileName, Symbol.getFilePath()) *
+                      (1.0 + std::log2(1 + Symbol.getNumOccurrences()));
+    double &S = Score[Symbol.getFilePath()];
+    S = std::max(S, NewScore);
+  }
+  // Sort by the gathered scores. Use file name as a tie breaker so we can
   // deduplicate.
   std::sort(Symbols.begin(), Symbols.end(),
             [&](const SymbolInfo &A, const SymbolInfo &B) {
-              auto APop = HeaderPopularity[A.getFilePath()];
-              auto BPop = HeaderPopularity[B.getFilePath()];
-              if (APop != BPop)
-                return APop > BPop;
+              auto AS = Score[A.getFilePath()];
+              auto BS = Score[B.getFilePath()];
+              if (AS != BS)
+                return AS > BS;
               return A.getFilePath() < B.getFilePath();
             });
 }
 
 std::vector<find_all_symbols::SymbolInfo>
 SymbolIndexManager::search(llvm::StringRef Identifier,
-                           bool IsNestedSearch) const {
+                           bool IsNestedSearch,
+                           llvm::StringRef FileName) const {
   // The identifier may be fully qualified, so split it and get all the context
   // names.
   llvm::SmallVector<llvm::StringRef, 8> Names;
@@ -119,7 +147,7 @@
     TookPrefix = true;
   } while (MatchedSymbols.empty() && !Names.empty() && IsNestedSearch);
 
-  rankByPopularity(MatchedSymbols);
+  rank(MatchedSymbols, FileName);
   return MatchedSymbols;
 }
 
Index: include-fixer/IncludeFixer.cpp
===================================================================
--- include-fixer/IncludeFixer.cpp
+++ include-fixer/IncludeFixer.cpp
@@ -365,6 +365,9 @@
             .getLocWithOffset(Range.getOffset())
             .print(llvm::dbgs(), CI->getSourceManager()));
   DEBUG(llvm::dbgs() << " ...");
+  llvm::StringRef FileName = CI->getSourceManager().getFilename(
+      CI->getSourceManager().getLocForStartOfFile(
+          CI->getSourceManager().getMainFileID()));
 
   QuerySymbolInfos.push_back({Query.str(), ScopedQualifiers, Range});
 
@@ -385,9 +388,10 @@
   // context, it might treat the identifier as a nested class of the scoped
   // namespace.
   std::vector<find_all_symbols::SymbolInfo> MatchedSymbols =
-      SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false);
+      SymbolIndexMgr.search(QueryString, /*IsNestedSearch=*/false, FileName);
   if (MatchedSymbols.empty())
-    MatchedSymbols = SymbolIndexMgr.search(Query);
+    MatchedSymbols =
+        SymbolIndexMgr.search(Query, /*IsNestedSearch=*/true, FileName);
   DEBUG(llvm::dbgs() << "Having found " << MatchedSymbols.size()
                      << " symbols\n");
   // We store a copy of MatchedSymbols in a place where it's globally reachable.
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
http://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to