Allow the use of multiple flags per conversion, as most flags are not
mutually exclusive.

Remove all dynamic memory usage, relying on stack buffer and variables
instead for generated format strings and character arguments.

Fix reading past format string end (with check before strchr) and use
of strtoll.

Only allow %% verbatim and simplify setting cooldown and process all
literal data in one iteration of the loop.

Fix incorrect default "width" value, and allow negative widths via *.

Consume away any length specifiers (l and h).
---
 printf.c | 121 +++++++++++++++++++++++++------------------------------
 1 file changed, 54 insertions(+), 67 deletions(-)

diff --git a/printf.c b/printf.c
index 039dac7..aa0a2df 100644
--- a/printf.c
+++ b/printf.c
@@ -9,6 +9,8 @@
 #include "utf.h"
 #include "util.h"
 
+static const char FLAGS[] = "#0- +";
+
 static void
 usage(void)
 {
@@ -18,12 +20,13 @@ usage(void)
 int
 main(int argc, char *argv[])
 {
-       Rune *rarg;
+       Rune rarg;
+       char fmtbuf[sizeof(FLAGS) + sizeof("%*.*ll")];
        size_t i, j, argi, lastargi, formatlen, blen;
        long long num;
        double dou;
        int cooldown = 0, width, precision, ret = 0;
-       char *format, *tmp, *arg, *fmt, flag;
+       char *format, *tmp, *arg;
 
        argv0 = argv[0];
        if (argc < 2)
@@ -38,40 +41,46 @@ main(int argc, char *argv[])
        if (formatlen == 0)
                return 0;
        lastargi = 0;
-       for (i = 0, argi = 2; !cooldown || i < formatlen; i++, i = cooldown ? i 
: (i % formatlen)) {
+       for (i = 0, argi = 2;
+            !cooldown || i < formatlen;
+            cooldown = argi >= argc, i = (i + 1) % (formatlen + cooldown)) {
+               char flags[sizeof(FLAGS)] = { 0 };
                if (i == 0) {
                        if (lastargi == argi)
                                break;
                        lastargi = argi;
                }
                if (format[i] != '%') {
-                       putchar(format[i]);
+                       size_t n = strcspn(format + i + 1, "%");
+                       fwrite(format + i, sizeof(*format), n + 1, stdout);
+                       i += n;
                        continue;
                }
 
-               /* flag */
-               for (flag = '\0', i++; strchr("#-+ 0", format[i]); i++) {
-                       flag = format[i];
+               if (format[++i] == '%') {
+                       putchar('%');
+                       continue;
                }
 
+               /* flag */
+               for (; i < formatlen && strchr(FLAGS, format[i]); i++)
+                       flags[strcspn(flags, (const char[]){ format[i], '\0' 
})] = format[i];
+
                /* field width */
-               width = -1;
+               width = 0;
                if (format[i] == '*') {
                        if (argi < argc)
-                               width = estrtonum(argv[argi++], 0, INT_MAX);
-                       else
-                               cooldown = 1;
+                               width = estrtonum(argv[argi++], INT_MIN, 
INT_MAX);
                        i++;
                } else {
-                       j = i;
-                       for (; strchr("+-0123456789", format[i]); i++);
-                       if (j != i) {
-                               tmp = estrndup(format + j, i - j);
-                               width = estrtonum(tmp, 0, INT_MAX);
-                               free(tmp);
-                       } else {
-                               width = 0;
+                       num = strtoll(format + i, &tmp, 10);
+                       if (num > INT_MAX) {
+                               if (tmp - format - i > INT_MAX)
+                                       tmp = format + i + INT_MAX;
+                               eprintf("field width %.*s not in range\n", 
(int)(tmp - format - i), format + i);
                        }
+                       width = num;
+                       i = tmp - format;
                }
 
                /* field precision */
@@ -80,33 +89,25 @@ main(int argc, char *argv[])
                        if (format[++i] == '*') {
                                if (argi < argc)
                                        precision = estrtonum(argv[argi++], 0, 
INT_MAX);
-                               else
-                                       cooldown = 1;
                                i++;
                        } else {
-                               j = i;
-                               for (; strchr("+-0123456789", format[i]); i++);
-                               if (j != i) {
-                                       tmp = estrndup(format + j, i - j);
-                                       precision = estrtonum(tmp, 0, INT_MAX);
-                                       free(tmp);
-                               } else {
-                                       precision = 0;
+                               num = strtoll(format + i, &tmp, 10);
+                               if (num < 0 || num > INT_MAX) {
+                                       if (tmp - format - i > INT_MAX)
+                                               tmp = format + i + INT_MAX;
+                                       eprintf("field precision %.*s not in 
range\n", (int)(tmp - format - i), format + i);
                                }
+                               precision = num;
+                               i = tmp - format;
                        }
                }
 
-               if (format[i] != '%') {
-                       if (argi < argc)
-                               arg = argv[argi++];
-                       else {
-                               arg = "";
-                               cooldown = 1;
-                       }
-               } else {
-                       putchar('%');
-                       continue;
-               }
+               if (argi < argc)
+                       arg = argv[argi++];
+               else
+                       arg = "";
+
+               i += strspn(format + i, "hl");
 
                switch (format[i]) {
                case 'b':
@@ -121,26 +122,20 @@ main(int argc, char *argv[])
                        break;
                case 'c':
                        unescape(arg);
-                       rarg = ereallocarray(NULL, utflen(arg) + 1, 
sizeof(*rarg));
-                       utftorunestr(arg, rarg);
-                       efputrune(rarg, stdout, "<stdout>");
-                       free(rarg);
+                       chartorune(&rarg, arg);
+                       efputrune(&rarg, stdout, "<stdout>");
                        break;
                case 's':
-                       fmt = estrdup(flag ? "%#*.*s" : "%*.*s");
-                       if (flag)
-                               fmt[1] = flag;
-                       printf(fmt, width, precision, arg);
-                       free(fmt);
+                       snprintf(fmtbuf, sizeof(fmtbuf), "%%%s*.*s", flags);
+                       printf(fmtbuf, width, precision, arg);
                        break;
                case 'd': case 'i': case 'o': case 'u': case 'x': case 'X':
                        for (j = 0; isspace(arg[j]); j++);
                        if (arg[j] == '\'' || arg[j] == '\"') {
                                arg += j + 1;
                                unescape(arg);
-                               rarg = ereallocarray(NULL, utflen(arg) + 1, 
sizeof(*rarg));
-                               utftorunestr(arg, rarg);
-                               num = rarg[0];
+                               chartorune(&rarg, arg);
+                               num = rarg;
                        } else if (arg[0]) {
                                errno = 0;
                                if (format[i] == 'd' || format[i] == 'i')
@@ -161,27 +156,19 @@ main(int argc, char *argv[])
                        } else {
                                        num = 0;
                        }
-                       fmt = estrdup(flag ? "%#*.*ll#" : "%*.*ll#");
-                       if (flag)
-                               fmt[1] = flag;
-                       fmt[flag ? 7 : 6] = format[i];
-                       printf(fmt, width, precision, num);
-                       free(fmt);
+                       snprintf(fmtbuf, sizeof(fmtbuf), "%%%s*.*ll%c", flags, 
format[i]);
+                       printf(fmtbuf, width, precision, num);
                        break;
                case 'a': case 'A': case 'e': case 'E': case 'f': case 'F': 
case 'g': case 'G':
-                       fmt = estrdup(flag ? "%#*.*#" : "%*.*#");
-                       if (flag)
-                               fmt[1] = flag;
-                       fmt[flag ? 5 : 4] = format[i];
-                       dou = (strlen(arg) > 0) ? estrtod(arg) : 0;
-                       printf(fmt, width, precision, dou);
-                       free(fmt);
+                       snprintf(fmtbuf, sizeof(fmtbuf), "%%%s*.*%c", flags, 
format[i]);
+                       dou = *arg ? estrtod(arg) : 0;
+                       printf(fmtbuf, width, precision, dou);
                        break;
+               case '\0':
+                       eprintf("Missing format specifier.\n");
                default:
                        eprintf("Invalid format specifier '%c'.\n", format[i]);
                }
-               if (argi >= argc)
-                       cooldown = 1;
        }
 
        return fshut(stdout, "<stdout>") | ret;
-- 
2.25.1


Reply via email to