Track printk() recursion and limit it to 3 levels per-CPU and per-context.

Signed-off-by: John Ogness <[email protected]>
---
 kernel/printk/printk.c | 80 ++++++++++++++++++++++++++++++++++++++++--
 1 file changed, 77 insertions(+), 3 deletions(-)

diff --git a/kernel/printk/printk.c b/kernel/printk/printk.c
index 2f829fbf0a13..c666e3e43f0c 100644
--- a/kernel/printk/printk.c
+++ b/kernel/printk/printk.c
@@ -1940,6 +1940,71 @@ static void call_console_drivers(const char *ext_text, 
size_t ext_len,
        }
 }
 
+/*
+ * Recursion is tracked separately on each CPU. If NMIs are supported, an
+ * additional NMI context per CPU is also separately tracked. Until per-CPU
+ * is available, a separate "early tracking" is performed.
+ */
+#ifdef CONFIG_PRINTK_NMI
+#define PRINTK_CTX_NUM 2
+#else
+#define PRINTK_CTX_NUM 1
+#endif
+static DEFINE_PER_CPU(char [PRINTK_CTX_NUM], printk_count);
+static char printk_count_early[PRINTK_CTX_NUM];
+
+/*
+ * Recursion is limited to keep the output sane. printk() should not require
+ * more than 1 level of recursion (allowing, for example, printk() to trigger
+ * a WARN), but a higher value is used in case some printk-internal errors
+ * exist, such as the ringbuffer validation checks failing.
+ */
+#define PRINTK_MAX_RECURSION 3
+
+/* Return a pointer to the dedicated counter for the CPU+context of the 
caller. */
+static char *printk_recursion_counter(void)
+{
+       int ctx = 0;
+
+#ifdef CONFIG_PRINTK_NMI
+       if (in_nmi())
+               ctx = 1;
+#endif
+       if (!printk_percpu_data_ready())
+               return &printk_count_early[ctx];
+       return &((*this_cpu_ptr(&printk_count))[ctx]);
+}
+
+/*
+ * Enter recursion tracking. Interrupts are disabled to simplify tracking.
+ * The caller must check the return value to see if the recursion is allowed.
+ * On failure, interrupts are not disabled.
+ */
+static bool printk_enter_irqsave(unsigned long *flags)
+{
+       char *count;
+
+       local_irq_save(*flags);
+       count = printk_recursion_counter();
+       if (*count > PRINTK_MAX_RECURSION) {
+               local_irq_restore(*flags);
+               return false;
+       }
+       (*count)++;
+
+       return true;
+}
+
+/* Exit recursion tracking, restoring interrupts. */
+static void printk_exit_irqrestore(unsigned long flags)
+{
+       char *count;
+
+       count = printk_recursion_counter();
+       (*count)--;
+       local_irq_restore(flags);
+}
+
 int printk_delay_msec __read_mostly;
 
 static inline void printk_delay(void)
@@ -2040,11 +2105,13 @@ int vprintk_store(int facility, int level,
        struct prb_reserved_entry e;
        enum log_flags lflags = 0;
        struct printk_record r;
+       unsigned long irqflags;
        u16 trunc_msg_len = 0;
        char prefix_buf[8];
        u16 reserve_size;
        va_list args2;
        u16 text_len;
+       int ret = 0;
        u64 ts_nsec;
 
        /*
@@ -2055,6 +2122,9 @@ int vprintk_store(int facility, int level,
         */
        ts_nsec = local_clock();
 
+       if (!printk_enter_irqsave(&irqflags))
+               return 0;
+
        /*
         * The sprintf needs to come first since the syslog prefix might be
         * passed in as a parameter. An extra byte must be reserved so that
@@ -2092,7 +2162,8 @@ int vprintk_store(int facility, int level,
                                prb_commit(&e);
                        }
 
-                       return text_len;
+                       ret = text_len;
+                       goto out;
                }
        }
 
@@ -2108,7 +2179,7 @@ int vprintk_store(int facility, int level,
 
                prb_rec_init_wr(&r, reserve_size + trunc_msg_len);
                if (!prb_reserve(&e, prb, &r))
-                       return 0;
+                       goto out;
        }
 
        /* fill message */
@@ -2130,7 +2201,10 @@ int vprintk_store(int facility, int level,
        else
                prb_final_commit(&e);
 
-       return (text_len + trunc_msg_len);
+       ret = text_len + trunc_msg_len;
+out:
+       printk_exit_irqrestore(irqflags);
+       return ret;
 }
 
 asmlinkage int vprintk_emit(int facility, int level,
-- 
2.20.1

Reply via email to