This is difficult to get correct: we add r12 register to interrupt state
restore so that every interrupt context knows which mode to return to
for swapgs.  Testing if gs base is negative is enough to know which mode
to enter from, since user cannot set gs base negative currently.
(That would require FSGSBASE cpu feature enabled, but for now we leave
it disabled).
syscall64 enters with interrupts disabled and requires swapgs at
beginning and end, as it is known to be called only from userspace.

---
 i386/i386/gdt.c    | 18 +++++++++
 i386/i386/ldt.c    |  2 +-
 i386/i386/pcb.c    |  8 ++--
 i386/i386/thread.h |  1 +
 x86_64/boothdr.S   |  2 -
 x86_64/locore.S    | 92 ++++++++++++++++++++++++++++++++++++----------
 6 files changed, 98 insertions(+), 25 deletions(-)

diff --git a/i386/i386/gdt.c b/i386/i386/gdt.c
index 8c4a59c7..585583b7 100644
--- a/i386/i386/gdt.c
+++ b/i386/i386/gdt.c
@@ -40,6 +40,7 @@
 
 #include "vm_param.h"
 #include "seg.h"
+#include "msr.h"
 #include "gdt.h"
 #include "mp_desc.h"
 
@@ -109,6 +110,17 @@ gdt_fill(int cpu, struct real_descriptor *mygdt)
 #endif /* MACH_PV_DESCRIPTORS */
 }
 
+#ifdef __x86_64__
+static void
+reload_gs_base(int cpu)
+{
+       /* KGSBASE is kernels gs base while in userspace,
+        * but when in kernel, GSBASE must point to percpu area. */
+       wrmsr(MSR_REG_GSBASE, (uint64_t)&percpu_array[cpu]);
+       wrmsr(MSR_REG_KGSBASE, 0);
+}
+#endif
+
 static void
 reload_segs(void)
 {
@@ -138,6 +150,9 @@ gdt_init(void)
        gdt_fill(0, gdt);
 
        reload_segs();
+#ifdef __x86_64__
+       reload_gs_base(0);
+#endif
 
 #ifdef MACH_PV_PAGETABLES
 #if VM_MIN_KERNEL_ADDRESS != LINEAR_MIN_KERNEL_ADDRESS
@@ -157,5 +172,8 @@ ap_gdt_init(int cpu)
        gdt_fill(cpu, mp_gdt[cpu]);
 
        reload_segs();
+#ifdef __x86_64__
+       reload_gs_base(cpu);
+#endif
 }
 #endif
diff --git a/i386/i386/ldt.c b/i386/i386/ldt.c
index 7db67f99..6a13a7f8 100644
--- a/i386/i386/ldt.c
+++ b/i386/i386/ldt.c
@@ -72,7 +72,7 @@ ldt_fill(struct real_descriptor *myldt, struct 
real_descriptor *mygdt)
 #if defined(__x86_64__) && ! defined(USER32)
         if (!CPU_HAS_FEATURE(CPU_FEATURE_SEP))
             panic("syscall support is missing on 64 bit");
-        /* Enable 64-bit syscalls */
+        /* Enable 64-bit syscalls with interrupts disabled on entry */
         wrmsr(MSR_REG_EFER, rdmsr(MSR_REG_EFER) | MSR_EFER_SCE);
         wrmsr(MSR_REG_LSTAR, (vm_offset_t)syscall64);
         wrmsr(MSR_REG_STAR, ((((long)USER_CS - 16) << 16) | (long)KERNEL_CS) 
<< 32);
diff --git a/i386/i386/pcb.c b/i386/i386/pcb.c
index d845b2b2..e4ac2bb9 100644
--- a/i386/i386/pcb.c
+++ b/i386/i386/pcb.c
@@ -230,7 +230,7 @@ void switch_ktss(pcb_t pcb)
 
 #if defined(__x86_64__) && !defined(USER32)
        wrmsr(MSR_REG_FSBASE, pcb->ims.sbs.fsbase);
-       wrmsr(MSR_REG_GSBASE, pcb->ims.sbs.gsbase);
+       wrmsr(MSR_REG_KGSBASE, pcb->ims.sbs.gsbase);
 #endif
 
        db_load_context(pcb);
