================
@@ -3068,12 +3124,242 @@ void MallocChecker::checkDeadSymbols(SymbolReaper 
&SymReaper,
   C.addTransition(state->set<RegionState>(RS), N);
 }
 
+// Helper function to check if a name is a recognized smart pointer name
+static bool isSmartPtrName(StringRef Name) {
+  return Name == "unique_ptr" || Name == "shared_ptr";
+}
+
+// Allowlist of owning smart pointers we want to recognize.
+// Start with unique_ptr and shared_ptr. (intentionally exclude weak_ptr)
+static bool isSmartOwningPtrType(QualType QT) {
+  QT = QT->getCanonicalTypeUnqualified();
+
+  // First try TemplateSpecializationType (for std smart pointers)
+  if (const auto *TST = QT->getAs<TemplateSpecializationType>()) {
+    const TemplateDecl *TD = TST->getTemplateName().getAsTemplateDecl();
+    if (!TD)
+      return false;
+
+    const auto *ND = dyn_cast_or_null<NamedDecl>(TD->getTemplatedDecl());
+    if (!ND)
+      return false;
+
+    // Check if it's in std namespace
+    if (!isWithinStdNamespace(ND))
+      return false;
+
+    return isSmartPtrName(ND->getName());
+  }
+
+  // Also try RecordType (for custom smart pointer implementations)
+  if (const auto *RD = QT->getAsCXXRecordDecl()) {
+    // Accept any custom unique_ptr or shared_ptr implementation
+    return isSmartPtrName(RD->getName());
+  }
+
+  return false;
+}
+
+/// Check if a record type has smart pointer fields (directly or in base
+/// classes).
+static bool hasSmartPtrField(const CXXRecordDecl *CRD) {
+  // Check direct fields
+  if (llvm::any_of(CRD->fields(), [](const FieldDecl *FD) {
+        return isSmartOwningPtrType(FD->getType());
+      }))
+    return true;
+
+  // Check fields from base classes
+  for (const CXXBaseSpecifier &Base : CRD->bases()) {
+    if (const CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl()) {
+      if (hasSmartPtrField(BaseDecl))
+        return true;
+    }
+  }
+  return false;
+}
+
+/// Check if an expression is an rvalue record type passed by value.
+static bool isRvalueByValueRecord(const Expr *AE) {
+  if (AE->isGLValue())
+    return false;
+
+  QualType T = AE->getType();
+  if (!T->isRecordType() || T->isReferenceType())
+    return false;
+
+  // Accept common temp/construct forms but don't overfit.
+  return isa<CXXTemporaryObjectExpr, MaterializeTemporaryExpr, 
CXXConstructExpr,
+             InitListExpr, ImplicitCastExpr, CXXBindTemporaryExpr>(AE);
+}
+
+/// Check if an expression is an rvalue record with smart pointer fields passed
+/// by value.
+static bool isRvalueByValueRecordWithSmartPtr(const Expr *AE) {
+  if (!isRvalueByValueRecord(AE))
+    return false;
+
+  const auto *CRD = AE->getType()->getAsCXXRecordDecl();
+  return CRD && hasSmartPtrField(CRD);
+}
+
+/// Check if a CXXRecordDecl has a name matching recognized smart pointer 
names.
+static bool isSmartOwningPtrRecord(const CXXRecordDecl *RD) {
+  if (!RD)
+    return false;
+
+  // Check the record name directly
+  if (isSmartPtrName(RD->getName())) {
+    // Accept both std and custom smart pointer implementations
+    return true;
+  }
+
+  return false;
+}
+
+/// Check if a call is a constructor of a smart pointer class that accepts
+/// pointer parameters.
+static bool isSmartPtrCall(const CallEvent &Call) {
+  // Only check for smart pointer constructor calls
+  const auto *CD = dyn_cast_or_null<CXXConstructorDecl>(Call.getDecl());
+  if (!CD)
+    return false;
+
+  const auto *RD = CD->getParent();
+  if (!isSmartOwningPtrRecord(RD))
+    return false;
+
+  // Check if constructor takes a pointer parameter
+  for (const auto *Param : CD->parameters()) {
+    QualType ParamType = Param->getType();
+    if (ParamType->isPointerType() && !ParamType->isFunctionPointerType() &&
+        !ParamType->isVoidPointerType()) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+static void collectSmartOwningPtrFieldRegions(
----------------
NagyDonat wrote:

This function duplicates the traversal logic which also appears in 
`hasSmartPtrField`, and I fear that it would be troublesome to preserve 
consistency between them. (Note that there are already some inconsistencies: as 
far as I see it's possible that  `hasSmartPtrField` returns `true` but 
`collectSmartOwningPtrFieldRegions` doesn't find any regions. This is not 
necessarily a problem, but the code that uses these functions must be careful 
to cover it.)

I can imagine three potential solutions for this problem:
1. Placing comments like `// WARNING: Keep the traversal in this function in 
sync with ...` at the beginning of these two functions. Not elegant, but I can 
accept this if you don't like the other options.
2. Eliminating the preliminary traversal step which is performed by 
`hasSmartPtrField` slightly before calling `collectSmartOwningPtrFieldRegions`. 
This would significantly simplify the logic – you won't need to do the same 
traversal twice – but could potentially worsen the (performance if the 
preliminary traversal step lets us avoid constructing the base object regions). 
I strongly suspect that this performance loss is negligible, especially if we 
compare it to the overall analyzer runtime, so – referring to "Premature 
optimization is the root of all evil." – I'd primarily suggest this approach.
3. I also tried to refactor these two functions to ensure that they share the 
same "core logic" which can optionally pass along a context struct object:
```c++
struct FieldConsumer {
  const MemRegion *Base;
  CheckerContext &C;
  llvm::SmallPtrSetImpl<const MemRegion *> &Out;
  
  void consume(const FieldDecl *FD) {
    SVal L = C.getState()->getLValue(FD, loc::MemRegionVal(Base));
    if (const MemRegion *FR = L.getAsRegion())
      Out.insert(FR);
  }
  std::optional<FieldConsumer> switchToBase(const CXXRecordDecl *BaseDecl, bool 
IsVirtual) {
    // Get the base class region
    SVal BaseL = C.getState()->getLValue(BaseDecl, Base->getAs<SubRegion>(),
                                         IsVirtual);
    if (const MemRegion *BaseRegion = BaseL.getAsRegion()) {
      // Return a consumer 
      return FieldConsumer{BaseRegion, C, Out};
    }
    return std::nullopt;
  }
};

static bool hasSmartPtrField(const CXXRecordDecl *CRD,
      std::optional<FieldConsumer> FC = std::nullopt) {
  for (const FieldDecl *FD : CRD->fields()) {
    if (isSmartOwningPtrType(FD->getType())) {
      if (!FC)
        return true;
      FC->consume(FD);
    }
  }

  // Collect fields from base classes
  for (const CXXBaseSpecifier &BaseSpec : CRD->bases()) {
    if (const CXXRecordDecl *BaseDecl =
            BaseSpec.getType()->getAsCXXRecordDecl()) {
      std::optional<FieldConsumer> NewFC;
      if (FC) {
        NewFC = FC->switchToBase(BaseDecl, BaseSpec.isVirtual());
        if (!NewFC)
           continue;
      }
      bool Found = hasSmartPtrField(BaseDecl, NewFC);
      if (Found && !FC)
        return true;
    }
  }
  return false;
}

static void collectSmartOwningPtrFieldRegions(const MemRegion *Base, QualType 
RecQT,
    CheckerContext &C, llvm::SmallPtrSetImpl<const MemRegion *> &Out) {
  if (!Base)
    return;
  const auto *CRD = RecQT->getAsCXXRecordDecl();
  if (!CRD)
    return;
  FieldConsumer FC{Base, C, Out};
  hasSmartPtrField(CRD, FC);
}  
```
[Disclaimer: I wrote this code in the browser, so it probably contains some 
typos.] I feel that this is a bit too verbose, but I would still prefer this 
over the status quo, because I value the maintainability highly (I wouldn't 
like to troubleshoot bugs introduced by a discrepancy between these two 
functions).

https://github.com/llvm/llvm-project/pull/152751
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to