The fault-inject-make-fail-nth-read-write-interface-symmetric.patch in
-mm tree allows users to set task->fail_nth for non current task by procfs.
On the other hand, the current task's fail_nth is decreased to zero in
fault-injection path without any specific locks.

So we need to prevent the task->fail_nth from being unexpected value by
data races (for example, setting task->fail_nth to zero while decreasing
the current->fail_nth).  In this fix, we use READ_ONCE() and WRITE_ONCE()
to prevent the compiler from creating unsolicited accesses.

Cc: Dmitry Vyukov <dvyu...@google.com>
Reported-by: Dmitry Vyukov <dvyu...@google.com>
Signed-off-by: Akinobu Mita <akinobu.m...@gmail.com>
---
 fs/proc/base.c     | 5 +++--
 lib/fault-inject.c | 7 +++++--
 2 files changed, 8 insertions(+), 4 deletions(-)

diff --git a/fs/proc/base.c b/fs/proc/base.c
index ecc8a25..719c2e9 100644
--- a/fs/proc/base.c
+++ b/fs/proc/base.c
@@ -1370,7 +1370,7 @@ static ssize_t proc_fail_nth_write(struct file *file, 
const char __user *buf,
        task = get_proc_task(file_inode(file));
        if (!task)
                return -ESRCH;
-       task->fail_nth = n;
+       WRITE_ONCE(task->fail_nth, n);
        put_task_struct(task);
 
        return count;
@@ -1386,7 +1386,8 @@ static ssize_t proc_fail_nth_read(struct file *file, char 
__user *buf,
        task = get_proc_task(file_inode(file));
        if (!task)
                return -ESRCH;
-       len = snprintf(numbuf, sizeof(numbuf), "%u\n", task->fail_nth);
+       len = snprintf(numbuf, sizeof(numbuf), "%u\n",
+                       READ_ONCE(task->fail_nth));
        len = simple_read_from_buffer(buf, count, ppos, numbuf, len);
        put_task_struct(task);
 
diff --git a/lib/fault-inject.c b/lib/fault-inject.c
index 09ac73c1..7d315fd 100644
--- a/lib/fault-inject.c
+++ b/lib/fault-inject.c
@@ -107,9 +107,12 @@ static inline bool fail_stacktrace(struct fault_attr *attr)
 
 bool should_fail(struct fault_attr *attr, ssize_t size)
 {
-       if (in_task() && current->fail_nth) {
-               if (--current->fail_nth == 0)
+       if (in_task()) {
+               unsigned int fail_nth = READ_ONCE(current->fail_nth);
+
+               if (fail_nth && !WRITE_ONCE(current->fail_nth, fail_nth - 1))
                        goto fail;
+
                return false;
        }
 
-- 
2.7.4

Reply via email to