Add the emit_kasan_check() function that emits KASAN shadow memory
checks before memory accesses in JIT-compiled BPF programs. The
implementation relies on the existing __asan_{load,store}X functions
from KASAN subsystem. The helper:
- ensures that the kasan instrumention is actually needed: if the
  instruction being processed accesses the program stack, we skip the
  instrumentation, as those accesses are already protected with page
  guards
- saves registers. This includes caller-saved registers, but also
  temporary registers, as those were possibly used by the
  affected program
- computes the accessed address and stores it in %rdi
- calls the relevant function, depending on the instruction being a load
  or a store, and the size of the access.
- restores registers

The special care needed when inserting this instrumentation comes at the
cost of a non negligeable increase in JITed code size. For example, a
bare

  mov   0x0(%si),rbx # Load in rbx content at address stored in rsi

becomes

  push    %rax
  push    %rcx
  push    %rdx
  push    %rsi
  push    %rdi
  push    %r8
  push    %r9
  mov     %rsi,%rdi
  call    0xffffffff81da0a60 <__asan_load8>
  pop     %r9
  pop     %r8
  pop     %rdi
  pop     %rsi
  pop     %rdx
  pop     %rcx
  pop     %rax
  mov     0x0(%rsi),rbx

Signed-off-by: Alexis Lothoré (eBPF Foundation) <[email protected]>
---
Changes in v2:
- move asan functions declaration directly into jit compiler, and guard
  them with IS_ENABLED
- remove faulty stack alignment, no arg is passed to kasan funcs on the
  stack anyway
- make sure to emit call depth accounting code
- do not save unneeded registers
- update helper signature to let caller configure some values (eg:
  is_write)

Signed-off-by: Alexis Lothoré (eBPF Foundation) <[email protected]>
---
 arch/x86/net/bpf_jit_comp.c | 93 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 93 insertions(+)

diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index a0c541a441cf..0981791014eb 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -21,6 +21,19 @@
 #include <asm/unwind.h>
 #include <asm/cfi.h>
 
+#if IS_ENABLED(CONFIG_BPF_JIT_KASAN)
+void __asan_load1(void *p);
+void __asan_store1(void *p);
+void __asan_load2(void *p);
+void __asan_store2(void *p);
+void __asan_load4(void *p);
+void __asan_store4(void *p);
+void __asan_load8(void *p);
+void __asan_store8(void *p);
+void __asan_load16(void *p);
+void __asan_store16(void *p);
+#endif
+
 static bool all_callee_regs_used[4] = {true, true, true, true};
 
 static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
@@ -1330,6 +1343,86 @@ static void emit_store_stack_imm64(u8 **pprog, int reg, 
int stack_off, u64 imm64
        emit_stx(pprog, BPF_DW, BPF_REG_FP, reg, stack_off);
 }
 
+static int emit_kasan_check(u8 **pprog, u32 addr_reg, struct bpf_insn *insn,
+                           u8 *ip, bool is_write, bool accesses_stack_only)
+{
+#ifdef CONFIG_BPF_JIT_KASAN
+       u32 bpf_size = BPF_SIZE(insn->code);
+       s32 off = insn->off;
+       u8 *prog = *pprog;
+       void *kasan_func;
+
+       if (accesses_stack_only)
+               return 0;
+
+       /* Derive KASAN check function from access type and size */
+       switch (bpf_size) {
+       case BPF_B:
+               kasan_func = is_write ? __asan_store1 : __asan_load1;
+               break;
+       case BPF_H:
+               kasan_func = is_write ? __asan_store2 : __asan_load2;
+               break;
+       case BPF_W:
+               kasan_func = is_write ? __asan_store4 : __asan_load4;
+               break;
+       case BPF_DW:
+               kasan_func = is_write ? __asan_store8 : __asan_load8;
+               break;
+       default:
+               return -EINVAL;
+       }
+
+       /* Save rax */
+       EMIT1(0x50);
+       /* Save rcx */
+       EMIT1(0x51);
+       /* Save rdx */
+       EMIT1(0x52);
+       /* Save rsi */
+       EMIT1(0x56);
+       /* Save rdi */
+       EMIT1(0x57);
+       /* Save r8 */
+       EMIT2(0x41, 0x50);
+       /* Save r9 */
+       EMIT2(0x41, 0x51);
+
+       /* mov rdi, addr_reg */
+       EMIT_mov(BPF_REG_1, addr_reg);
+
+       /* add rdi, off (if offset is non-zero) */
+       if (off) {
+               if (is_imm8(off)) {
+                       /* add rdi, imm8 */
+                       EMIT4(0x48, 0x83, 0xC7, (u8)off);
+               } else {
+                       /* add rdi, imm32 */
+                       EMIT3_off32(0x48, 0x81, 0xC7, off);
+               }
+       }
+
+       /* Adjust ip to account for the instrumentation generated so far */
+       ip += (prog - *pprog);
+       /* We emit a call, so update call depth counting */
+       ip += x86_call_depth_emit_accounting(&prog, kasan_func, ip);
+       /* call kasan_func */
+       if (emit_call(&prog, kasan_func, ip))
+               return -ERANGE;
+
+       EMIT2(0x41, 0x59);
+       EMIT2(0x41, 0x58);
+       EMIT1(0x5F);
+       EMIT1(0x5E);
+       EMIT1(0x5A);
+       EMIT1(0x59);
+       EMIT1(0x58);
+
+       *pprog = prog;
+#endif /* CONFIG_BPF_JIT_KASAN */
+       return 0;
+}
+
 static int emit_atomic_rmw(u8 **pprog, u32 atomic_op,
                           u32 dst_reg, u32 src_reg, s16 off, u8 bpf_size)
 {

-- 
2.54.0


Reply via email to