================
@@ -1236,12 +1222,98 @@ void 
Sema::checkFortifiedBuiltinMemoryFunction(FunctionDecl *FD,
     const Expr *ObjArg = TheCall->getArg(NewIndex);
 
     if (std::optional<uint64_t> Result =
-            ObjArg->tryEvaluateStrLen(getASTContext())) {
+            ObjArg->tryEvaluateStrLen(S.getASTContext())) {
       // Add 1 for null byte.
       return llvm::APSInt::getUnsigned(*Result + 1).extOrTrunc(SizeTypeWidth);
     }
     return std::nullopt;
-  };
+  }
+
+  const DiagnoseAsBuiltinAttr *getDABAttr() const { return DABAttr; }
+  unsigned getSizeTypeWidth() const { return SizeTypeWidth; }
+
+  /// Return function name after stripping __builtin_ and _chk affixes.
+  std::string GetFunctionName(unsigned BuiltinID, bool IsChkVariant) const {
+    std::string Name = S.getASTContext().BuiltinInfo.getName(BuiltinID);
+    llvm::StringRef Ref = Name;
+    if (IsChkVariant) {
+      Ref = Ref.drop_front(std::strlen("__builtin___"));
+      Ref = Ref.drop_back(std::strlen("_chk"));
+    } else {
+      Ref.consume_front("__builtin_");
+    }
+    return Ref.str();
+  }
+
+private:
+  Sema &S;
+  CallExpr *TheCall;
+  FunctionDecl *FD;
+  const DiagnoseAsBuiltinAttr *DABAttr;
+  unsigned SizeTypeWidth;
+};
+} // anonymous namespace
+
+void Sema::checkSourceBufferOverread(FunctionDecl *FD, CallExpr *TheCall,
+                                     unsigned SrcArgIdx, unsigned SizeArgIdx) {
+  if (isConstantEvaluatedContext())
+    return;
+
+  const Expr *SrcArg = TheCall->getArg(SrcArgIdx);
+  const Expr *SizeArg = TheCall->getArg(SizeArgIdx);
+  if (SrcArg->isValueDependent() || SrcArg->isTypeDependent() ||
+      SizeArg->isValueDependent() || SizeArg->isTypeDependent())
+    return;
+
+  FortifiedBufferChecker Checker(*this, FD, TheCall);
+
+  std::optional<llvm::APSInt> CopyLen =
+      Checker.ComputeExplicitObjectSizeArgument(SizeArgIdx);
+  std::optional<llvm::APSInt> SrcBufSize =
+      Checker.ComputeSizeArgument(SrcArgIdx);
+
+  if (!CopyLen || !SrcBufSize)
+    return;
+
+  // Warn only if copy length exceeds source buffer size.
+  if (llvm::APSInt::compareValues(*CopyLen, *SrcBufSize) <= 0)
+    return;
+
+  llvm::StringRef FuncName = "memory function";
+  if (const FunctionDecl *CalleeDecl = TheCall->getDirectCallee()) {
+    FuncName = CalleeDecl->getName();
+    // __builtin___memcpy_chk -> memcpy, __builtin_memcpy -> memcpy.
+    // The _chk variants have a different prefix so try that one first.
+    if (!(FuncName.consume_front("__builtin___") &&
----------------
erichkeane wrote:

Why are we doing this again?  Isn't this effectively `GetFunctionName`?  

Also, we should probably make sure that `FuncName` isn't empty, which can 
happen in cases where the callee is a special function of some sort (though 
probalby just an assert, since we expect these to be builtins?).

https://github.com/llvm/llvm-project/pull/183004
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to