@@ -710,11 +710,13 @@ kern_return_t thread_setstatus(
                             return KERN_INVALID_ARGUMENT;
 
                     state = (struct i386_fsgs_base_state *) tstate;
+                    if (state->gs_base & 0x8000000000000000UL)
+                            printf("WARNING: negative gs base not allowed\n");
                     thread->pcb->ims.sbs.fsbase = state->fs_base;
-                    thread->pcb->ims.sbs.gsbase = state->gs_base;
+                    thread->pcb->ims.sbs.gsbase = state->gs_base & 
0x7fffffffffffffffUL;
                     if (thread == current_thread()) {
                             wrmsr(MSR_REG_FSBASE, state->fs_base);
-                            wrmsr(MSR_REG_GSBASE, state->gs_base);
+                            wrmsr(MSR_REG_KGSBASE, state->gs_base);
                     }
                     break;
             }
diff --git a/i386/i386/thread.h b/i386/i386/thread.h
index 9c88d09a..5112bc83 100644
--- a/i386/i386/thread.h
+++ b/i386/i386/thread.h
@@ -183,6 +183,7 @@ struct i386_interrupt_state {
        long    ds;
 #endif
 #ifdef __x86_64__
+       long    r12;
        long    r11;
        long    r10;
        long    r9;
diff --git a/x86_64/boothdr.S b/x86_64/boothdr.S
index 45d59c06..98f1cab2 100644
--- a/x86_64/boothdr.S
+++ b/x86_64/boothdr.S
@@ -185,7 +185,6 @@ boot_entry64:
        andq    $(~15),%rax
        movq    %rax,%rsp
 
-#if NCPUS > 1
        /* Set GS base address for kernel */
        movq    $percpu_array, %rdx
        movl    %edx, %eax
@@ -198,7 +197,6 @@ boot_entry64:
        xorl    %edx, %edx
        movl    $MSR_REG_KGSBASE, %ecx
        wrmsr
-#endif
 
        /* Reset EFLAGS to a known state.  */
        pushq   $0
diff --git a/x86_64/locore.S b/x86_64/locore.S
index 6afda87a..907502ef 100644
--- a/x86_64/locore.S
+++ b/x86_64/locore.S
@@ -88,7 +88,8 @@
        pushq   %r8     ;\
        pushq   %r9     ;\
        pushq   %r10     ;\
-       pushq   %r11
+       pushq   %r11    ;\
+       pushq   %r12
 
 #define PUSH_AREGS_ISR \
        pushq   %rax    ;\
@@ -96,6 +97,7 @@
 
 
 #define POP_REGS_ISR   \
+       popq    %r12    ;\
        popq    %r11    ;\
        popq    %r10     ;\
        popq    %r9     ;\
@@ -163,21 +165,54 @@
 #define POP_SEGMENTS_ISR(reg)
 #endif
 
-#if NCPUS > 1
-#define SET_KERNEL_SEGMENTS(reg)                \
-       ud2             /* TODO: use swapgs or similar */
-#else // NCPUS > 1
 #ifdef USER32
-#define SET_KERNEL_SEGMENTS(reg)              \
-       mov     %ss,reg /* switch to kernel segments */ ;\
-       mov     reg,%ds /* (same as kernel stack segment) */ ;\
-       mov     reg,%es                 ;\
-       mov     reg,%fs                 ;\
-       mov     reg,%gs
-#else // USER32
+#define SET_KERNEL_SEGMENTS(reg)       \
+       mov     %ss,reg /* switch to kernel segments */ ;\
+       mov     reg,%ds /* (same as kernel stack segment) */ ;\
+       mov     reg,%es                 ;\
+       mov     reg,%fs                 ;\
+       mov     reg,%gs
+#else
 #define SET_KERNEL_SEGMENTS(reg)
-#endif // USER32
-#endif // NCPUS > 1
+#endif
+
+#define RETURN_TO_KERN 0x7eadbeef
+#define RETURN_TO_USER 0x66666666
+
+#ifdef USER32
+# define SWAPGS_ENTRY_IF_NEEDED_R12
+# define SWAPGS_EXIT_IF_NEEDED_R12
+#else
+/* Keeps %r12 (callee-saved) value throughout interrupt context */
+# define SWAPGS_ENTRY_IF_NEEDED_R12    \
+       pushf                           ;\
+       cli                             ;\
+       pushq   %rax                    ;\
+       pushq   %rcx                    ;\
+       pushq   %rdx                    ;\
+       movl    $MSR_REG_GSBASE, %ecx   ;\
+       rdmsr                           ;\
+       testl   %edx, %edx      /* gs base sign bit set ? */ ;\
+       js      0f              /* yes, dont swap then return to kernel mode */ 
;\
+       swapgs                  /* no, swap then return to user mode */ ;\
+       movq    $RETURN_TO_USER, %r12   ;\
+       jmp     1f                      ;\
+0:     movq    $RETURN_TO_KERN, %r12   ;\
+1:     popq    %rdx                    ;\
+       popq    %rcx                    ;\
+       popq    %rax                    ;\
+       popf
+
+# define SWAPGS_EXIT_IF_NEEDED_R12     \
+       cmpq    $RETURN_TO_USER, %r12   ;\
+       je      0f              /* return to user with swap */ ;\
+       cmpq    $RETURN_TO_KERN, %r12   ;\
+       je      1f              /* return to kern without swap */ ;\
+       ud2                     /* or die */ ;\
+0:     swapgs                          ;\
+1:
+
+#endif
 
 /*
  * Fault recovery.
@@ -617,6 +652,7 @@ ENTRY(alltraps)
        pusha                           /* save the general registers */
 trap_push_segs:
        PUSH_SEGMENTS(%rax)             /* and the segment registers */
+       SWAPGS_ENTRY_IF_NEEDED_R12
        SET_KERNEL_SEGMENTS(%rax)       /* switch to kernel data segment */
 trap_set_segs:
        cld                             /* clear direction flag */
@@ -673,6 +709,7 @@ _return_to_user:
  */
 
 _return_from_kernel:
+       SWAPGS_EXIT_IF_NEEDED_R12
 #ifdef USER32
 _kret_popl_gs:
        popq    %gs                     /* restore segment registers */
@@ -738,6 +775,7 @@ ENTRY(thread_bootstrap_return)
        movq    %rsp,%rcx                       /* get kernel stack */
        or      $(KERNEL_STACK_SIZE-1),%rcx
        movq    -7-IKS_SIZE(%rcx),%rsp          /* switch back to PCB stack */
+       movq    $RETURN_TO_USER, %r12
        jmp     _return_from_trap
 
 /*
@@ -752,6 +790,7 @@ ENTRY(thread_syscall_return)
        or      $(KERNEL_STACK_SIZE-1),%rcx
        movq    -7-IKS_SIZE(%rcx),%rsp          /* switch back to PCB stack */
        movq    %rax,R_EAX(%rsp)                /* save return value */
+       movq    $RETURN_TO_USER, %r12
        jmp     _return_from_trap
 
 ENTRY(call_continuation)
@@ -829,8 +868,8 @@ INTERRUPT(255)
 ENTRY(all_intrs)
        PUSH_REGS_ISR                   /* save registers */
        cld                             /* clear direction flag */
-
        PUSH_SEGMENTS_ISR(%rdx)         /* save segment registers */
+       SWAPGS_ENTRY_IF_NEEDED_R12
 
        CPU_NUMBER_NO_GS(%rcx)
        movq    %rsp,%rdx               /* on an interrupt stack? */
@@ -887,6 +926,7 @@ LEXT(return_to_iret)                        /* to find the 
return from calling interrupt) */
        cmpq    $0,CX(EXT(need_ast),%rdx)
        jnz     ast_from_interrupt      /* take it if so */
 1:
+       SWAPGS_EXIT_IF_NEEDED_R12
        POP_SEGMENTS_ISR(%rdx)          /* restore segment regs */
        POP_AREGS_ISR                   /* restore registers */
 
@@ -898,10 +938,10 @@ int_from_intstack:
        jb      stack_overflowed        /* if not: */
        call    EXT(interrupt)          /* call interrupt routine */
 _return_to_iret_i:                     /* ( label for kdb_kintr) */
+       SWAPGS_EXIT_IF_NEEDED_R12
        POP_SEGMENTS_ISR(%rdx)
        POP_AREGS_ISR                   /* restore registers */
                                        /* no ASTs */
-
        iretq
 
 stack_overflowed:
@@ -951,6 +991,7 @@ ast_from_interrupt:
  *             saved SPL
  *             saved IRQ
  *             return address == return_to_iret_i
+ *             saved %r12
  *             saved %r11
  *             saved %r10
  *             saved %r9
@@ -973,6 +1014,7 @@ ast_from_interrupt:
  *             saved %fs
  *             saved %es
  *             saved %ds
+ *             saved %r12
  *             saved %r11
  *             saved %r10
  *             saved %r9
@@ -1168,6 +1210,7 @@ syscall_entry_2:
 
        pusha                           /* save the general registers */
        PUSH_SEGMENTS(%rdx)             /* and the segment registers */
+       SWAPGS_ENTRY_IF_NEEDED_R12
        SET_KERNEL_SEGMENTS(%rdx)       /* switch to kernel data segment */
 
 /*
@@ -1303,6 +1346,7 @@ mach_call_addr:
                                        /* set page-fault trap */
        movq    $(T_PF_USER),R_ERR(%rbx)
                                        /* set error code - read user space */
+       movq    $RETURN_TO_USER, %r12
        jmp     _take_trap              /* treat as a trap */
 
 /*
@@ -1313,6 +1357,7 @@ mach_call_range:
        movq    $(T_INVALID_OPCODE),R_TRAPNO(%rbx)
                                        /* set invalid-operation trap */
        movq    $0,R_ERR(%rbx)          /* clear error code */
+       movq    $RETURN_TO_USER, %r12
        jmp     _take_trap              /* treat as a trap */
 
 /*
@@ -1356,6 +1401,7 @@ syscall_addr:
                                        /* set page-fault trap */
        movq    $(T_PF_USER),R_ERR(%rbx)
                                        /* set error code - read user space */
+       movq    $RETURN_TO_USER, %r12
        jmp     _take_trap              /* treat as a trap */
 END(syscall)
 
@@ -1367,6 +1413,8 @@ END(syscall)
  * the syscall.
  * Note: emulated syscalls seem to not be used anymore in GNU/Hurd, so they
  * are not handled here.
+ * Note: added complication: need gs base to be in kernel mode during execution
+ * to read the active thread twice.  Call swapgs twice, once at start and at 
end.
  * TODO:
      - for now we assume the return address is canonical, but apparently there
        can be cases where it's not (see how Linux handles this). Does it apply
@@ -1375,6 +1423,8 @@ END(syscall)
        iretq from return_from_trap, works fine in all combinations
  */
 ENTRY(syscall64)
+       /* interrupts are already disabled */
+       swapgs
        /* RFLAGS[32:63] are reserved, so combine syscall num (32 bit) and
         * eflags in RAX to allow using r11 as temporary register
         */
@@ -1420,8 +1470,8 @@ ENTRY(syscall64)
        mov     %r11,%rbx               /* prepare for error handling */
        mov     %r10,%rcx               /* fix arg3 location according to C ABI 
*/
 
-       /* switch to kernel stack, then we can enable interrupts */
-       CPU_NUMBER_NO_STACK(%r8b, %r8d, %r8, %r11d, %r11)
+       /* switch to kernel stack then enable interrupts */
+       CPU_NUMBER(%r11d)               /* we can call the fast version here */
        movq    CX(EXT(kernel_stack),%r11),%rsp
        sti
 
@@ -1464,7 +1514,8 @@ _syscall64_call:
 
 _syscall64_check_for_ast:
        /* Check for ast. */
-       CPU_NUMBER_NO_GS(%r11)
+       CPU_NUMBER(%r11d)
+
        cmpl    $0,CX(EXT(need_ast),%r11)
        jz      _syscall64_restore_state
 
@@ -1513,6 +1564,7 @@ _syscall64_restore_state:
        mov     R_R15(%r11),%r15        /* callee-preserved register */
        mov     R_EFLAGS(%r11),%r11     /* sysret convention */
 
+       swapgs
        sysretq         /* fast return to user-space, the thread didn't block */
 
 /* Error handling fragments, from here we jump directly to the trap handler */
@@ -1520,12 +1572,14 @@ _syscall64_addr_push:
        movq    %r11,R_CR2(%rbx)        /* set fault address */
        movq    $(T_PAGE_FAULT),R_TRAPNO(%rbx)  /* set page-fault trap */
        movq    $(T_PF_USER),R_ERR(%rbx) /* set error code - read user space */
+       movq    $RETURN_TO_USER, %r12
        jmp     _take_trap              /* treat as a trap */
 
 _syscall64_range:
        movq    $(T_INVALID_OPCODE),R_TRAPNO(%rbx)
                                        /* set invalid-operation trap */
        movq    $0,R_ERR(%rbx)          /* clear error code */
+       movq    $RETURN_TO_USER, %r12
        jmp     _take_trap              /* treat as a trap */
 
 END(syscall64)
-- 
2.51.0



Reply via email to