Author: Ryosuke Niwa
Date: 2024-05-24T17:48:36-07:00
New Revision: 1c90de5fe3d9f3d4048ba7e4aba2fd1613843f34

URL: 
https://github.com/llvm/llvm-project/commit/1c90de5fe3d9f3d4048ba7e4aba2fd1613843f34
DIFF: 
https://github.com/llvm/llvm-project/commit/1c90de5fe3d9f3d4048ba7e4aba2fd1613843f34.diff

LOG: [analyzer] Allow recursive functions to be trivial. (#91876)

Added: 
    

Modified: 
    clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
    clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp

Removed: 
    


################################################################################
diff  --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp 
b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 5c797d5233089..49bbff1942167 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -271,6 +271,43 @@ class TrivialFunctionAnalysisVisitor
 
   TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {}
 
+  bool IsFunctionTrivial(const Decl *D) {
+    auto CacheIt = Cache.find(D);
+    if (CacheIt != Cache.end())
+      return CacheIt->second;
+
+    // Treat a recursive function call to be trivial until proven otherwise.
+    auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(D, true));
+    if (!IsNew)
+      return RecursiveIt->second;
+
+    bool Result = [&]() {
+      if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
+        for (auto *CtorInit : CtorDecl->inits()) {
+          if (!Visit(CtorInit->getInit()))
+            return false;
+        }
+      }
+      const Stmt *Body = D->getBody();
+      if (!Body)
+        return false;
+      return Visit(Body);
+    }();
+
+    if (!Result) {
+      // D and its mutually recursive callers are all non-trivial.
+      for (auto &It : RecursiveFn)
+        It.second = false;
+    }
+    RecursiveIt = RecursiveFn.find(D);
+    assert(RecursiveIt != RecursiveFn.end());
+    Result = RecursiveIt->second;
+    RecursiveFn.erase(RecursiveIt);
+    Cache[D] = Result;
+
+    return Result;
+  }
+
   bool VisitStmt(const Stmt *S) {
     // All statements are non-trivial unless overriden later.
     // Don't even recurse into children by default.
@@ -368,7 +405,7 @@ class TrivialFunctionAnalysisVisitor
         Name == "bitwise_cast" || Name.find("__builtin") == 0)
       return true;
 
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+    return IsFunctionTrivial(Callee);
   }
 
   bool
@@ -403,7 +440,7 @@ class TrivialFunctionAnalysisVisitor
       return true;
 
     // Recursively descend into the callee to confirm that it's trivial as 
well.
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+    return IsFunctionTrivial(Callee);
   }
 
   bool VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) {
@@ -413,7 +450,7 @@ class TrivialFunctionAnalysisVisitor
     if (!Callee)
       return false;
     // Recursively descend into the callee to confirm that it's trivial as 
well.
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+    return IsFunctionTrivial(Callee);
   }
 
   bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
@@ -439,7 +476,7 @@ class TrivialFunctionAnalysisVisitor
     }
 
     // Recursively descend into the callee to confirm that it's trivial.
-    return TrivialFunctionAnalysis::isTrivialImpl(CE->getConstructor(), Cache);
+    return IsFunctionTrivial(CE->getConstructor());
   }
 
   bool VisitCXXNewExpr(const CXXNewExpr *NE) { return VisitChildren(NE); }
@@ -513,36 +550,13 @@ class TrivialFunctionAnalysisVisitor
 
 private:
   CacheTy &Cache;
+  CacheTy RecursiveFn;
 };
 
 bool TrivialFunctionAnalysis::isTrivialImpl(
     const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
-  // If the function isn't in the cache, conservatively assume that
-  // it's not trivial until analysis completes. This makes every recursive
-  // function non-trivial. This also guarantees that each function
-  // will be scanned at most once.
-  auto [It, IsNew] = Cache.insert(std::make_pair(D, false));
-  if (!IsNew)
-    return It->second;
-
   TrivialFunctionAnalysisVisitor V(Cache);
-
-  if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
-    for (auto *CtorInit : CtorDecl->inits()) {
-      if (!V.Visit(CtorInit->getInit()))
-        return false;
-    }
-  }
-
-  const Stmt *Body = D->getBody();
-  if (!Body)
-    return false;
-
-  bool Result = V.Visit(Body);
-  if (Result)
-    Cache[D] = true;
-
-  return Result;
+  return V.IsFunctionTrivial(D);
 }
 
 bool TrivialFunctionAnalysis::isTrivialImpl(

diff  --git a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp 
b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
index 96986631726fe..a98c6eb9c84d9 100644
--- a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
@@ -231,6 +231,18 @@ class RefCounted {
   void method();
   void someFunction();
   int otherFunction();
+  unsigned recursiveTrivialFunction(int n) { return !n ? 1 : 
recursiveTrivialFunction(n - 1);  }
+  unsigned recursiveComplexFunction(int n) { return !n ? otherFunction() : 
recursiveComplexFunction(n - 1);  }
+  unsigned mutuallyRecursiveFunction1(int n) { return n < 0 ? 1 : (n % 2 ? 
mutuallyRecursiveFunction2(n - 2) : mutuallyRecursiveFunction1(n - 1)); }
+  unsigned mutuallyRecursiveFunction2(int n) { return n < 0 ? 1 : (n % 3 ? 
mutuallyRecursiveFunction2(n - 3) : mutuallyRecursiveFunction1(n - 2)); }
+  unsigned mutuallyRecursiveFunction3(int n) { return n < 0 ? 1 : (n % 5 ? 
mutuallyRecursiveFunction3(n - 5) : mutuallyRecursiveFunction4(n - 3)); }
+  unsigned mutuallyRecursiveFunction4(int n) { return n < 0 ? 1 : (n % 7 ? 
otherFunction() : mutuallyRecursiveFunction3(n - 3)); }
+  unsigned recursiveFunction5(unsigned n) { return n > 100 ? 2 : (n % 2 ? 
recursiveFunction5(n + 1) : recursiveFunction6(n + 2)); }
+  unsigned recursiveFunction6(unsigned n) { return n > 100 ? 3 : (n % 2 ? 
recursiveFunction6(n % 7) : recursiveFunction7(n % 5)); }
+  unsigned recursiveFunction7(unsigned n) { return n > 100 ? 5 : 
recursiveFunction7(n * 5); }
+
+  void mutuallyRecursive8() { mutuallyRecursive9(); someFunction(); }
+  void mutuallyRecursive9() { mutuallyRecursive8(); }
 
   int trivial1() { return 123; }
   float trivial2() { return 0.3; }
@@ -498,6 +510,24 @@ class UnrelatedClass {
     RefCounted::singleton().trivial18(); // no-warning
     RefCounted::singleton().someFunction(); // no-warning
 
+    getFieldTrivial().recursiveTrivialFunction(7); // no-warning
+    getFieldTrivial().recursiveComplexFunction(9);
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().mutuallyRecursiveFunction1(11); // no-warning
+    getFieldTrivial().mutuallyRecursiveFunction2(13); // no-warning
+    getFieldTrivial().mutuallyRecursiveFunction3(17);
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().mutuallyRecursiveFunction4(19);
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().recursiveFunction5(23); // no-warning
+    getFieldTrivial().recursiveFunction6(29); // no-warning
+    getFieldTrivial().recursiveFunction7(31); // no-warning
+
+    getFieldTrivial().mutuallyRecursive8();
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().mutuallyRecursive9();
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+
     getFieldTrivial().someFunction();
     // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
     getFieldTrivial().nonTrivial1();


        
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to