https://github.com/rniwa updated https://github.com/llvm/llvm-project/pull/91876

>From e40017a2750ee39bfd1a87b5ddea620076bc4419 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rn...@apple.com>
Date: Sat, 11 May 2024 20:18:52 -0700
Subject: [PATCH 1/6] [analyzer] Allow recursive functions to be trivial.

---
 .../Checkers/WebKit/PtrTypesSemantics.cpp      | 18 +++++++++---------
 .../Checkers/WebKit/uncounted-obj-arg.cpp      |  6 ++++++
 2 files changed, 15 insertions(+), 9 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp 
b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 5c797d5233089..2a4da9eeaee6d 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -517,11 +517,9 @@ class TrivialFunctionAnalysisVisitor
 
 bool TrivialFunctionAnalysis::isTrivialImpl(
     const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
-  // If the function isn't in the cache, conservatively assume that
-  // it's not trivial until analysis completes. This makes every recursive
-  // function non-trivial. This also guarantees that each function
-  // will be scanned at most once.
-  auto [It, IsNew] = Cache.insert(std::make_pair(D, false));
+  // Treat every recursive function as trivial until otherwise proven.
+  // This guarantees each function is evaluated at most once.
+  auto [It, IsNew] = Cache.insert(std::make_pair(D, true));
   if (!IsNew)
     return It->second;
 
@@ -535,12 +533,14 @@ bool TrivialFunctionAnalysis::isTrivialImpl(
   }
 
   const Stmt *Body = D->getBody();
-  if (!Body)
-    return false;
+  if (!Body) {
+    Cache[D] = false;
+    return false;    
+  }
 
   bool Result = V.Visit(Body);
-  if (Result)
-    Cache[D] = true;
+  if (!Result)
+    Cache[D] = false;
 
   return Result;
 }
diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp 
b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
index 96986631726fe..18af9e17f78b0 100644
--- a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
@@ -231,6 +231,8 @@ class RefCounted {
   void method();
   void someFunction();
   int otherFunction();
+  unsigned recursiveFunction(int n) { return !n ? 1 : recursiveFunction(n - 
1);  }
+  unsigned recursiveComplexFunction(int n) { return !n ? otherFunction() : 
recursiveComplexFunction(n - 1);  }
 
   int trivial1() { return 123; }
   float trivial2() { return 0.3; }
@@ -498,6 +500,10 @@ class UnrelatedClass {
     RefCounted::singleton().trivial18(); // no-warning
     RefCounted::singleton().someFunction(); // no-warning
 
+    getFieldTrivial().recursiveFunction(7); // no-warning
+    getFieldTrivial().recursiveComplexFunction(9);
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+
     getFieldTrivial().someFunction();
     // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
     getFieldTrivial().nonTrivial1();

>From 5d7a259c0209a8cbb70c2718518905669db2a885 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rn...@apple.com>
Date: Sat, 11 May 2024 20:24:10 -0700
Subject: [PATCH 2/6] Fix formatting.

---
 clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp 
b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 2a4da9eeaee6d..449e56be9984d 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -535,7 +535,7 @@ bool TrivialFunctionAnalysis::isTrivialImpl(
   const Stmt *Body = D->getBody();
   if (!Body) {
     Cache[D] = false;
-    return false;    
+    return false;
   }
 
   bool Result = V.Visit(Body);

>From 9971dcdebc160acb708b548bf67ef5efa76d4fe0 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rn...@apple.com>
Date: Sun, 12 May 2024 15:15:33 -0700
Subject: [PATCH 3/6] Fix the TrivialFunctionAnalysis::isTrivialImpl for
 mutually recursive functions.

Instead of assuming every function to be initially trivial, we explicitly track
the set of functions that we're currently visting. When one of the currently 
visited
function is determined to be not trivial, we clear this set to signal that all
mutually recursive functions are non-trivial. We conclude that a function is 
trivial
when Visit() call on the function body returned true **AND** the set still 
contains
the function.

To implement this new algorithm, a new public function, IsFunctionTrivial,
is introduced to TrivialFunctionAnalysisVisitor, and various Visit functions in
TrivialFunctionAnalysisVisitor has been updated to use this function instead of
TrivialFunctionAnalysis::isTrivialImpl, which is now a wrapper for the function.
---
 .../Checkers/WebKit/PtrTypesSemantics.cpp     | 69 +++++++++++--------
 .../Checkers/WebKit/uncounted-obj-arg.cpp     | 20 +++++-
 2 files changed, 58 insertions(+), 31 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp 
b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 449e56be9984d..86524d6223c53 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -271,6 +271,40 @@ class TrivialFunctionAnalysisVisitor
 
   TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {}
 
+  bool IsFunctionTrivial(const Decl *D) {
+    auto CacheIt = Cache.find(D);
+    if (CacheIt != Cache.end())
+      return CacheIt->second;
+
+    // Treat a recursive function call to be trivial until proven otherwise.
+    auto [RecursiveIt, IsNew] = RecursiveFn.insert(D);
+    if (!IsNew)
+      return true;
+
+    bool Result = [&]() {
+      if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
+        for (auto *CtorInit : CtorDecl->inits()) {
+          if (!Visit(CtorInit->getInit()))
+            return false;
+        }
+      }
+      const Stmt *Body = D->getBody();
+      if (!Body)
+        return false;
+      return Visit(Body);
+    }();
+
+    Cache[D] = Result;
+
+    if (!Result) // D and its mutually recursive callers are non-trivial.
+      RecursiveFn.clear();
+    else // Check if any of mutually recursive functions were non-trivial.
+      Result = RecursiveFn.contains(D);
+    RecursiveFn.erase(D);
+
+    return Result;
+  }
+
   bool VisitStmt(const Stmt *S) {
     // All statements are non-trivial unless overriden later.
     // Don't even recurse into children by default.
@@ -368,7 +402,7 @@ class TrivialFunctionAnalysisVisitor
         Name == "bitwise_cast" || Name.find("__builtin") == 0)
       return true;
 
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+    return IsFunctionTrivial(Callee);
   }
 
   bool
@@ -403,7 +437,7 @@ class TrivialFunctionAnalysisVisitor
       return true;
 
     // Recursively descend into the callee to confirm that it's trivial as 
well.
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+    return IsFunctionTrivial(Callee);
   }
 
   bool VisitCXXOperatorCallExpr(const CXXOperatorCallExpr *OCE) {
@@ -413,7 +447,7 @@ class TrivialFunctionAnalysisVisitor
     if (!Callee)
       return false;
     // Recursively descend into the callee to confirm that it's trivial as 
well.
-    return TrivialFunctionAnalysis::isTrivialImpl(Callee, Cache);
+    return IsFunctionTrivial(Callee);
   }
 
   bool VisitCXXDefaultArgExpr(const CXXDefaultArgExpr *E) {
@@ -439,7 +473,7 @@ class TrivialFunctionAnalysisVisitor
     }
 
     // Recursively descend into the callee to confirm that it's trivial.
