================
@@ -0,0 +1,296 @@
+//===--- OverridePureVirtuals.cpp --------------------------------*- 
C++-*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM 
Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "refactor/Tweak.h"
+
+#include "clang/AST/ASTContext.h"
+#include "clang/AST/DeclCXX.h"
+#include "clang/AST/Type.h"
+#include "clang/AST/TypeLoc.h"
+#include "clang/Basic/LLVM.h"
+#include "clang/Basic/SourceLocation.h"
+#include "clang/Tooling/Core/Replacement.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <string>
+
+namespace clang {
+namespace clangd {
+namespace {
+
+class OverridePureVirtuals : public Tweak {
+public:
+  const char *id() const final; // defined by REGISTER_TWEAK.
+  bool prepare(const Selection &Sel) override;
+  Expected<Effect> apply(const Selection &Sel) override;
+  std::string title() const override { return "Override pure virtual methods"; 
}
+  llvm::StringLiteral kind() const override {
+    return CodeAction::QUICKFIX_KIND;
+  }
+
+private:
+  // Stores the CXXRecordDecl of the class being modified.
+  const CXXRecordDecl *CurrentDeclDef = nullptr;
+  // Stores pure virtual methods that need overriding, grouped by their 
original
+  // access specifier.
+  llvm::MapVector<AccessSpecifier, llvm::SmallVector<const CXXMethodDecl *>>
+      MissingMethodsByAccess;
+  // Stores the source locations of existing access specifiers in CurrentDecl.
+  llvm::MapVector<AccessSpecifier, SourceLocation> AccessSpecifierLocations;
+
+  // Helper function to gather information before applying the tweak.
+  void collectMissingPureVirtuals(const Selection &Sel);
+};
+
+REGISTER_TWEAK(OverridePureVirtuals)
+
+// Function to get all unique pure virtual methods from the entire
+// base class hierarchy of CurrentDeclDef.
+llvm::SmallVector<const clang::CXXMethodDecl *>
+getAllUniquePureVirtualsFromBaseHierarchy(
+    const clang::CXXRecordDecl *CurrentDeclDef) {
+  llvm::SmallVector<const clang::CXXMethodDecl *> AllPureVirtualsInHierarchy;
+  llvm::DenseSet<const clang::CXXMethodDecl *> CanonicalPureVirtualsSeen;
+
+  if (!CurrentDeclDef || !CurrentDeclDef->getDefinition())
+    return AllPureVirtualsInHierarchy;
+
+  const clang::CXXRecordDecl *Def = CurrentDeclDef->getDefinition();
+
+  Def->forallBases([&](const clang::CXXRecordDecl *BaseDefinition) {
+    for (const clang::CXXMethodDecl *Method : BaseDefinition->methods()) {
+      if (Method->isPureVirtual() &&
+          CanonicalPureVirtualsSeen.insert(Method->getCanonicalDecl()).second)
+        AllPureVirtualsInHierarchy.emplace_back(Method);
+    }
+    // Continue iterating through all bases.
+    return true;
+  });
+
+  return AllPureVirtualsInHierarchy;
+}
+
+// Gets canonical declarations of methods already overridden or implemented in
+// class D.
+llvm::SetVector<const CXXMethodDecl *>
+getImplementedOrOverriddenCanonicals(const CXXRecordDecl *D) {
+  llvm::SetVector<const CXXMethodDecl *> ImplementedSet;
+  for (const CXXMethodDecl *M : D->methods()) {
+    // If M provides an implementation for any virtual method it overrides.
+    // A method is an "implementation" if it's virtual and not pure.
+    // Or if it directly overrides a base method.
+    for (const CXXMethodDecl *OverriddenM : M->overridden_methods())
+      ImplementedSet.insert(OverriddenM->getCanonicalDecl());
+  }
+  return ImplementedSet;
+}
+
+// Get the location of every colon of the `AccessSpecifier`.
+llvm::MapVector<AccessSpecifier, SourceLocation>
+getSpecifierLocations(const CXXRecordDecl *D) {
+  llvm::MapVector<AccessSpecifier, SourceLocation> Locs;
+  for (auto *DeclNode : D->decls()) {
+    if (const auto *ASD = llvm::dyn_cast<AccessSpecDecl>(DeclNode))
+      Locs[ASD->getAccess()] = ASD->getColonLoc();
+  }
+  return Locs;
+}
+
+bool hasAbstractBaseAncestor(const clang::CXXRecordDecl *CurrentDecl) {
+  if (!CurrentDecl || !CurrentDecl->getDefinition())
+    return false;
+
+  return llvm::any_of(
+      CurrentDecl->getDefinition()->bases(), [](CXXBaseSpecifier BaseSpec) {
+        const auto *D = BaseSpec.getType()->getAsCXXRecordDecl();
+        const auto *Def = D ? D->getDefinition() : nullptr;
+        return Def && Def->isAbstract();
+      });
+}
+
+// Check if the current class has any pure virtual method to be implemented.
+bool OverridePureVirtuals::prepare(const Selection &Sel) {
+  const SelectionTree::Node *Node = Sel.ASTSelection.commonAncestor();
+  if (!Node)
+    return false;
+
+  // Make sure we have a definition.
+  CurrentDeclDef = Node->ASTNode.get<CXXRecordDecl>();
+  if (!CurrentDeclDef || !CurrentDeclDef->getDefinition())
+    return false;
+
+  // From now on, we should work with the definition.
+  CurrentDeclDef = CurrentDeclDef->getDefinition();
+
+  // Only offer for abstract classes with abstract bases.
+  return CurrentDeclDef->isAbstract() &&
+         hasAbstractBaseAncestor(CurrentDeclDef);
+}
+
+// Collects all pure virtual methods that are missing an override in
+// CurrentDecl, grouped by their original access specifier.
+void OverridePureVirtuals::collectMissingPureVirtuals(const Selection &Sel) {
+  if (!CurrentDeclDef)
+    return;
+
+  AccessSpecifierLocations = getSpecifierLocations(CurrentDeclDef);
+  MissingMethodsByAccess.clear();
+
+  // Get all unique pure virtual methods from the entire base class hierarchy.
+  llvm::SmallVector<const CXXMethodDecl *> AllPureVirtualsInHierarchy =
+      getAllUniquePureVirtualsFromBaseHierarchy(CurrentDeclDef);
+
+  // Get methods already implemented or overridden in CurrentDecl.
+  const auto ImplementedOrOverriddenSet =
+      getImplementedOrOverriddenCanonicals(CurrentDeclDef);
+
+  // Filter AllPureVirtualsInHierarchy to find those not in
+  // ImplementedOrOverriddenSet, which needs to be overriden.
+  for (const CXXMethodDecl *BaseMethod : AllPureVirtualsInHierarchy) {
+    bool AlreadyHandled = ImplementedOrOverriddenSet.contains(BaseMethod);
+    if (!AlreadyHandled)
+      MissingMethodsByAccess[BaseMethod->getAccess()].emplace_back(BaseMethod);
+  }
+}
+
+// Free function to generate the string for a group of method overrides.
+std::string generateOverridesStringForGroup(
+    llvm::SmallVector<const CXXMethodDecl *> Methods,
+    const LangOptions &LangOpts) {
+  const auto GetParamString = [&LangOpts](const ParmVarDecl *P) {
+    std::string TypeStr = P->getType().getAsString(LangOpts);
+    // Unnamed parameter.
+    if (P->getNameAsString().empty())
+      return TypeStr;
+
+    return llvm::formatv("{0} {1}", std::move(TypeStr), P->getNameAsString())
+        .str();
+  };
+
+  std::string MethodsString;
+  for (const auto *Method : Methods) {
+    llvm::SmallVector<std::string> ParamsAsString;
+    ParamsAsString.reserve(Method->parameters().size());
+    llvm::transform(Method->parameters(), std::back_inserter(ParamsAsString),
+                    GetParamString);
+    auto Params = llvm::join(ParamsAsString, ", ");
+
+    MethodsString +=
+        llvm::formatv(
+            "  {0} {1}({2}) {3}override {{\n"
----------------
marcogmaia wrote:

I've changed the way we are dealing with the method string.

Now we are getting the string representation of the method as it is declared, 
and using the lexer to remove the pure virtual syntax. 

This should deal with the above cases you've mentioned; more tests were added.

https://github.com/llvm/llvm-project/pull/139348
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to