njames93 created this revision.
njames93 added reviewers: alexfh, klimek, aaron.ballman.
Herald added subscribers: carlosgalvezp, xazax.hun.
Herald added a project: All.
njames93 requested review of this revision.
Herald added a project: clang-tools-extra.
Herald added a subscriber: cfe-commits.

Reimplement most of the matching logic using Visitors instead of matchers.

Benchmarks from running the check over SemaCodeComplete.cpp
Before 1.02s, After 0.87s


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D125026

Files:
  clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp
  clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h

Index: clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h
===================================================================
--- clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h
+++ clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.h
@@ -34,73 +34,40 @@
 private:
   class Visitor;
 
-  void reportBinOp(const ast_matchers::MatchFinder::MatchResult &Result,
-                   const BinaryOperator *Op);
-
-  void matchBoolCondition(ast_matchers::MatchFinder *Finder, bool Value,
-                          StringRef BooleanId);
-
-  void matchTernaryResult(ast_matchers::MatchFinder *Finder, bool Value,
-                          StringRef Id);
-
-  void matchIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value,
-                          StringRef Id);
+  void reportBinOp(const ASTContext &Context, const BinaryOperator *Op);
 
   void matchIfAssignsBool(ast_matchers::MatchFinder *Finder, bool Value,
                           StringRef Id);
 
-  void matchCompoundIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value,
-                                  StringRef Id);
-
-  void matchCaseIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value,
-                              StringRef Id);
+  void replaceWithThenStatement(const ASTContext &Context,
+                                const IfStmt *IfStatement,
+                                const Expr *BoolLiteral);
 
-  void matchDefaultIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value,
-                                 StringRef Id);
+  void replaceWithElseStatement(const ASTContext &Context,
+                                const IfStmt *IfStatement,
+                                const Expr *BoolLiteral);
 
-  void matchLabelIfReturnsBool(ast_matchers::MatchFinder *Finder, bool Value,
-                               StringRef Id);
+  void replaceWithCondition(const ASTContext &Context,
+                            const ConditionalOperator *Ternary, bool Negated);
 
-  void
-  replaceWithThenStatement(const ast_matchers::MatchFinder::MatchResult &Result,
-                           const Expr *BoolLiteral);
-
-  void
-  replaceWithElseStatement(const ast_matchers::MatchFinder::MatchResult &Result,
-                           const Expr *BoolLiteral);
-
-  void
-  replaceWithCondition(const ast_matchers::MatchFinder::MatchResult &Result,
-                       const ConditionalOperator *Ternary, bool Negated);
-
-  void replaceWithReturnCondition(
-      const ast_matchers::MatchFinder::MatchResult &Result, const IfStmt *If,
-      bool Negated);
+  void replaceWithReturnCondition(const ASTContext &Context, const IfStmt *If,
+                                  const Expr *BoolLiteral, bool Negated);
 
   void
   replaceWithAssignment(const ast_matchers::MatchFinder::MatchResult &Result,
                         const IfStmt *If, bool Negated);
 
-  void replaceCompoundReturnWithCondition(
-      const ast_matchers::MatchFinder::MatchResult &Result,
-      const CompoundStmt *Compound, bool Negated);
-
-  void replaceCompoundReturnWithCondition(
-      const ast_matchers::MatchFinder::MatchResult &Result, bool Negated,
-      const IfStmt *If);
-
-  void replaceCaseCompoundReturnWithCondition(
-      const ast_matchers::MatchFinder::MatchResult &Result, bool Negated);
-
-  void replaceDefaultCompoundReturnWithCondition(
-      const ast_matchers::MatchFinder::MatchResult &Result, bool Negated);
+  void replaceCompoundReturnWithCondition(const ASTContext &Context,
+                                          const CompoundStmt *Compound,
+                                          const ReturnStmt *Ret, bool Negated);
 
-  void replaceLabelCompoundReturnWithCondition(
-      const ast_matchers::MatchFinder::MatchResult &Result, bool Negated);
+  void replaceCompoundReturnWithCondition(const ASTContext &Context,
+                                          const ReturnStmt *Ret, bool Negated,
+                                          const IfStmt *If);
 
-  void issueDiag(const ast_matchers::MatchFinder::MatchResult &Result,
-                 SourceLocation Loc, StringRef Description,
-                 SourceRange ReplacementRange, StringRef Replacement);
+  void issueDiag(const ASTContext &Result, SourceLocation Loc,
+                 StringRef Description, SourceRange ReplacementRange,
+                 StringRef Replacement);
 
   const bool ChainedConditionalReturn;
   const bool ChainedConditionalAssignment;
Index: clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp
===================================================================
--- clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp
+++ clang-tools-extra/clang-tidy/readability/SimplifyBooleanExprCheck.cpp
@@ -10,6 +10,7 @@
 #include "SimplifyBooleanExprMatchers.h"
 #include "clang/AST/RecursiveASTVisitor.h"
 #include "clang/Lex/Lexer.h"
+#include "llvm/ADT/PointerIntPair.h"
 
 #include <string>
 #include <utility>