-    return TrivialFunctionAnalysis::isTrivialImpl(CE->getConstructor(), Cache);
+    return IsFunctionTrivial(CE->getConstructor());
   }
 
   bool VisitCXXNewExpr(const CXXNewExpr *NE) { return VisitChildren(NE); }
@@ -513,36 +547,13 @@ class TrivialFunctionAnalysisVisitor
 
 private:
   CacheTy &Cache;
+  llvm::DenseSet<llvm::PointerUnion<const Decl *>> RecursiveFn;
 };
 
 bool TrivialFunctionAnalysis::isTrivialImpl(
     const Decl *D, TrivialFunctionAnalysis::CacheTy &Cache) {
-  // Treat every recursive function as trivial until otherwise proven.
-  // This guarantees each function is evaluated at most once.
-  auto [It, IsNew] = Cache.insert(std::make_pair(D, true));
-  if (!IsNew)
-    return It->second;
-
   TrivialFunctionAnalysisVisitor V(Cache);
-
-  if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
-    for (auto *CtorInit : CtorDecl->inits()) {
-      if (!V.Visit(CtorInit->getInit()))
-        return false;
-    }
-  }
-
-  const Stmt *Body = D->getBody();
-  if (!Body) {
-    Cache[D] = false;
-    return false;
-  }
-
-  bool Result = V.Visit(Body);
-  if (!Result)
-    Cache[D] = false;
-
-  return Result;
+  return V.IsFunctionTrivial(D);
 }
 
 bool TrivialFunctionAnalysis::isTrivialImpl(
diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp 
b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
index 18af9e17f78b0..a6e64191131b3 100644
--- a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
@@ -231,8 +231,15 @@ class RefCounted {
   void method();
   void someFunction();
   int otherFunction();
-  unsigned recursiveFunction(int n) { return !n ? 1 : recursiveFunction(n - 
1);  }
+  unsigned recursiveTrivialFunction(int n) { return !n ? 1 : 
recursiveTrivialFunction(n - 1);  }
   unsigned recursiveComplexFunction(int n) { return !n ? otherFunction() : 
recursiveComplexFunction(n - 1);  }
+  unsigned mutuallyRecursiveFunction1(int n) { return n < 0 ? 1 : (n % 2 ? 
mutuallyRecursiveFunction2(n - 2) : mutuallyRecursiveFunction1(n - 1)); }
+  unsigned mutuallyRecursiveFunction2(int n) { return n < 0 ? 1 : (n % 3 ? 
mutuallyRecursiveFunction2(n - 3) : mutuallyRecursiveFunction1(n - 2)); }
+  unsigned mutuallyRecursiveFunction3(int n) { return n < 0 ? 1 : (n % 5 ? 
mutuallyRecursiveFunction3(n - 5) : mutuallyRecursiveFunction4(n - 3)); }
+  unsigned mutuallyRecursiveFunction4(int n) { return n < 0 ? 1 : (n % 7 ? 
otherFunction() : mutuallyRecursiveFunction3(n - 3)); }
+  unsigned mutuallyRecursiveFunction5(unsigned n) { return n > 100 ? 2 : (n % 
2 ? mutuallyRecursiveFunction5(n + 1) : mutuallyRecursiveFunction6(n + 2)); }
+  unsigned mutuallyRecursiveFunction6(unsigned n) { return n > 100 ? 3 : (n % 
2 ? mutuallyRecursiveFunction6(n % 7) : mutuallyRecursiveFunction7(n % 5)); }
+  unsigned mutuallyRecursiveFunction7(unsigned n) { return n > 100 ? 5 : 
mutuallyRecursiveFunction7(n * 5); }
 
   int trivial1() { return 123; }
   float trivial2() { return 0.3; }
@@ -500,9 +507,18 @@ class UnrelatedClass {
     RefCounted::singleton().trivial18(); // no-warning
     RefCounted::singleton().someFunction(); // no-warning
 
-    getFieldTrivial().recursiveFunction(7); // no-warning
+    getFieldTrivial().recursiveTrivialFunction(7); // no-warning
     getFieldTrivial().recursiveComplexFunction(9);
     // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().mutuallyRecursiveFunction1(11); // no-warning
+    getFieldTrivial().mutuallyRecursiveFunction2(13); // no-warning
+    getFieldTrivial().mutuallyRecursiveFunction3(17);
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().mutuallyRecursiveFunction4(19);
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().mutuallyRecursiveFunction5(23); // no-warning
+    getFieldTrivial().mutuallyRecursiveFunction6(29); // no-warning
+    getFieldTrivial().mutuallyRecursiveFunction7(31); // no-warning
 
     getFieldTrivial().someFunction();
     // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}

>From 99e3b24565ad23286256d74fdd76aa7f7c8cdd93 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rn...@apple.com>
Date: Sun, 12 May 2024 23:09:18 -0700
Subject: [PATCH 4/6] Fix a bug that IsFunctionTrivial was not updating the
 cache upon completion.

Also fix a bug that IsFunctionTrivial would do a redundant traversal when
a function had been determined to be non-trivial because it's indistinguishable
if a given function had not been traversed or had been found to be non-trivial
because we were using the absense of the function in the hash set to indicate
the non-triviality of a function. Use a hash map of a function to a boolean
instead to explicitly track whether a given function had been visited or not,
and whether a given function had been determined to be non-trivial or not.
---
 .../Checkers/WebKit/PtrTypesSemantics.cpp     | 19 ++++++++++---------
 1 file changed, 10 insertions(+), 9 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp 
b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 86524d6223c53..4c5d963f7015b 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -277,9 +277,9 @@ class TrivialFunctionAnalysisVisitor
       return CacheIt->second;
 
     // Treat a recursive function call to be trivial until proven otherwise.
-    auto [RecursiveIt, IsNew] = RecursiveFn.insert(D);
+    auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(D, true));
     if (!IsNew)
-      return true;
+      return RecursiveIt->second;
 
     bool Result = [&]() {
       if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) {
@@ -294,13 +294,14 @@ class TrivialFunctionAnalysisVisitor
       return Visit(Body);
     }();
 
-    Cache[D] = Result;
-
-    if (!Result) // D and its mutually recursive callers are non-trivial.
-      RecursiveFn.clear();
-    else // Check if any of mutually recursive functions were non-trivial.
-      Result = RecursiveFn.contains(D);
+    if (!Result) {
+      // D and its mutually recursive callers are all non-trivial.
+      for (auto& It : RecursiveFn)
+        It.second = false;
+    }
+    assert(RecursiveFn[D] == Result);
     RecursiveFn.erase(D);
+    Cache[D] = Result;
 
     return Result;
   }
@@ -547,7 +548,7 @@ class TrivialFunctionAnalysisVisitor
 
 private:
   CacheTy &Cache;
-  llvm::DenseSet<llvm::PointerUnion<const Decl *>> RecursiveFn;
+  CacheTy RecursiveFn;
 };
 
 bool TrivialFunctionAnalysis::isTrivialImpl(

>From f638d18b005a6925a06c22fabd14e8bcf927d073 Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rn...@apple.com>
Date: Sun, 12 May 2024 23:48:25 -0700
Subject: [PATCH 5/6] Fix the cache updating logic.

---
 .../StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp  | 8 +++++---
 1 file changed, 5 insertions(+), 3 deletions(-)

diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp 
b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
index 4c5d963f7015b..49bbff1942167 100644
--- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
+++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp
@@ -296,11 +296,13 @@ class TrivialFunctionAnalysisVisitor
 
     if (!Result) {
       // D and its mutually recursive callers are all non-trivial.
-      for (auto& It : RecursiveFn)
+      for (auto &It : RecursiveFn)
         It.second = false;
     }
-    assert(RecursiveFn[D] == Result);
-    RecursiveFn.erase(D);
+    RecursiveIt = RecursiveFn.find(D);
+    assert(RecursiveIt != RecursiveFn.end());
+    Result = RecursiveIt->second;
+    RecursiveFn.erase(RecursiveIt);
     Cache[D] = Result;
 
     return Result;

>From dea8371992492aeae1e060b609fc72a41786a1ec Mon Sep 17 00:00:00 2001
From: Ryosuke Niwa <rn...@apple.com>
Date: Thu, 23 May 2024 22:28:56 -0700
Subject: [PATCH 6/6] Add an actually mutually recursive function, and rename
 exiting recursive functions.

---
 .../Checkers/WebKit/uncounted-obj-arg.cpp     | 20 +++++++++++++------
 1 file changed, 14 insertions(+), 6 deletions(-)

diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp 
b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
index a6e64191131b3..a98c6eb9c84d9 100644
--- a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
+++ b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp
@@ -237,9 +237,12 @@ class RefCounted {
   unsigned mutuallyRecursiveFunction2(int n) { return n < 0 ? 1 : (n % 3 ? 
mutuallyRecursiveFunction2(n - 3) : mutuallyRecursiveFunction1(n - 2)); }
   unsigned mutuallyRecursiveFunction3(int n) { return n < 0 ? 1 : (n % 5 ? 
mutuallyRecursiveFunction3(n - 5) : mutuallyRecursiveFunction4(n - 3)); }
   unsigned mutuallyRecursiveFunction4(int n) { return n < 0 ? 1 : (n % 7 ? 
otherFunction() : mutuallyRecursiveFunction3(n - 3)); }
-  unsigned mutuallyRecursiveFunction5(unsigned n) { return n > 100 ? 2 : (n % 
2 ? mutuallyRecursiveFunction5(n + 1) : mutuallyRecursiveFunction6(n + 2)); }
-  unsigned mutuallyRecursiveFunction6(unsigned n) { return n > 100 ? 3 : (n % 
2 ? mutuallyRecursiveFunction6(n % 7) : mutuallyRecursiveFunction7(n % 5)); }
-  unsigned mutuallyRecursiveFunction7(unsigned n) { return n > 100 ? 5 : 
mutuallyRecursiveFunction7(n * 5); }
+  unsigned recursiveFunction5(unsigned n) { return n > 100 ? 2 : (n % 2 ? 
recursiveFunction5(n + 1) : recursiveFunction6(n + 2)); }
+  unsigned recursiveFunction6(unsigned n) { return n > 100 ? 3 : (n % 2 ? 
recursiveFunction6(n % 7) : recursiveFunction7(n % 5)); }
+  unsigned recursiveFunction7(unsigned n) { return n > 100 ? 5 : 
recursiveFunction7(n * 5); }
+
+  void mutuallyRecursive8() { mutuallyRecursive9(); someFunction(); }
+  void mutuallyRecursive9() { mutuallyRecursive8(); }
 
   int trivial1() { return 123; }
   float trivial2() { return 0.3; }
@@ -516,9 +519,14 @@ class UnrelatedClass {
     // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
     getFieldTrivial().mutuallyRecursiveFunction4(19);
     // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
-    getFieldTrivial().mutuallyRecursiveFunction5(23); // no-warning
-    getFieldTrivial().mutuallyRecursiveFunction6(29); // no-warning
-    getFieldTrivial().mutuallyRecursiveFunction7(31); // no-warning
+    getFieldTrivial().recursiveFunction5(23); // no-warning
+    getFieldTrivial().recursiveFunction6(29); // no-warning
+    getFieldTrivial().recursiveFunction7(31); // no-warning
+
+    getFieldTrivial().mutuallyRecursive8();
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
+    getFieldTrivial().mutuallyRecursive9();
+    // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}
 
     getFieldTrivial().someFunction();
     // expected-warning@-1{{Call argument for 'this' parameter is uncounted 
and unsafe}}

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to