If a memory fault occurs that is due to an overlay/pkey fault, report that to
userspace with a SEGV_PKUERR.

Signed-off-by: Joey Gouly <joey.go...@arm.com>
Cc: Catalin Marinas <catalin.mari...@arm.com>
Cc: Will Deacon <w...@kernel.org>
---
 arch/arm64/include/asm/traps.h |  1 +
 arch/arm64/kernel/traps.c      | 12 ++++++--
 arch/arm64/mm/fault.c          | 56 ++++++++++++++++++++++++++++++++--
 3 files changed, 64 insertions(+), 5 deletions(-)

diff --git a/arch/arm64/include/asm/traps.h b/arch/arm64/include/asm/traps.h
index eefe766d6161..f6f6f2cb7f10 100644
--- a/arch/arm64/include/asm/traps.h
+++ b/arch/arm64/include/asm/traps.h
@@ -25,6 +25,7 @@ try_emulate_armv8_deprecated(struct pt_regs *regs, u32 insn)
 void force_signal_inject(int signal, int code, unsigned long address, unsigned 
long err);
 void arm64_notify_segfault(unsigned long addr);
 void arm64_force_sig_fault(int signo, int code, unsigned long far, const char 
*str);
+void arm64_force_sig_fault_pkey(int signo, int code, unsigned long far, const 
char *str, int pkey);
 void arm64_force_sig_mceerr(int code, unsigned long far, short lsb, const char 
*str);
 void arm64_force_sig_ptrace_errno_trap(int errno, unsigned long far, const 
char *str);
 
diff --git a/arch/arm64/kernel/traps.c b/arch/arm64/kernel/traps.c
index 215e6d7f2df8..1bac6c84d3f5 100644
--- a/arch/arm64/kernel/traps.c
+++ b/arch/arm64/kernel/traps.c
@@ -263,16 +263,24 @@ static void arm64_show_signal(int signo, const char *str)
        __show_regs(regs);
 }
 
-void arm64_force_sig_fault(int signo, int code, unsigned long far,
-                          const char *str)
+void arm64_force_sig_fault_pkey(int signo, int code, unsigned long far,
+                          const char *str, int pkey)
 {
        arm64_show_signal(signo, str);
        if (signo == SIGKILL)
                force_sig(SIGKILL);
+       else if (code == SEGV_PKUERR)
+               force_sig_pkuerr((void __user *)far, pkey);
        else
                force_sig_fault(signo, code, (void __user *)far);
 }
 
+void arm64_force_sig_fault(int signo, int code, unsigned long far,
+                          const char *str)
+{
+       arm64_force_sig_fault_pkey(signo, code, far, str, 0);
+}
+
 void arm64_force_sig_mceerr(int code, unsigned long far, short lsb,
                            const char *str)
 {
diff --git a/arch/arm64/mm/fault.c b/arch/arm64/mm/fault.c
index 8251e2fea9c7..585295168918 100644
--- a/arch/arm64/mm/fault.c
+++ b/arch/arm64/mm/fault.c
@@ -23,6 +23,7 @@
 #include <linux/sched/debug.h>
 #include <linux/highmem.h>
 #include <linux/perf_event.h>
+#include <linux/pkeys.h>
 #include <linux/preempt.h>
 #include <linux/hugetlb.h>
 
@@ -489,6 +490,23 @@ static void do_bad_area(unsigned long far, unsigned long 
esr,
 #define VM_FAULT_BADMAP                ((__force vm_fault_t)0x010000)
 #define VM_FAULT_BADACCESS     ((__force vm_fault_t)0x020000)
 
+static bool fault_from_pkey(unsigned long esr, struct vm_area_struct *vma,
+                       unsigned int mm_flags)
+{
+       unsigned long iss2 = ESR_ELx_ISS2(esr);
+
+       if (!arch_pkeys_enabled())
+               return false;
+
+       if (iss2 & ESR_ELx_Overlay)
+               return true;
+
+       return !arch_vma_access_permitted(vma,
+                       mm_flags & FAULT_FLAG_WRITE,
+                       mm_flags & FAULT_FLAG_INSTRUCTION,
+                       mm_flags & FAULT_FLAG_REMOTE);
+}
+
 static vm_fault_t __do_page_fault(struct mm_struct *mm,
                                  struct vm_area_struct *vma, unsigned long 
addr,
                                  unsigned int mm_flags, unsigned long vm_flags,
@@ -529,6 +547,8 @@ static int __kprobes do_page_fault(unsigned long far, 
unsigned long esr,
        unsigned int mm_flags = FAULT_FLAG_DEFAULT;
        unsigned long addr = untagged_addr(far);
        struct vm_area_struct *vma;
+       bool pkey_fault = false;
+       int pkey = -1;
 
        if (kprobe_page_fault(regs, esr))
                return 0;
@@ -590,6 +610,12 @@ static int __kprobes do_page_fault(unsigned long far, 
unsigned long esr,
                vma_end_read(vma);
                goto lock_mmap;
        }
+
+       if (fault_from_pkey(esr, vma, mm_flags)) {
+               vma_end_read(vma);
+               goto lock_mmap;
+       }
+
        fault = handle_mm_fault(vma, addr, mm_flags | FAULT_FLAG_VMA_LOCK, 
regs);
        if (!(fault & (VM_FAULT_RETRY | VM_FAULT_COMPLETED)))
                vma_end_read(vma);
@@ -617,6 +643,11 @@ static int __kprobes do_page_fault(unsigned long far, 
unsigned long esr,
                goto done;
        }
 
+       if (fault_from_pkey(esr, vma, mm_flags)) {
+               pkey_fault = true;
+               pkey = vma_pkey(vma);
+       }
+
        fault = __do_page_fault(mm, vma, addr, mm_flags, vm_flags, regs);
 
        /* Quick path to respond to signals */
@@ -682,9 +713,28 @@ static int __kprobes do_page_fault(unsigned long far, 
unsigned long esr,
                 * Something tried to access memory that isn't in our memory
                 * map.
                 */
-               arm64_force_sig_fault(SIGSEGV,
-                                     fault == VM_FAULT_BADACCESS ? SEGV_ACCERR 
: SEGV_MAPERR,
-                                     far, inf->name);
+               int fault_kind;
+               /*
+                * The pkey value that we return to userspace can be different
+                * from the pkey that caused the fault.
+                *
+                * 1. T1   : mprotect_key(foo, PAGE_SIZE, pkey=4);
+                * 2. T1   : set POR_EL0 to deny access to pkey=4, touches, page
+                * 3. T1   : faults...
+                * 4.    T2: mprotect_key(foo, PAGE_SIZE, pkey=5);
+                * 5. T1   : enters fault handler, takes mmap_lock, etc...
+                * 6. T1   : reaches here, sees vma_pkey(vma)=5, when we really
+                *           faulted on a pte with its pkey=4.
+                */
+
+               if (pkey_fault)
+                       fault_kind = SEGV_PKUERR;
+               else
+                       fault_kind = fault == VM_FAULT_BADACCESS ? SEGV_ACCERR 
: SEGV_MAPERR;
+
+               arm64_force_sig_fault_pkey(SIGSEGV,
+                                     fault_kind,
+                                     far, inf->name, pkey);
        }
 
        return 0;
-- 
2.25.1

Reply via email to