@@ -22,45 +23,23 @@
 
 namespace {
 
-StringRef getText(const MatchFinder::MatchResult &Result, SourceRange Range) {
+StringRef getText(const ASTContext &Context, SourceRange Range) {
   return Lexer::getSourceText(CharSourceRange::getTokenRange(Range),
-                              *Result.SourceManager,
-                              Result.Context->getLangOpts());
+                              Context.getSourceManager(),
+                              Context.getLangOpts());
 }
 
-template <typename T>
-StringRef getText(const MatchFinder::MatchResult &Result, T &Node) {
-  return getText(Result, Node.getSourceRange());
+template <typename T> StringRef getText(const ASTContext &Context, T &Node) {
+  return getText(Context, Node.getSourceRange());
 }
 
 } // namespace
 
-static constexpr char ConditionThenStmtId[] = "if-bool-yields-then";
-static constexpr char ConditionElseStmtId[] = "if-bool-yields-else";
-static constexpr char TernaryId[] = "ternary-bool-yields-condition";
-static constexpr char TernaryNegatedId[] = "ternary-bool-yields-not-condition";
-static constexpr char IfReturnsBoolId[] = "if-return";
-static constexpr char IfReturnsNotBoolId[] = "if-not-return";
-static constexpr char ThenLiteralId[] = "then-literal";
 static constexpr char IfAssignVariableId[] = "if-assign-lvalue";
 static constexpr char IfAssignLocId[] = "if-assign-loc";
 static constexpr char IfAssignBoolId[] = "if-assign";
 static constexpr char IfAssignNotBoolId[] = "if-assign-not";
 static constexpr char IfAssignVarId[] = "if-assign-var";
-static constexpr char CompoundReturnId[] = "compound-return";
-static constexpr char CompoundIfId[] = "compound-if";
-static constexpr char CompoundBoolId[] = "compound-bool";
-static constexpr char CompoundNotBoolId[] = "compound-bool-not";
-static constexpr char CaseId[] = "case";
-static constexpr char CaseCompoundBoolId[] = "case-compound-bool";
-static constexpr char CaseCompoundNotBoolId[] = "case-compound-bool-not";
-static constexpr char DefaultId[] = "default";
-static constexpr char DefaultCompoundBoolId[] = "default-compound-bool";
-static constexpr char DefaultCompoundNotBoolId[] = "default-compound-bool-not";
-static constexpr char LabelId[] = "label";
-static constexpr char LabelCompoundBoolId[] = "label-compound-bool";
-static constexpr char LabelCompoundNotBoolId[] = "label-compound-bool-not";
-static constexpr char IfStmtId[] = "if";
 
 static constexpr char SimplifyOperatorDiagnostic[] =
     "redundant boolean literal supplied to boolean operator";
@@ -69,18 +48,6 @@
 static constexpr char SimplifyConditionalReturnDiagnostic[] =
     "redundant boolean literal in conditional return statement";
 
-static const Expr *getBoolLiteral(const MatchFinder::MatchResult &Result,
-                                  StringRef Id) {
-  if (const Expr *Literal = Result.Nodes.getNodeAs<CXXBoolLiteralExpr>(Id))
-    return Literal->getBeginLoc().isMacroID() ? nullptr : Literal;
-  if (const auto *Negated = Result.Nodes.getNodeAs<UnaryOperator>(Id)) {
-    if (Negated->getOpcode() == UO_LNot &&
-        isa<CXXBoolLiteralExpr>(Negated->getSubExpr()))
-      return Negated->getBeginLoc().isMacroID() ? nullptr : Negated;
-  }
-  return nullptr;
-}
-
 static internal::BindableMatcher<Stmt> literalOrNegatedBool(bool Value) {
   return expr(
       anyOf(cxxBoolLiteral(equals(Value)),
@@ -88,14 +55,6 @@
                           hasOperatorName("!"))));
 }
 
-static internal::Matcher<Stmt> returnsBool(bool Value,
-                                           StringRef Id = "ignored") {
-  auto SimpleReturnsBool = returnStmt(has(literalOrNegatedBool(Value).bind(Id)))
-                               .bind("returns-bool");
-  return anyOf(SimpleReturnsBool,
-               compoundStmt(statementCountIs(1), has(SimpleReturnsBool)));
-}
-
 static bool needsParensAfterUnaryNegation(const Expr *E) {
   E = E->IgnoreImpCasts();
   if (isa<BinaryOperator>(E) || isa<ConditionalOperator>(E))
@@ -192,32 +151,29 @@
   return !E->getType()->isBooleanType();
 }
 
