Current implementation of get/put_user_unsafe default to get/put_user
which toggle PAN before each access, despite having been told by the caller
that multiple accesses to user memory were about to happen.

Provide implementations for user_access_begin/end to turn PAN off/on and
implement unsafe accessors that assume PAN was already turned off.

Signed-off-by: Julien Thierry <[email protected]>
Cc: Catalin Marinas <[email protected]>
Cc: Will Deacon <[email protected]>

---
 arch/arm64/include/asm/uaccess.h | 79 ++++++++++++++++++++++++++++++----------
 1 file changed, 59 insertions(+), 20 deletions(-)

diff --git a/arch/arm64/include/asm/uaccess.h b/arch/arm64/include/asm/uaccess.h
index 8e40808..6a70c75 100644
--- a/arch/arm64/include/asm/uaccess.h
+++ b/arch/arm64/include/asm/uaccess.h
@@ -270,31 +270,26 @@ static inline void __user *__uaccess_mask_ptr(const void 
__user *ptr)

 #define __raw_get_user(x, ptr, err)                                    \
 do {                                                                   \
-       unsigned long __gu_val;                                         \
-       __chk_user_ptr(ptr);                                            \
-       uaccess_enable_not_uao();                                       \
        switch (sizeof(*(ptr))) {                                       \
        case 1:                                                         \
-               __get_user_asm("ldrb", "ldtrb", "%w", __gu_val, (ptr),  \
+               __get_user_asm("ldrb", "ldtrb", "%w", (x), (ptr),       \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        case 2:                                                         \
-               __get_user_asm("ldrh", "ldtrh", "%w", __gu_val, (ptr),  \
+               __get_user_asm("ldrh", "ldtrh", "%w", (x), (ptr),       \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        case 4:                                                         \
-               __get_user_asm("ldr", "ldtr", "%w", __gu_val, (ptr),    \
+               __get_user_asm("ldr", "ldtr", "%w", (x), (ptr),         \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        case 8:                                                         \
-               __get_user_asm("ldr", "ldtr", "%x",  __gu_val, (ptr),   \
+               __get_user_asm("ldr", "ldtr", "%x",  (x), (ptr),        \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        default:                                                        \
                BUILD_BUG();                                            \
        }                                                               \
-       uaccess_disable_not_uao();                                      \
-       (x) = (__force __typeof__(*(ptr)))__gu_val;                     \
 } while (0)

 #define __get_user_error(x, ptr, err)                                  \
@@ -302,8 +297,13 @@ static inline void __user *__uaccess_mask_ptr(const void 
__user *ptr)
        __typeof__(*(ptr)) __user *__p = (ptr);                         \
        might_fault();                                                  \
        if (access_ok(__p, sizeof(*__p))) {                             \
+               unsigned long __gu_val;                                 \
+               __chk_user_ptr(__p);                                    \
                __p = uaccess_mask_ptr(__p);                            \
-               __raw_get_user((x), __p, (err));                        \
+               uaccess_enable_not_uao();                               \
+               __raw_get_user(__gu_val, __p, (err));                   \
+               uaccess_disable_not_uao();                              \
+               (x) = (__force __typeof__(*__p)) __gu_val;              \
        } else {                                                        \
                (x) = 0; (err) = -EFAULT;                               \
        }                                                               \
@@ -334,30 +334,26 @@ static inline void __user *__uaccess_mask_ptr(const void 
__user *ptr)

 #define __raw_put_user(x, ptr, err)                                    \
 do {                                                                   \
-       __typeof__(*(ptr)) __pu_val = (x);                              \
-       __chk_user_ptr(ptr);                                            \
-       uaccess_enable_not_uao();                                       \
        switch (sizeof(*(ptr))) {                                       \
        case 1:                                                         \
-               __put_user_asm("strb", "sttrb", "%w", __pu_val, (ptr),  \
+               __put_user_asm("strb", "sttrb", "%w", (x), (ptr),       \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        case 2:                                                         \
-               __put_user_asm("strh", "sttrh", "%w", __pu_val, (ptr),  \
+               __put_user_asm("strh", "sttrh", "%w", (x), (ptr),       \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        case 4:                                                         \
-               __put_user_asm("str", "sttr", "%w", __pu_val, (ptr),    \
+               __put_user_asm("str", "sttr", "%w", (x), (ptr),         \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        case 8:                                                         \
-               __put_user_asm("str", "sttr", "%x", __pu_val, (ptr),    \
+               __put_user_asm("str", "sttr", "%x", (x), (ptr),         \
                               (err), ARM64_HAS_UAO);                   \
                break;                                                  \
        default:                                                        \
                BUILD_BUG();                                            \
        }                                                               \
-       uaccess_disable_not_uao();                                      \
 } while (0)

 #define __put_user_error(x, ptr, err)                                  \
@@ -365,9 +361,13 @@ static inline void __user *__uaccess_mask_ptr(const void 
__user *ptr)
        __typeof__(*(ptr)) __user *__p = (ptr);                         \
        might_fault();                                                  \
        if (access_ok(__p, sizeof(*__p))) {                             \
+               __typeof__(*(ptr)) __pu_val = (x);                      \
+               __chk_user_ptr(__p);                                    \
                __p = uaccess_mask_ptr(__p);                            \
-               __raw_put_user((x), __p, (err));                        \
-       } else  {                                                       \
+               uaccess_enable_not_uao();                               \
+               __raw_put_user(__pu_val, __p, (err));                   \
+               uaccess_disable_not_uao();                              \
+       } else {                                                        \
                (err) = -EFAULT;                                        \
        }                                                               \
 } while (0)
@@ -381,6 +381,45 @@ static inline void __user *__uaccess_mask_ptr(const void 
__user *ptr)

 #define put_user       __put_user

+static __must_check inline bool user_access_begin(const void __user *ptr,
+                                                 size_t len)
+{
+       if (unlikely(!access_ok(ptr, len)))
+               return false;
+
+       uaccess_enable_not_uao();
+       return true;
+}
+#define user_access_begin(ptr, len)    user_access_begin(ptr, len)
+#define user_access_end()              uaccess_disable_not_uao()
+
+#define unsafe_get_user(x, ptr, err)                                   \
+do {                                                                   \
+       __typeof__(*(ptr)) __user *__p = (ptr);                         \
+       unsigned long __gu_val;                                         \
+       int __gu_err = 0;                                               \
+       might_fault();                                                  \
+       __chk_user_ptr(__p);                                            \
+       __p = uaccess_mask_ptr(__p);                                    \
+       __raw_get_user(__gu_val, __p, __gu_err);                        \
+       (x) = (__force __typeof__(*__p)) __gu_val;                      \
+       if (__gu_err != 0)                                              \
+               goto err;                                               \
+} while (0)
+
+#define unsafe_put_user(x, ptr, err)                                   \
+do {                                                                   \
+       __typeof__(*(ptr)) __user *__p = (ptr);                         \
+       __typeof__(*(ptr)) __pu_val = (x);                              \
+       int __pu_err = 0;                                               \
+       might_fault();                                                  \
+       __chk_user_ptr(__p);                                            \
+       __p = uaccess_mask_ptr(__p);                                    \
+       __raw_put_user(__pu_val, __p, __pu_err);                        \
+       if (__pu_err != 0)                                              \
+               goto err;                                               \
+} while (0)
+
 extern unsigned long __must_check __arch_copy_from_user(void *to, const void 
__user *from, unsigned long n);
 #define raw_copy_from_user(to, from, n)                                        
\
 ({                                                                     \
--
1.9.1

Reply via email to