================
@@ -3068,12 +3124,240 @@ void MallocChecker::checkDeadSymbols(SymbolReaper 
&SymReaper,
   C.addTransition(state->set<RegionState>(RS), N);
 }
 
+// Helper function to check if a name is a recognized smart pointer name
+static bool isSmartPtrName(StringRef Name) {
+  return Name == "unique_ptr" || Name == "shared_ptr";
+}
+
+// Allowlist of owning smart pointers we want to recognize.
+// Start with unique_ptr and shared_ptr. (intentionally exclude weak_ptr)
+static bool isSmartOwningPtrType(QualType QT) {
+  QT = QT->getCanonicalTypeUnqualified();
+
+  // First try TemplateSpecializationType (for std smart pointers)
+  if (const auto *TST = QT->getAs<TemplateSpecializationType>()) {
+    const TemplateDecl *TD = TST->getTemplateName().getAsTemplateDecl();
+    if (!TD)
+      return false;
+
+    const auto *ND = dyn_cast_or_null<NamedDecl>(TD->getTemplatedDecl());
+    if (!ND)
+      return false;
+
+    // Check if it's in std namespace
+    if (!isWithinStdNamespace(ND))
+      return false;
+
+    return isSmartPtrName(ND->getName());
+  }
+
+  // Also try RecordType (for custom smart pointer implementations)
+  if (const auto *RD = QT->getAsCXXRecordDecl()) {
+    // Accept any custom unique_ptr or shared_ptr implementation
+    return isSmartPtrName(RD->getName());
+  }
+
+  return false;
+}
+
+/// Check if a record type has smart pointer fields (directly or in base
+/// classes).
+static bool hasSmartPtrField(const CXXRecordDecl *CRD) {
+  // Check direct fields
+  if (llvm::any_of(CRD->fields(), [](const FieldDecl *FD) {
+        return isSmartOwningPtrType(FD->getType());
+      }))
+    return true;
+
+  // Check fields from base classes
+  for (const CXXBaseSpecifier &Base : CRD->bases()) {
+    if (const CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl()) {
+      if (hasSmartPtrField(BaseDecl))
+        return true;
+    }
+  }
+  return false;
+}
+
+/// Check if an expression is an rvalue record type passed by value.
+static bool isRvalueByValueRecord(const Expr *AE) {
+  if (AE->isGLValue())
+    return false;
+
+  QualType T = AE->getType();
+  if (!T->isRecordType() || T->isReferenceType())
+    return false;
+
+  // Accept common temp/construct forms but don't overfit.
+  return isa<CXXTemporaryObjectExpr, MaterializeTemporaryExpr, 
CXXConstructExpr,
+             InitListExpr, ImplicitCastExpr, CXXBindTemporaryExpr>(AE);
+}
+
+/// Check if an expression is an rvalue record with smart pointer fields passed
+/// by value.
+static bool isRvalueByValueRecordWithSmartPtr(const Expr *AE) {
+  if (!isRvalueByValueRecord(AE))
+    return false;
+
+  const auto *CRD = AE->getType()->getAsCXXRecordDecl();
+  return CRD && hasSmartPtrField(CRD);
+}
+
+/// Check if a CXXRecordDecl has a name matching recognized smart pointer 
names.
+static bool isSmartOwningPtrRecord(const CXXRecordDecl *RD) {
+  if (!RD)
+    return false;
+
+  // Check the record name directly
+  if (isSmartPtrName(RD->getName())) {
+    // Accept both std and custom smart pointer implementations
+    return true;
+  }
+
+  return false;
+}
+
+/// Check if a call is a constructor of a smart pointer class that accepts
+/// pointer parameters.
+static bool isSmartPtrCall(const CallEvent &Call) {
+  // Only check for smart pointer constructor calls
+  const auto *CD = dyn_cast_or_null<CXXConstructorDecl>(Call.getDecl());
+  if (!CD)
+    return false;
+
+  const auto *RD = CD->getParent();
+  if (!isSmartOwningPtrRecord(RD))
+    return false;
+
+  // Check if constructor takes a pointer parameter
+  for (const auto *Param : CD->parameters()) {
+    QualType ParamType = Param->getType();
+    if (ParamType->isPointerType() && !ParamType->isFunctionPointerType() &&
+        !ParamType->isVoidPointerType()) {
+      return true;
+    }
+  }
+
+  return false;
+}
+
+static void collectDirectSmartOwningPtrFieldRegions(
+    const MemRegion *Base, QualType RecQT, CheckerContext &C,
+    SmallVectorImpl<const MemRegion *> &Out) {
+  if (!Base)
+    return;
+  const auto *CRD = RecQT->getAsCXXRecordDecl();
+  if (!CRD)
+    return;
+
+  // Collect direct fields
+  for (const FieldDecl *FD : CRD->fields()) {
+    if (!isSmartOwningPtrType(FD->getType()))
+      continue;
+    SVal L = C.getState()->getLValue(FD, loc::MemRegionVal(Base));
+    if (const MemRegion *FR = L.getAsRegion())
+      Out.push_back(FR);
+  }
+
+  // Collect fields from base classes
+  for (const CXXBaseSpecifier &BaseSpec : CRD->bases()) {
+    if (const CXXRecordDecl *BaseDecl =
+            BaseSpec.getType()->getAsCXXRecordDecl()) {
+      // Get the base class region
+      SVal BaseL = C.getState()->getLValue(BaseDecl, Base->getAs<SubRegion>(),
+                                           BaseSpec.isVirtual());
+      if (const MemRegion *BaseRegion = BaseL.getAsRegion()) {
+        // Recursively collect fields from this base class
+        collectDirectSmartOwningPtrFieldRegions(BaseRegion, BaseSpec.getType(),
+                                                C, Out);
+      }
+    }
+  }
+}
+
+/// Handle smart pointer constructor calls by escaping allocated symbols
+/// that are passed as pointer arguments to the constructor.
+ProgramStateRef MallocChecker::handleSmartPointerConstructorArguments(
+    const CallEvent &Call, ProgramStateRef State) const {
+  const auto *CD = cast<CXXConstructorDecl>(Call.getDecl());
+  for (unsigned I = 0, E = Call.getNumArgs(); I != E; ++I) {
+    const Expr *ArgExpr = Call.getArgExpr(I);
+    if (!ArgExpr)
+      continue;
+
+    QualType ParamType = CD->getParamDecl(I)->getType();
+    if (ParamType->isPointerType() && !ParamType->isFunctionPointerType() &&
+        !ParamType->isVoidPointerType()) {
+      // This argument is a pointer being passed to smart pointer constructor
+      SVal ArgVal = Call.getArgSVal(I);
+      SymbolRef Sym = ArgVal.getAsSymbol();
+      if (Sym && State->contains<RegionState>(Sym)) {
+        const RefState *RS = State->get<RegionState>(Sym);
+        if (RS && (RS->isAllocated() || RS->isAllocatedOfSizeZero())) {
+          State = State->set<RegionState>(Sym, RefState::getEscaped(RS));
+        }
+      }
+    }
+  }
+  return State;
+}
+
+/// Handle all smart pointer related processing in function calls.
+/// This includes both direct smart pointer constructor calls and by-value
+/// arguments containing smart pointer fields.
+ProgramStateRef MallocChecker::handleSmartPointerRelatedCalls(
+    const CallEvent &Call, CheckerContext &C, ProgramStateRef State) const {
+
+  // Handle direct smart pointer constructor calls first
+  if (isSmartPtrCall(Call)) {
+    return handleSmartPointerConstructorArguments(Call, State);
+  }
+
+  // Handle smart pointer fields in by-value record arguments
+  SmallVector<const MemRegion *, 8> SmartPtrFieldRoots;
+  for (unsigned I = 0, E = Call.getNumArgs(); I != E; ++I) {
+    const Expr *AE = Call.getArgExpr(I);
+    if (!AE)
+      continue;
+    AE = AE->IgnoreParenImpCasts();
+
+    if (!isRvalueByValueRecordWithSmartPtr(AE))
+      continue;
+
+    // Find a region for the argument.
+    SVal ArgVal = Call.getArgSVal(I);
+    const MemRegion *ArgRegion = ArgVal.getAsRegion();
+    if (!ArgRegion) {
+      // Skip this argument to prevent overly broad escaping that would
+      // suppress legitimate leak detection
+      continue;
+    }
+
+    // Collect direct smart owning pointer field regions
+    collectDirectSmartOwningPtrFieldRegions(ArgRegion, AE->getType(), C,
+                                            SmartPtrFieldRoots);
+  }
+
+  // Escape symbols reachable from smart pointer fields
+  if (!SmartPtrFieldRoots.empty()) {
+    State = EscapeTrackedCallback::EscapeTrackedRegionsReachableFrom(
+        SmartPtrFieldRoots, State);
+  }
+
+  return State;
+}
+
 void MallocChecker::checkPostCall(const CallEvent &Call,
                                   CheckerContext &C) const {
+  // Handle existing post-call handlers first
   if (const auto *PostFN = PostFnMap.lookup(Call)) {
     (*PostFN)(this, C.getState(), Call, C);
-    return;
+    return; // Post-handler already called addTransition, we're done
   }
+
+  // Handle smart pointer related processing only if no post-handler was called
+  ProgramStateRef State = handleSmartPointerRelatedCalls(Call, C, 
C.getState());
+  C.addTransition(State);
----------------
steakhal wrote:

```suggestion
  C.addTransition(handleSmartPointerRelatedCalls(Call, C, C.getState()));
```

https://github.com/llvm/llvm-project/pull/152751
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to