-static std::string
-compareExpressionToConstant(const MatchFinder::MatchResult &Result,
-                            const Expr *E, bool Negated, const char *Constant) {
+static std::string compareExpressionToConstant(const ASTContext &Context,
+                                               const Expr *E, bool Negated,
+                                               const char *Constant) {
   E = E->IgnoreImpCasts();
   const std::string ExprText =
-      (isa<BinaryOperator>(E) ? ("(" + getText(Result, *E) + ")")
-                              : getText(Result, *E))
+      (isa<BinaryOperator>(E) ? ("(" + getText(Context, *E) + ")")
+                              : getText(Context, *E))
           .str();
   return ExprText + " " + (Negated ? "!=" : "==") + " " + Constant;
 }
 
-static std::string
-compareExpressionToNullPtr(const MatchFinder::MatchResult &Result,
-                           const Expr *E, bool Negated) {
-  const char *NullPtr =
-      Result.Context->getLangOpts().CPlusPlus11 ? "nullptr" : "NULL";
-  return compareExpressionToConstant(Result, E, Negated, NullPtr);
+static std::string compareExpressionToNullPtr(const ASTContext &Context,
+                                              const Expr *E, bool Negated) {
+  const char *NullPtr = Context.getLangOpts().CPlusPlus11 ? "nullptr" : "NULL";
+  return compareExpressionToConstant(Context, E, Negated, NullPtr);
 }
 
-static std::string
-compareExpressionToZero(const MatchFinder::MatchResult &Result, const Expr *E,
-                        bool Negated) {
-  return compareExpressionToConstant(Result, E, Negated, "0");
+static std::string compareExpressionToZero(const ASTContext &Context,
+                                           const Expr *E, bool Negated) {
+  return compareExpressionToConstant(Context, E, Negated, "0");
 }
 
-static std::string replacementExpression(const MatchFinder::MatchResult &Result,
+static std::string replacementExpression(const ASTContext &Context,
                                          bool Negated, const Expr *E) {
   E = E->IgnoreParenBaseCasts();
   if (const auto *EC = dyn_cast<ExprWithCleanups>(E))
@@ -228,20 +184,20 @@
     if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
       if (UnOp->getOpcode() == UO_LNot) {
         if (needsNullPtrComparison(UnOp->getSubExpr()))
-          return compareExpressionToNullPtr(Result, UnOp->getSubExpr(), true);
+          return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), true);
 
         if (needsZeroComparison(UnOp->getSubExpr()))
-          return compareExpressionToZero(Result, UnOp->getSubExpr(), true);
+          return compareExpressionToZero(Context, UnOp->getSubExpr(), true);
 
-        return replacementExpression(Result, false, UnOp->getSubExpr());
+        return replacementExpression(Context, false, UnOp->getSubExpr());
       }
     }
 
     if (needsNullPtrComparison(E))
-      return compareExpressionToNullPtr(Result, E, false);
+      return compareExpressionToNullPtr(Context, E, false);
 
     if (needsZeroComparison(E))
-      return compareExpressionToZero(Result, E, false);
+      return compareExpressionToZero(Context, E, false);
 
     StringRef NegatedOperator;
     const Expr *LHS = nullptr;
@@ -258,20 +214,20 @@
       }
     }
     if (!NegatedOperator.empty() && LHS && RHS)
-      return (asBool((getText(Result, *LHS) + " " + NegatedOperator + " " +
-                      getText(Result, *RHS))
+      return (asBool((getText(Context, *LHS) + " " + NegatedOperator + " " +
+                      getText(Context, *RHS))
                          .str(),
                      NeedsStaticCast));
 
-    StringRef Text = getText(Result, *E);
+    StringRef Text = getText(Context, *E);
     if (!NeedsStaticCast && needsParensAfterUnaryNegation(E))
       return ("!(" + Text + ")").str();
 
     if (needsNullPtrComparison(E))
-      return compareExpressionToNullPtr(Result, E, false);
+      return compareExpressionToNullPtr(Context, E, false);
 
     if (needsZeroComparison(E))
-      return compareExpressionToZero(Result, E, false);
+      return compareExpressionToZero(Context, E, false);
 
     return ("!" + asBool(Text, NeedsStaticCast));
   }
@@ -279,20 +235,20 @@
   if (const auto *UnOp = dyn_cast<UnaryOperator>(E)) {
     if (UnOp->getOpcode() == UO_LNot) {
       if (needsNullPtrComparison(UnOp->getSubExpr()))
-        return compareExpressionToNullPtr(Result, UnOp->getSubExpr(), false);
+        return compareExpressionToNullPtr(Context, UnOp->getSubExpr(), false);
 
       if (needsZeroComparison(UnOp->getSubExpr()))
-        return compareExpressionToZero(Result, UnOp->getSubExpr(), false);
+        return compareExpressionToZero(Context, UnOp->getSubExpr(), false);
     }
   }
 
   if (needsNullPtrComparison(E))
-    return compareExpressionToNullPtr(Result, E, true);
+    return compareExpressionToNullPtr(Context, E, true);
 
   if (needsZeroComparison(E))
-    return compareExpressionToZero(Result, E, true);
+    return compareExpressionToZero(Context, E, true);
 
-  return asBool(getText(Result, *E), NeedsStaticCast);
+  return asBool(getText(Context, *E), NeedsStaticCast);
 }
 
 static const Expr *stmtReturnsBool(const ReturnStmt *Ret, bool Negated) {
@@ -330,14 +286,14 @@
   return nullptr;
 }
 
