Author: Ziqing Luo
Date: 2026-01-02T11:10:31-08:00
New Revision: 81b46646fb5eb34559ef1e31d0ee83a69c18a301

URL: 
https://github.com/llvm/llvm-project/commit/81b46646fb5eb34559ef1e31d0ee83a69c18a301
DIFF: 
https://github.com/llvm/llvm-project/commit/81b46646fb5eb34559ef1e31d0ee83a69c18a301.diff

LOG: [-Wunsafe-buffer-usage] Add check for custom printf/scanf functions 
(#173096)

This commit adds support for functions annotated with
`__attribute__((__format__(__printf__, ...)))` (or `__scanf__`). These
functions will be treated the same way as printf/scanf functions in the
standard C library by `-Wunsafe-buffer-usage`

rdar://143233737

Added: 
    

Modified: 
    clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def
    clang/lib/Analysis/UnsafeBufferUsage.cpp
    clang/test/SemaCXX/warn-unsafe-buffer-usage-libc-functions.cpp

Removed: 
    


################################################################################
diff  --git 
a/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def 
b/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def
index fae5f8b8aa8e3..f9bba5d54e9c7 100644
--- a/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def
+++ b/clang/include/clang/Analysis/Analyses/UnsafeBufferUsageGadgets.def
@@ -40,6 +40,7 @@ WARNING_GADGET(UnsafeBufferUsageCtorAttr)
 WARNING_GADGET(DataInvocation)
 WARNING_GADGET(UniquePtrArrayAccess)
 WARNING_OPTIONAL_GADGET(UnsafeLibcFunctionCall)
+WARNING_OPTIONAL_GADGET(UnsafeFormatAttributedFunctionCall)
 WARNING_OPTIONAL_GADGET(SpanTwoParamConstructor) // Uses of `std::span(arg0, 
arg1)`
 FIXABLE_GADGET(ULCArraySubscript)          // `DRE[any]` in an Unspecified 
Lvalue Context
 FIXABLE_GADGET(DerefSimplePtrArithFixable)

diff  --git a/clang/lib/Analysis/UnsafeBufferUsage.cpp 
b/clang/lib/Analysis/UnsafeBufferUsage.cpp
index 7ef20726d0ab9..620da756c3a9c 100644
--- a/clang/lib/Analysis/UnsafeBufferUsage.cpp
+++ b/clang/lib/Analysis/UnsafeBufferUsage.cpp
@@ -825,9 +825,13 @@ struct LibcFunNamePrefixSuffixParser {
 //
 // `UnsafeArg` is the output argument that will be set only if this function
 // returns true.
-static bool hasUnsafeFormatOrSArg(const CallExpr *Call, const Expr *&UnsafeArg,
-                                  const unsigned FmtArgIdx, ASTContext &Ctx,
-                                  bool isKprintf = false) {
+//
+// Format arguments start at `FmtIdx` + 1, if `FmtArgIdx` is insignificant.
+static bool
+hasUnsafeFormatOrSArg(ASTContext &Ctx, const CallExpr *Call,
+                      const Expr *&UnsafeArg, const unsigned FmtIdx,
+                      std::optional<const unsigned> FmtArgIdx = std::nullopt,
+                      bool isKprintf = false) {
   class StringFormatStringHandler
       : public analyze_format_string::FormatStringHandler {
     const CallExpr *Call;
@@ -847,18 +851,18 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, 
const Expr *&UnsafeArg,
     const Expr *
     getPrecisionAsExpr(const analyze_printf::OptionalAmount &Precision,
                        const CallExpr *Call) {
-      unsigned PArgIdx = -1;
-
-      if (Precision.hasDataArgument())
-        PArgIdx = Precision.getPositionalArgIndex() + FmtArgIdx;
-      if (0 < PArgIdx && PArgIdx < Call->getNumArgs()) {
-        const Expr *PArg = Call->getArg(PArgIdx);
-
-        // Strip the cast if `PArg` is a cast-to-int expression:
-        if (auto *CE = dyn_cast<CastExpr>(PArg);
-            CE && CE->getType()->isSignedIntegerType())
-          PArg = CE->getSubExpr();
-        return PArg;
+      if (Precision.hasDataArgument()) {
+        unsigned PArgIdx = Precision.getArgIndex() + FmtArgIdx;
+
+        if (PArgIdx < Call->getNumArgs()) {
+          const Expr *PArg = Call->getArg(PArgIdx);
+
+          // Strip the cast if `PArg` is a cast-to-int expression:
+          if (auto *CE = dyn_cast<CastExpr>(PArg);
+              CE && CE->getType()->isSignedIntegerType())
+            PArg = CE->getSubExpr();
+          return PArg;
+        }
       }
       if (Precision.getHowSpecified() ==
           analyze_printf::OptionalAmount::HowSpecified::Constant) {
@@ -886,9 +890,9 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, 
const Expr *&UnsafeArg,
           analyze_printf::PrintfConversionSpecifier::sArg)
         return true; // continue parsing
 
-      unsigned ArgIdx = FS.getPositionalArgIndex() + FmtArgIdx;
+      unsigned ArgIdx = FS.getArgIndex() + FmtArgIdx;
 
-      if (!(0 < ArgIdx && ArgIdx < Call->getNumArgs()))
+      if (ArgIdx >= Call->getNumArgs())
         // If the `ArgIdx` is invalid, give up.
         return true; // continue parsing
 
@@ -921,12 +925,15 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, 
const Expr *&UnsafeArg,
     bool isUnsafeArgSet() { return UnsafeArgSet; }
   };
 
-  const Expr *Fmt = Call->getArg(FmtArgIdx);
+  const Expr *Fmt = Call->getArg(FmtIdx);
+  unsigned FmtArgStartingIdx =
+      FmtArgIdx.has_value() ? static_cast<unsigned>(*FmtArgIdx) : FmtIdx + 1;
 
   if (auto *SL = dyn_cast<clang::StringLiteral>(Fmt->IgnoreParenImpCasts())) {
     if (SL->getCharByteWidth() == 1) {
       StringRef FmtStr = SL->getString();
-      StringFormatStringHandler Handler(Call, FmtArgIdx, UnsafeArg, Ctx);
+      StringFormatStringHandler Handler(Call, FmtArgStartingIdx, UnsafeArg,
+                                        Ctx);
 
       return analyze_format_string::ParsePrintfString(
                  Handler, FmtStr.begin(), FmtStr.end(), Ctx.getLangOpts(),
@@ -935,7 +942,8 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, 
const Expr *&UnsafeArg,
     }
 
     if (auto FmtStr = SL->tryEvaluateString(Ctx)) {
-      StringFormatStringHandler Handler(Call, FmtArgIdx, UnsafeArg, Ctx);
+      StringFormatStringHandler Handler(Call, FmtArgStartingIdx, UnsafeArg,
+                                        Ctx);
       return analyze_format_string::ParsePrintfString(
                  Handler, FmtStr->data(), FmtStr->data() + FmtStr->size(),
                  Ctx.getLangOpts(), Ctx.getTargetInfo(), isKprintf) &&
@@ -946,7 +954,7 @@ static bool hasUnsafeFormatOrSArg(const CallExpr *Call, 
const Expr *&UnsafeArg,
   // In this case, this call is considered unsafe if at least one argument
   // (including the format argument) is unsafe pointer.
   return llvm::any_of(
-      llvm::make_range(Call->arg_begin() + FmtArgIdx, Call->arg_end()),
+      llvm::make_range(Call->arg_begin() + FmtArgStartingIdx, Call->arg_end()),
       [&UnsafeArg, &Ctx](const Expr *Arg) -> bool {
         if (Arg->getType()->isPointerType() && !isNullTermPointer(Arg, Ctx)) {
           UnsafeArg = Arg;
@@ -1161,7 +1169,7 @@ static bool hasUnsafePrintfStringArg(const CallExpr 
&Node, ASTContext &Ctx,
     // It is a fprintf:
     const Expr *UnsafeArg;
 
-    if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 1, Ctx, false)) {
+    if (hasUnsafeFormatOrSArg(Ctx, &Node, UnsafeArg, /* FmtIdx= */ 1)) {
       Result.addNode(Tag, DynTypedNode::create(*UnsafeArg));
       return true;
     }
@@ -1175,7 +1183,8 @@ static bool hasUnsafePrintfStringArg(const CallExpr 
&Node, ASTContext &Ctx,
 
     if (auto *II = FD->getIdentifier())
       isKprintf = II->getName() == "kprintf";
-    if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 0, Ctx, isKprintf)) {
+    if (hasUnsafeFormatOrSArg(Ctx, &Node, UnsafeArg, /* FmtIdx= */ 0,
+                              /* FmtArgIdx= */ std::nullopt, isKprintf)) {
       Result.addNode(Tag, DynTypedNode::create(*UnsafeArg));
       return true;
     }
@@ -1190,7 +1199,7 @@ static bool hasUnsafePrintfStringArg(const CallExpr 
&Node, ASTContext &Ctx,
       // second is an integer, it is a snprintf:
       const Expr *UnsafeArg;
 
-      if (hasUnsafeFormatOrSArg(&Node, UnsafeArg, 2, Ctx, false)) {
+      if (hasUnsafeFormatOrSArg(Ctx, &Node, UnsafeArg, /* FmtIdx= */ 2)) {
         Result.addNode(Tag, DynTypedNode::create(*UnsafeArg));
         return true;
       }
@@ -2068,6 +2077,7 @@ class UnsafeLibcFunctionCallGadget : public WarningGadget 
{
   constexpr static const char *const UnsafeVaListTag =
       "UnsafeLibcFunctionCall_va_list";
 
+public:
   enum UnsafeKind {
     OTHERS = 0,  // no specific information, the callee function is unsafe
     SPRINTF = 1, // never call `-sprintf`s, call `-snprintf`s instead.
@@ -2080,7 +2090,6 @@ class UnsafeLibcFunctionCallGadget : public WarningGadget 
{
                  // considered unsafe as it is not compile-time check
   } WarnedFunKind = OTHERS;
 
-public:
   UnsafeLibcFunctionCallGadget(const MatchResult &Result)
       : WarningGadget(Kind::UnsafeLibcFunctionCall),
         Call(Result.getNodeAs<CallExpr>(Tag)) {
@@ -2171,6 +2180,86 @@ class UnsafeLibcFunctionCallGadget : public 
WarningGadget {
   SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; }
 };
 
+class UnsafeFormatAttributedFunctionCallGadget : public WarningGadget {
+  const CallExpr *const Call;
+  const Expr *UnsafeArg = nullptr;
+  constexpr static const char *const Tag = 
"UnsafeFormatAttributedFunctionCall";
+  constexpr static const char *const UnsafeStringTag =
+      "UnsafeFormatAttributedFunctionCall_string";
+
+public:
+  UnsafeFormatAttributedFunctionCallGadget(const MatchResult &Result)
+      : WarningGadget(Kind::UnsafeLibcFunctionCall),
+        Call(Result.getNodeAs<CallExpr>(Tag)),
+        UnsafeArg(Result.getNodeAs<Expr>(UnsafeStringTag)) {}
+
+  static bool matches(const Stmt *S, ASTContext &Ctx,
+                      const UnsafeBufferUsageHandler *Handler,
+                      MatchResult &Result) {
+    if (ignoreUnsafeLibcCall(Ctx, *S, Handler))
+      return false;
+    auto *CE = dyn_cast<CallExpr>(S);
+    if (!CE || !CE->getDirectCallee())
+      return false;
+    const auto *FD = dyn_cast<FunctionDecl>(CE->getDirectCallee());
+    if (!FD)
+      return false;
+
+    const FormatAttr *Attr = nullptr;
+    bool IsPrintf = false;
+    bool AnyAttr = llvm::any_of(
+        FD->specific_attrs<FormatAttr>(),
+        [&Attr, &IsPrintf](const FormatAttr *FA) -> bool {
+          if (const auto *II = FA->getType()) {
+            if (II->getName() == "printf" || II->getName() == "scanf") {
+              Attr = FA;
+              IsPrintf = II->getName() == "printf";
+              return true;
+            }
+          }
+          return false;
+        });
+    const Expr *UnsafeArg;
+
+    if (AnyAttr && !IsPrintf &&
+        (CE->getNumArgs() >= static_cast<unsigned>(Attr->getFirstArg()))) {
+      // for scanf-like functions, any format argument is considered unsafe:
+      Result.addNode(Tag, DynTypedNode::create(*CE));
+      return true;
+    }
+    if (AnyAttr && libc_func_matchers::hasUnsafeFormatOrSArg(
+                       Ctx, CE, UnsafeArg,
+                       // FormatAttribute indexes are 1-based:
+                       /* FmtIdx= */ Attr->getFormatIdx() - 1,
+                       /* FmtArgIdx= */ Attr->getFirstArg() - 1)) {
+      Result.addNode(Tag, DynTypedNode::create(*CE));
+      Result.addNode(UnsafeStringTag, DynTypedNode::create(*UnsafeArg));
+      return true;
+    }
+    return false;
+  }
+
+  const Stmt *getBaseStmt() const { return Call; }
+
+  SourceLocation getSourceLoc() const override { return Call->getBeginLoc(); }
+
+  void handleUnsafeOperation(UnsafeBufferUsageHandler &Handler,
+                             bool IsRelatedToDecl,
+                             ASTContext &Ctx) const override {
+    if (UnsafeArg)
+      Handler.handleUnsafeLibcCall(
+          Call, UnsafeLibcFunctionCallGadget::UnsafeKind::STRING, Ctx,
+          UnsafeArg);
+    else
+      Handler.handleUnsafeLibcCall(
+          Call, UnsafeLibcFunctionCallGadget::UnsafeKind::OTHERS, Ctx);
+  }
+
+  DeclUseList getClaimedVarUseSites() const override { return {}; }
+
+  SmallVector<const Expr *, 1> getUnsafePtrs() const override { return {}; }
+};
+
 // Represents expressions of the form `DRE[*]` in the Unspecified Lvalue
 // Context (see `findStmtsInUnspecifiedLvalueContext`).
 // Note here `[]` is the built-in subscript operator.

diff  --git a/clang/test/SemaCXX/warn-unsafe-buffer-usage-libc-functions.cpp 
b/clang/test/SemaCXX/warn-unsafe-buffer-usage-libc-functions.cpp
index 4f1af79609223..a68ecf128c3d9 100644
--- a/clang/test/SemaCXX/warn-unsafe-buffer-usage-libc-functions.cpp
+++ b/clang/test/SemaCXX/warn-unsafe-buffer-usage-libc-functions.cpp
@@ -1,10 +1,10 @@
-// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage \
+// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage -Wno-gcc-compat\
 // RUN:            -verify %s
-// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage \
+// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage -Wno-gcc-compat\
 // RUN:            -verify %s -x objective-c++
-// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage-in-libc-call \
+// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage-in-libc-call 
-Wno-gcc-compat\
 // RUN:            -verify %s
-// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage-in-libc-call \
+// RUN: %clang_cc1 -std=c++20 -Wno-all -Wunsafe-buffer-usage-in-libc-call 
-Wno-gcc-compat\
 // RUN:            -verify %s -DTEST_STD_NS
 
 typedef struct {} FILE;
@@ -255,3 +255,34 @@ void dontCrashForInvalidFormatString() {
   snprintf((char*)0, 0, "%");
   snprintf((char*)0, 0, "\0");
 }
+
+
+// Also warn about unsafe printf/scanf-like functions:
+void myprintf(const char *, ...) __attribute__((__format__ (__printf__, 1, 
2)));
+void myprintf_2(const char *, int, const char *) __attribute__((__format__ 
(__printf__, 1, 3)));
+void myprintf_3(const char *, const char *, int, const char *) 
__attribute__((__format__ (__printf__, 2, 4)));
+void myscanf(const char *, ...) __attribute__((__format__ (__scanf__, 1, 2)));
+
+void test_myprintf(char * Str, std::string StdStr) {
+  myprintf("hello", Str);
+  myprintf("hello %s", StdStr.c_str());
+  myprintf("hello %s", Str);  // expected-warning{{function 'myprintf' is 
unsafe}} \
+                                expected-note{{string argument is not 
guaranteed to be null-terminated}}
+
+  myprintf_2("hello", 0, Str);
+  myprintf_2("hello %s", 0, StdStr.c_str());
+  myprintf_2("hello %s", 0, Str);  // expected-warning{{function 'myprintf_2' 
is unsafe}} \
+                                     expected-note{{string argument is not 
guaranteed to be null-terminated}}
+
+  myprintf_3("irrelevant", "hello", 0, Str);
+  myprintf_3("irrelevant", "hello %s", 0, StdStr.c_str());
+  myprintf_3("irrelevant", "hello %s", 0, Str);  // expected-warning{{function 
'myprintf_3' is unsafe}} \
+                                      expected-note{{string argument is not 
guaranteed to be null-terminated}}
+  
+  myscanf("hello %s");
+  myscanf("hello %s", Str); // expected-warning{{function 'myscanf' is unsafe}}
+
+  int X;
+
+  myscanf("hello %d", &X); // expected-warning{{function 'myscanf' is unsafe}}
+}


        
_______________________________________________
cfe-commits mailing list
[email protected]
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to