Add the emit_kasan_check() function that emits KASAN shadow memory
checks before memory accesses in JIT-compiled BPF programs. The
implementation relies on the existing __asan_{load,store}X functions
from KASAN subsystem. The helper:
- ensures that the kasan instrumention is actually needed: if the
instruction being processed accesses the program stack, we skip the
instrumentation, as those accesses are already protected with page
guards
- saves registers. This includes caller-saved registers, but also
temporary registers, as those were possibly used by the
affected program. Theoretically, r10 and r11 should be saved as well,
but the number of called function and their scope being limited, they
are skipped for the sake of reducing the overhead
- computes the accessed address and stores it in %rdi
- calls the relevant function, depending on the instruction being a load
or a store, and the size of the access.
- restores registers
The special care needed when inserting this instrumentation comes at the
cost of a non negligeable increase in JITed code size. For example, a
bare
mov 0x0(%si),rbx # Load in rbx content at address stored in rsi
becomes
push %rax
push %rcx
push %rdx
push %rsi
push %rdi
push %r8
push %r9
mov %rsi,%rdi
call 0xffffffff81da0a60 <__asan_load8>
pop %r9
pop %r8
pop %rdi
pop %rsi
pop %rdx
pop %rcx
pop %rax
mov 0x0(%rsi),rbx
Signed-off-by: Alexis Lothoré (eBPF Foundation) <[email protected]>
---
Changes in v3:
- skip kasan instrumentation if there is no verifier env (cBPF)
- move helper up in the file
Changes in v2:
- move asan functions declaration directly into jit compiler, and guard
them with IS_ENABLED
- remove faulty stack alignment, no arg is passed to kasan funcs on the
stack anyway
- make sure to emit call depth accounting code
- do not save unneeded registers
- update helper signature to let caller configure some values (eg:
is_write)
---
arch/x86/net/bpf_jit_comp.c | 95 +++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 95 insertions(+)
diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 054e043ffcd2..68c5f9f94e5e 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -21,6 +21,17 @@
#include <asm/unwind.h>
#include <asm/cfi.h>
+#if IS_ENABLED(CONFIG_BPF_JIT_KASAN)
+void __asan_load1(void *p);
+void __asan_store1(void *p);
+void __asan_load2(void *p);
+void __asan_store2(void *p);
+void __asan_load4(void *p);
+void __asan_store4(void *p);
+void __asan_load8(void *p);
+void __asan_store8(void *p);
+#endif
+
static bool all_callee_regs_used[4] = {true, true, true, true};
static u8 *emit_code(u8 *ptr, u32 bytes, unsigned int len)
@@ -1110,6 +1121,90 @@ static void maybe_emit_1mod(u8 **pprog, u32 reg, bool
is64)
*pprog = prog;
}
+static int emit_kasan_check(struct bpf_verifier_env *env, u8 **pprog,
+ u32 addr_reg, struct bpf_insn *insn, u8 *ip,
+ bool is_write, bool accesses_stack_only)
+{
+#ifdef CONFIG_BPF_JIT_KASAN
+ u32 bpf_size = BPF_SIZE(insn->code);
+ s32 off = insn->off;
+ u8 *prog = *pprog;
+ void *kasan_func;
+
+ if (!env)
+ return 0;
+
+ if (accesses_stack_only)
+ return 0;
+
+ /* Derive KASAN check function from access type and size */
+ switch (bpf_size) {
+ case BPF_B:
+ kasan_func = is_write ? __asan_store1 : __asan_load1;
+ break;
+ case BPF_H:
+ kasan_func = is_write ? __asan_store2 : __asan_load2;
+ break;
+ case BPF_W:
+ kasan_func = is_write ? __asan_store4 : __asan_load4;
+ break;
+ case BPF_DW:
+ kasan_func = is_write ? __asan_store8 : __asan_load8;
+ break;
+ default:
+ return -EINVAL;
+ }
+
+ /* Save rax */
+ EMIT1(0x50);
+ /* Save rcx */
+ EMIT1(0x51);
+ /* Save rdx */
+ EMIT1(0x52);
+ /* Save rsi */
+ EMIT1(0x56);
+ /* Save rdi */
+ EMIT1(0x57);
+ /* Save r8 */
+ EMIT2(0x41, 0x50);
+ /* Save r9 */
+ EMIT2(0x41, 0x51);
+
+ /* mov rdi, addr_reg */
+ EMIT_mov(BPF_REG_1, addr_reg);
+
+ /* add rdi, off (if offset is non-zero) */
+ if (off) {
+ if (is_imm8(off)) {
+ /* add rdi, imm8 */
+ EMIT4(0x48, 0x83, 0xC7, (u8)off);
+ } else {
+ /* add rdi, imm32 */
+ EMIT3_off32(0x48, 0x81, 0xC7, off);
+ }
+ }
+
+ /* Adjust ip to account for the instrumentation generated so far */
+ ip += (prog - *pprog);
+ /* We emit a call, so update call depth counting */
+ ip += x86_call_depth_emit_accounting(&prog, kasan_func, ip);
+ /* call kasan_func */
+ if (emit_call(&prog, kasan_func, ip))
+ return -ERANGE;
+
+ EMIT2(0x41, 0x59);
+ EMIT2(0x41, 0x58);
+ EMIT1(0x5F);
+ EMIT1(0x5E);
+ EMIT1(0x5A);
+ EMIT1(0x59);
+ EMIT1(0x58);
+
+ *pprog = prog;
+#endif /* CONFIG_BPF_JIT_KASAN */
+ return 0;
+}
+
/* LDX: dst_reg = *(u8*)(src_reg + off) */
static void emit_ldx(u8 **pprog, u32 size, u32 dst_reg, u32 src_reg, int off)
{
--
2.54.0