-static bool containsDiscardedTokens(const MatchFinder::MatchResult &Result,
+static bool containsDiscardedTokens(const ASTContext &Context,
                                     CharSourceRange CharRange) {
   std::string ReplacementText =
-      Lexer::getSourceText(CharRange, *Result.SourceManager,
-                           Result.Context->getLangOpts())
+      Lexer::getSourceText(CharRange, Context.getSourceManager(),
+                           Context.getLangOpts())
           .str();
-  Lexer Lex(CharRange.getBegin(), Result.Context->getLangOpts(),
-            ReplacementText.data(), ReplacementText.data(),
+  Lexer Lex(CharRange.getBegin(), Context.getLangOpts(), ReplacementText.data(),
+            ReplacementText.data(),
             ReplacementText.data() + ReplacementText.size());
   Lex.SetCommentRetentionState(true);
 
@@ -352,18 +308,147 @@
 
 class SimplifyBooleanExprCheck::Visitor : public RecursiveASTVisitor<Visitor> {
 public:
-  Visitor(SimplifyBooleanExprCheck *Check,
-          const MatchFinder::MatchResult &Result)
-      : Check(Check), Result(Result) {}
+  Visitor(SimplifyBooleanExprCheck *Check, ASTContext &Context)
+      : Check(Check), Context(Context) {}
+
+  bool traverse() { return TraverseAST(Context); }
 
   bool VisitBinaryOperator(const BinaryOperator *Op) const {
-    Check->reportBinOp(Result, Op);
+    Check->reportBinOp(Context, Op);
+    return true;
+  }
+
+  static Optional<bool> getAsBoolLiteral(const Expr *E, bool FilterMacro) {
+    if (const auto *Bool = dyn_cast<CXXBoolLiteralExpr>(E)) {
+      if (FilterMacro && Bool->getBeginLoc().isMacroID())
+        return llvm::None;
+      return Bool->getValue();
+    }
+    if (const auto *UOp = dyn_cast<UnaryOperator>(E)) {
+      if (FilterMacro && UOp->getBeginLoc().isMacroID())
+        return None;
+      if (UOp->getOpcode() == UO_LNot)
+        if (Optional<bool> Res = getAsBoolLiteral(
+                UOp->getSubExpr()->IgnoreImplicit(), FilterMacro))
+          return !*Res;
+    }
+    return llvm::None;
+  }
+
+  static llvm::PointerIntPair<const Expr *, 1, bool>
+  parseReturnLiteralBool(const Stmt *S) {
+    const auto *RS = dyn_cast<ReturnStmt>(S);
+    if (!RS || !RS->getRetValue())
+      return {};
+    if (auto Ret =
+            getAsBoolLiteral(RS->getRetValue()->IgnoreImplicit(), false)) {
+      return {RS->getRetValue(), *Ret};
+    }
+    return {};
+  }
+
+  template <typename Functor>
+  static auto checkSingleStatement(Stmt *S, Functor F) -> decltype(F(S)) {
+    if (auto *CS = dyn_cast<CompoundStmt>(S)) {
+      if (CS->size() == 1)
+        return F(CS->body_front());
+      return {};
+    }
+    return F(S);
+  }
+
+  bool doesIfHaveIfParent(const IfStmt *If) {
+    auto Parents = Context.getParents(*If);
+    if (Parents.empty())
+      return false;
+    return Parents[0].get<IfStmt>() != nullptr;
+  }
+
+  bool VisitIfStmt(IfStmt *If) {
+    Expr *Cond = If->getCond()->IgnoreImplicit();
+    if (auto Bool = getAsBoolLiteral(Cond, true)) {
+      if (*Bool)
+        Check->replaceWithThenStatement(Context, If, Cond);
+      else
+        Check->replaceWithElseStatement(Context, If, Cond);
+    }
+
+    if (If->getElse()) {
+      auto ThenReturnBool =
+          checkSingleStatement(If->getThen(), parseReturnLiteralBool);
+      if (ThenReturnBool.getPointer()) {
+        auto ElseReturnBool =
+            checkSingleStatement(If->getElse(), parseReturnLiteralBool);
+        if (ElseReturnBool.getPointer() &&
+            ThenReturnBool.getInt() != ElseReturnBool.getInt()) {
+          if (Check->ChainedConditionalReturn || !doesIfHaveIfParent(If)) {
+            Check->replaceWithReturnCondition(Context, If,
+                                              ThenReturnBool.getPointer(),
+                                              !ThenReturnBool.getInt());
+          }
+        }
+      }
+    }
+    return true;
+  }
+
+  bool VisitConditionalOperator(ConditionalOperator *Cond) {
+    if (auto Then =
+            getAsBoolLiteral(Cond->getTrueExpr()->IgnoreImplicit(), false)) {
+      if (auto Else =
+              getAsBoolLiteral(Cond->getFalseExpr()->IgnoreImplicit(), false)) {
+        if (*Then != *Else)
+          Check->replaceWithCondition(Context, Cond, *Else);
+      }
+    }
+    return true;
+  }
+
+  bool VisitCompoundStmt(CompoundStmt *CS) {
+    if (CS->size() < 2)
+      return true;
+    for (auto Second = CS->body_rbegin(), First = std::next(Second),
+              End = CS->body_rend();
+         First != End; ++Second, ++First) {
+      auto RetStmt = parseReturnLiteralBool(*Second);
+      if (!RetStmt.getPointer())
+        continue;
+
+      if (auto *If = dyn_cast<IfStmt>(*First)) {
+        auto ThenReturnBool =
+            checkSingleStatement(If->getThen(), parseReturnLiteralBool);
+        if (ThenReturnBool.getPointer() &&
+            ThenReturnBool.getInt() != RetStmt.getInt()) {
+          if (Check->ChainedConditionalReturn ||
+              (If->getElse() == nullptr && !doesIfHaveIfParent(If))) {
+            Check->replaceCompoundReturnWithCondition(
+                Context, CS, cast<ReturnStmt>(*Second), RetStmt.getInt());
+          }
+        }
+      } else if (isa<LabelStmt, CaseStmt, DefaultStmt>(*First)) {
+        Stmt *SubStmt =
+            isa<LabelStmt>(*First)  ? cast<LabelStmt>(*First)->getSubStmt()
+            : isa<CaseStmt>(*First) ? cast<CaseStmt>(*First)->getSubStmt()
+                                    : cast<DefaultStmt>(*First)->getSubStmt();
+        if (auto *SubIf = dyn_cast<IfStmt>(SubStmt)) {
+          if (!SubIf->getElse()) {
+            auto ThenReturnBool =
+                checkSingleStatement(SubIf->getThen(), parseReturnLiteralBool);
+            if (ThenReturnBool.getPointer() &&
+                ThenReturnBool.getInt() != RetStmt.getInt()) {
+              Check->replaceCompoundReturnWithCondition(
+                  Context, cast<ReturnStmt>(*Second), RetStmt.getInt(), SubIf);
+            }
+          }
+        }
+      }
+    }
     return true;
   }
 
 private:
   SimplifyBooleanExprCheck *Check;
-  const MatchFinder::MatchResult &Result;
+  ASTContext &Context;
 };
 
 SimplifyBooleanExprCheck::SimplifyBooleanExprCheck(StringRef Name,
@@ -387,8 +472,8 @@
   return false;
 }
 
-void SimplifyBooleanExprCheck::reportBinOp(
-    const MatchFinder::MatchResult &Result, const BinaryOperator *Op) {
+void SimplifyBooleanExprCheck::reportBinOp(const ASTContext &Context,
+                                           const BinaryOperator *Op) {
   const auto *LHS = Op->getLHS()->IgnoreParenImpCasts();
   const auto *RHS = Op->getRHS()->IgnoreParenImpCasts();
 
@@ -410,12 +495,12 @@
 
   bool BoolValue = Bool->getValue();
 
-  auto ReplaceWithExpression = [this, &Result, LHS, RHS,
+  auto ReplaceWithExpression = [this, &Context, LHS, RHS,
                                 Bool](const Expr *ReplaceWith, bool Negated) {
     std::string Replacement =
-        replacementExpression(Result, Negated, ReplaceWith);
+        replacementExpression(Context, Negated, ReplaceWith);
     SourceRange Range(LHS->getBeginLoc(), RHS->getEndLoc());
-    issueDiag(Result, Bool->getBeginLoc(), SimplifyOperatorDiagnostic, Range,
+    issueDiag(Context, Bool->getBeginLoc(), SimplifyOperatorDiagnostic, Range,
               Replacement);
   };
 
@@ -449,39 +534,6 @@
   }
 }
 
-void SimplifyBooleanExprCheck::matchBoolCondition(MatchFinder *Finder,
-                                                  bool Value,
-                                                  StringRef BooleanId) {
-  Finder->addMatcher(
-      ifStmt(hasCondition(literalOrNegatedBool(Value).bind(BooleanId)))
-          .bind(IfStmtId),
-      this);
-}
-
-void SimplifyBooleanExprCheck::matchTernaryResult(MatchFinder *Finder,
-                                                  bool Value, StringRef Id) {
-  Finder->addMatcher(
-      conditionalOperator(hasTrueExpression(literalOrNegatedBool(Value)),
-                          hasFalseExpression(literalOrNegatedBool(!Value)))
-          .bind(Id),
-      this);
-}
-
-void SimplifyBooleanExprCheck::matchIfReturnsBool(MatchFinder *Finder,
-                                                  bool Value, StringRef Id) {
-  if (ChainedConditionalReturn)
-    Finder->addMatcher(ifStmt(hasThen(returnsBool(Value, ThenLiteralId)),
-                              hasElse(returnsBool(!Value)))
-                           .bind(Id),
-                       this);
-  else
-    Finder->addMatcher(ifStmt(unless(hasParent(ifStmt())),
-                              hasThen(returnsBool(Value, ThenLiteralId)),
-                              hasElse(returnsBool(!Value)))
-                           .bind(Id),
-                       this);
-}
-
 void SimplifyBooleanExprCheck::matchIfAssignsBool(MatchFinder *Finder,
                                                   bool Value, StringRef Id) {
   auto VarAssign = declRefExpr(hasDeclaration(decl().bind(IfAssignVarId)));
@@ -508,68 +560,6 @@
         this);
 }
 
-static internal::Matcher<Stmt> ifReturnValue(bool Value) {
-  return ifStmt(hasThen(returnsBool(Value)), unless(hasElse(stmt())))
-      .bind(CompoundIfId);
-}
-
-static internal::Matcher<Stmt> returnNotValue(bool Value) {
-  return returnStmt(has(literalOrNegatedBool(!Value))).bind(CompoundReturnId);
-}
-
-void SimplifyBooleanExprCheck::matchCompoundIfReturnsBool(MatchFinder *Finder,
-                                                          bool Value,
-                                                          StringRef Id) {
-  if (ChainedConditionalReturn)
-    Finder->addMatcher(
-        compoundStmt(hasSubstatementSequence(ifReturnValue(Value),
-                                             returnNotValue(Value)))
-            .bind(Id),
-        this);
-  else
-    Finder->addMatcher(
-        compoundStmt(hasSubstatementSequence(ifStmt(hasThen(returnsBool(Value)),
-                                                    unless(hasElse(stmt())),
-                                                    unless(hasParent(ifStmt())))
-                                                 .bind(CompoundIfId),
-                                             returnNotValue(Value)))
-            .bind(Id),
-        this);
-}
-
-void SimplifyBooleanExprCheck::matchCaseIfReturnsBool(MatchFinder *Finder,
-                                                      bool Value,
-                                                      StringRef Id) {
-  internal::Matcher<Stmt> CaseStmt =
-      caseStmt(hasSubstatement(ifReturnValue(Value))).bind(CaseId);
-  internal::Matcher<Stmt> CompoundStmt =
-      compoundStmt(hasSubstatementSequence(CaseStmt, returnNotValue(Value)))
-          .bind(Id);
-  Finder->addMatcher(switchStmt(has(CompoundStmt)), this);
-}
-
-void SimplifyBooleanExprCheck::matchDefaultIfReturnsBool(MatchFinder *Finder,
-                                                         bool Value,
-                                                         StringRef Id) {
-  internal::Matcher<Stmt> DefaultStmt =
-      defaultStmt(hasSubstatement(ifReturnValue(Value))).bind(DefaultId);
-  internal::Matcher<Stmt> CompoundStmt =
-      compoundStmt(hasSubstatementSequence(DefaultStmt, returnNotValue(Value)))
-          .bind(Id);
-  Finder->addMatcher(switchStmt(has(CompoundStmt)), this);
-}
-
-void SimplifyBooleanExprCheck::matchLabelIfReturnsBool(MatchFinder *Finder,
-                                                       bool Value,
-                                                       StringRef Id) {
-  internal::Matcher<Stmt> LabelStmt =
-      labelStmt(hasSubstatement(ifReturnValue(Value))).bind(LabelId);
-  internal::Matcher<Stmt> CompoundStmt =
-      compoundStmt(hasSubstatementSequence(LabelStmt, returnNotValue(Value)))
-          .bind(Id);
-  Finder->addMatcher(CompoundStmt, this);
-}
-
 void SimplifyBooleanExprCheck::storeOptions(ClangTidyOptions::OptionMap &Opts) {
   Options.store(Opts, "ChainedConditionalReturn", ChainedConditionalReturn);
   Options.store(Opts, "ChainedConditionalAssignment",
@@ -579,135 +569,77 @@
 void SimplifyBooleanExprCheck::registerMatchers(MatchFinder *Finder) {
   Finder->addMatcher(translationUnitDecl().bind("top"), this);
 
-  matchBoolCondition(Finder, true, ConditionThenStmtId);
-  matchBoolCondition(Finder, false, ConditionElseStmtId);
-
-  matchTernaryResult(Finder, true, TernaryId);
-  matchTernaryResult(Finder, false, TernaryNegatedId);
-
-  matchIfReturnsBool(Finder, true, IfReturnsBoolId);
-  matchIfReturnsBool(Finder, false, IfReturnsNotBoolId);
-
   matchIfAssignsBool(Finder, true, IfAssignBoolId);
   matchIfAssignsBool(Finder, false, IfAssignNotBoolId);
-
-  matchCompoundIfReturnsBool(Finder, true, CompoundBoolId);
-  matchCompoundIfReturnsBool(Finder, false, CompoundNotBoolId);
-
-  matchCaseIfReturnsBool(Finder, true, CaseCompoundBoolId);
-  matchCaseIfReturnsBool(Finder, false, CaseCompoundNotBoolId);
-
-  matchDefaultIfReturnsBool(Finder, true, DefaultCompoundBoolId);
-  matchDefaultIfReturnsBool(Finder, false, DefaultCompoundNotBoolId);
-
-  matchLabelIfReturnsBool(Finder, true, LabelCompoundBoolId);
-  matchLabelIfReturnsBool(Finder, false, LabelCompoundNotBoolId);
 }
 
 void SimplifyBooleanExprCheck::check(const MatchFinder::MatchResult &Result) {
   if (Result.Nodes.getNodeAs<TranslationUnitDecl>("top"))
-    Visitor(this, Result).TraverseAST(*Result.Context);
-  else if (const Expr *TrueConditionRemoved =
-               getBoolLiteral(Result, ConditionThenStmtId))
-    replaceWithThenStatement(Result, TrueConditionRemoved);
-  else if (const Expr *FalseConditionRemoved =
-               getBoolLiteral(Result, ConditionElseStmtId))
-    replaceWithElseStatement(Result, FalseConditionRemoved);
-  else if (const auto *Ternary =
-               Result.Nodes.getNodeAs<ConditionalOperator>(TernaryId))
-    replaceWithCondition(Result, Ternary, false);
-  else if (const auto *TernaryNegated =
-               Result.Nodes.getNodeAs<ConditionalOperator>(TernaryNegatedId))
-    replaceWithCondition(Result, TernaryNegated, true);
-  else if (const auto *If = Result.Nodes.getNodeAs<IfStmt>(IfReturnsBoolId))
-    replaceWithReturnCondition(Result, If, false);
-  else if (const auto *IfNot =
-               Result.Nodes.getNodeAs<IfStmt>(IfReturnsNotBoolId))
-    replaceWithReturnCondition(Result, IfNot, true);
+    Visitor(this, *Result.Context).traverse();
   else if (const auto *IfAssign =
                Result.Nodes.getNodeAs<IfStmt>(IfAssignBoolId))
     replaceWithAssignment(Result, IfAssign, false);
   else if (const auto *IfAssignNot =
                Result.Nodes.getNodeAs<IfStmt>(IfAssignNotBoolId))
     replaceWithAssignment(Result, IfAssignNot, true);
-  else if (const auto *Compound =
-               Result.Nodes.getNodeAs<CompoundStmt>(CompoundBoolId))
-    replaceCompoundReturnWithCondition(Result, Compound, false);
-  else if (const auto *CompoundNot =
-               Result.Nodes.getNodeAs<CompoundStmt>(CompoundNotBoolId))
-    replaceCompoundReturnWithCondition(Result, CompoundNot, true);
-  else if (Result.Nodes.getNodeAs<CompoundStmt>(CaseCompoundBoolId))
-    replaceCaseCompoundReturnWithCondition(Result, false);
-  else if (Result.Nodes.getNodeAs<CompoundStmt>(CaseCompoundNotBoolId))
-    replaceCaseCompoundReturnWithCondition(Result, true);
-  else if (Result.Nodes.getNodeAs<CompoundStmt>(DefaultCompoundBoolId))
-    replaceDefaultCompoundReturnWithCondition(Result, false);
-  else if (Result.Nodes.getNodeAs<CompoundStmt>(DefaultCompoundNotBoolId))
-    replaceDefaultCompoundReturnWithCondition(Result, true);
-  else if (Result.Nodes.getNodeAs<CompoundStmt>(LabelCompoundBoolId))
-    replaceLabelCompoundReturnWithCondition(Result, false);
-  else if (Result.Nodes.getNodeAs<CompoundStmt>(LabelCompoundNotBoolId))
-    replaceLabelCompoundReturnWithCondition(Result, true);
-  else if (const auto TU = Result.Nodes.getNodeAs<Decl>("top"))
-    Visitor(this, Result).TraverseDecl(const_cast<Decl *>(TU));
-}
-
-void SimplifyBooleanExprCheck::issueDiag(const MatchFinder::MatchResult &Result,
+}
+
+void SimplifyBooleanExprCheck::issueDiag(const ASTContext &Context,
                                          SourceLocation Loc,
                                          StringRef Description,
                                          SourceRange ReplacementRange,
                                          StringRef Replacement) {
   CharSourceRange CharRange =
       Lexer::makeFileCharRange(CharSourceRange::getTokenRange(ReplacementRange),
-                               *Result.SourceManager, getLangOpts());
+                               Context.getSourceManager(), getLangOpts());
 
   DiagnosticBuilder Diag = diag(Loc, Description);
-  if (!containsDiscardedTokens(Result, CharRange))
+  if (!containsDiscardedTokens(Context, CharRange))
     Diag << FixItHint::CreateReplacement(CharRange, Replacement);
 }
 
 void SimplifyBooleanExprCheck::replaceWithThenStatement(
-    const MatchFinder::MatchResult &Result, const Expr *BoolLiteral) {
-  const auto *IfStatement = Result.Nodes.getNodeAs<IfStmt>(IfStmtId);
-  issueDiag(Result, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
+    const ASTContext &Context, const IfStmt *IfStatement,
+    const Expr *BoolLiteral) {
+  issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
             IfStatement->getSourceRange(),
-            getText(Result, *IfStatement->getThen()));
+            getText(Context, *IfStatement->getThen()));
 }
 
 void SimplifyBooleanExprCheck::replaceWithElseStatement(
-    const MatchFinder::MatchResult &Result, const Expr *BoolLiteral) {
-  const auto *IfStatement = Result.Nodes.getNodeAs<IfStmt>(IfStmtId);
+    const ASTContext &Context, const IfStmt *IfStatement,
+    const Expr *BoolLiteral) {
   const Stmt *ElseStatement = IfStatement->getElse();
-  issueDiag(Result, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
+  issueDiag(Context, BoolLiteral->getBeginLoc(), SimplifyConditionDiagnostic,
             IfStatement->getSourceRange(),
-            ElseStatement ? getText(Result, *ElseStatement) : "");
+            ElseStatement ? getText(Context, *ElseStatement) : "");
 }
 
 void SimplifyBooleanExprCheck::replaceWithCondition(
-    const MatchFinder::MatchResult &Result, const ConditionalOperator *Ternary,
+    const ASTContext &Context, const ConditionalOperator *Ternary,
     bool Negated) {
   std::string Replacement =
-      replacementExpression(Result, Negated, Ternary->getCond());
-  issueDiag(Result, Ternary->getTrueExpr()->getBeginLoc(),
+      replacementExpression(Context, Negated, Ternary->getCond());
+  issueDiag(Context, Ternary->getTrueExpr()->getBeginLoc(),
             "redundant boolean literal in ternary expression result",
             Ternary->getSourceRange(), Replacement);
 }
 
 void SimplifyBooleanExprCheck::replaceWithReturnCondition(
-    const MatchFinder::MatchResult &Result, const IfStmt *If, bool Negated) {
+    const ASTContext &Context, const IfStmt *If, const Expr *BoolLiteral,
+    bool Negated) {
   StringRef Terminator = isa<CompoundStmt>(If->getElse()) ? ";" : "";
-  std::string Condition = replacementExpression(Result, Negated, If->getCond());
+  std::string Condition =
+      replacementExpression(Context, Negated, If->getCond());
   std::string Replacement = ("return " + Condition + Terminator).str();
-  SourceLocation Start =
-      Result.Nodes.getNodeAs<CXXBoolLiteralExpr>(ThenLiteralId)->getBeginLoc();
-  issueDiag(Result, Start, SimplifyConditionalReturnDiagnostic,
+  SourceLocation Start = BoolLiteral->getBeginLoc();
+  issueDiag(Context, Start, SimplifyConditionalReturnDiagnostic,
             If->getSourceRange(), Replacement);
 }
 
 void SimplifyBooleanExprCheck::replaceCompoundReturnWithCondition(
-    const MatchFinder::MatchResult &Result, const CompoundStmt *Compound,
-    bool Negated) {
-  const auto *Ret = Result.Nodes.getNodeAs<ReturnStmt>(CompoundReturnId);
+    const ASTContext &Context, const CompoundStmt *Compound,
+    const ReturnStmt *Ret, bool Negated) {
 
   // Scan through the CompoundStmt to look for a chained-if construct.
   const IfStmt *BeforeIf = nullptr;
@@ -722,9 +654,10 @@
             continue;
 
           std::string Replacement =
-              "return " + replacementExpression(Result, Negated, If->getCond());
+              "return " +
+              replacementExpression(Context, Negated, If->getCond());
           issueDiag(
-              Result, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic,
+              Context, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic,
               SourceRange(If->getBeginLoc(), Ret->getEndLoc()), Replacement);
           return;
         }
@@ -738,51 +671,29 @@
 }
 
 void SimplifyBooleanExprCheck::replaceCompoundReturnWithCondition(
-    const MatchFinder::MatchResult &Result, bool Negated, const IfStmt *If) {
+    const ASTContext &Context, const ReturnStmt *Ret, bool Negated,
+    const IfStmt *If) {
   const auto *Lit = stmtReturnsBool(If, Negated);
-  const auto *Ret = Result.Nodes.getNodeAs<ReturnStmt>(CompoundReturnId);
   const std::string Replacement =
-      "return " + replacementExpression(Result, Negated, If->getCond());
-  issueDiag(Result, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic,
+      "return " + replacementExpression(Context, Negated, If->getCond());
+  issueDiag(Context, Lit->getBeginLoc(), SimplifyConditionalReturnDiagnostic,
             SourceRange(If->getBeginLoc(), Ret->getEndLoc()), Replacement);
 }
 
-void SimplifyBooleanExprCheck::replaceCaseCompoundReturnWithCondition(
-    const MatchFinder::MatchResult &Result, bool Negated) {
-  const auto *CaseDefault = Result.Nodes.getNodeAs<CaseStmt>(CaseId);
-  const auto *If = cast<IfStmt>(CaseDefault->getSubStmt());
-  replaceCompoundReturnWithCondition(Result, Negated, If);
-}
-
-void SimplifyBooleanExprCheck::replaceDefaultCompoundReturnWithCondition(
-    const MatchFinder::MatchResult &Result, bool Negated) {
-  const SwitchCase *CaseDefault =
-      Result.Nodes.getNodeAs<DefaultStmt>(DefaultId);
-  const auto *If = cast<IfStmt>(CaseDefault->getSubStmt());
-  replaceCompoundReturnWithCondition(Result, Negated, If);
-}
-
-void SimplifyBooleanExprCheck::replaceLabelCompoundReturnWithCondition(
-    const MatchFinder::MatchResult &Result, bool Negated) {
-  const auto *Label = Result.Nodes.getNodeAs<LabelStmt>(LabelId);
-  const auto *If = cast<IfStmt>(Label->getSubStmt());
-  replaceCompoundReturnWithCondition(Result, Negated, If);
-}
-
 void SimplifyBooleanExprCheck::replaceWithAssignment(
     const MatchFinder::MatchResult &Result, const IfStmt *IfAssign,
     bool Negated) {
   SourceRange Range = IfAssign->getSourceRange();
-  StringRef VariableName =
-      getText(Result, *Result.Nodes.getNodeAs<Expr>(IfAssignVariableId));
+  StringRef VariableName = getText(
+      *Result.Context, *Result.Nodes.getNodeAs<Expr>(IfAssignVariableId));
   StringRef Terminator = isa<CompoundStmt>(IfAssign->getElse()) ? ";" : "";
   std::string Condition =
-      replacementExpression(Result, Negated, IfAssign->getCond());
+      replacementExpression(*Result.Context, Negated, IfAssign->getCond());
   std::string Replacement =
       (VariableName + " = " + Condition + Terminator).str();
   SourceLocation Location =
       Result.Nodes.getNodeAs<CXXBoolLiteralExpr>(IfAssignLocId)->getBeginLoc();
-  issueDiag(Result, Location,
+  issueDiag(*Result.Context, Location,
             "redundant boolean literal in conditional assignment", Range,
             Replacement);
 }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to