Insert KASAN shadow memory checks before memory load and store
operations in JIT-compiled BPF programs. This helps detect memory safety
bugs such as use-after-free and out-of-bounds accesses at runtime.

The main instructions being targeted are BPF_ST, BPF_STX and BPF_LDX,
but not all of them are being instrumented:
- if the load/store instruction is in fact accessing the program stack,
  emit_kasan_check silently skips the instrumentation, as we already
  have page guards to monitor stack accesses.
- if the load/store instruction is a BPF_PROBE_MEM or a BPF_PROBE_ATOMIC
  instruction, we do not instrument it, as the passed address can fault
  (hence the custom fault management with BPF_PROBE_XXX instructions),
  and so the corresponding kasan check could fault as well.

Signed-off-by: Alexis Lothoré (eBPF Foundation) <[email protected]>
---
Changes in v2:
- support BPF_ATOMICS
- support BPF_ST
- make sure to systematically pass correct instruction to kasan check
---
 arch/x86/net/bpf_jit_comp.c | 63 ++++++++++++++++++++++++++++++++++++++-------
 1 file changed, 53 insertions(+), 10 deletions(-)

diff --git a/arch/x86/net/bpf_jit_comp.c b/arch/x86/net/bpf_jit_comp.c
index 943a0f315cf2..cb3c03edc4bd 100644
--- a/arch/x86/net/bpf_jit_comp.c
+++ b/arch/x86/net/bpf_jit_comp.c
@@ -1516,17 +1516,30 @@ static int emit_atomic_rmw_index(u8 **pprog, u32 
atomic_op, u32 size,
        return 0;
 }
 
-static int emit_atomic_ld_st(u8 **pprog, u32 atomic_op, u32 dst_reg,
-                            u32 src_reg, s16 off, u8 bpf_size)
+static int emit_atomic_ld_st(u8 **pprog, struct bpf_insn *insn, u8 *ip,
+                            u32 dst_reg, u32 src_reg, bool accesses_stack_only)
 {
+       u32 atomic_op = insn->imm;
+       int err;
+
        switch (atomic_op) {
        case BPF_LOAD_ACQ:
+               err = emit_kasan_check(pprog, src_reg, insn, ip, false,
+                                      accesses_stack_only);
+               if (err)
+                       return err;
                /* dst_reg = smp_load_acquire(src_reg + off16) */
-               emit_ldx(pprog, bpf_size, dst_reg, src_reg, off);
+               emit_ldx(pprog, BPF_SIZE(insn->code), dst_reg, src_reg,
+                        insn->off);
                break;
        case BPF_STORE_REL:
+               err = emit_kasan_check(pprog, dst_reg, insn, ip, true,
+                                      accesses_stack_only);
+               if (err)
+                       return err;
                /* smp_store_release(dst_reg + off16, src_reg) */
-               emit_stx(pprog, bpf_size, dst_reg, src_reg, off);
+               emit_stx(pprog, BPF_SIZE(insn->code), dst_reg, src_reg,
+                        insn->off);
                break;
        default:
                pr_err("bpf_jit: unknown atomic load/store opcode %02x\n",
@@ -1904,6 +1917,7 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                const s32 imm32 = insn->imm;
                u32 dst_reg = insn->dst_reg;
                u32 src_reg = insn->src_reg;
+               bool accesses_stack_only;
                u8 b2 = 0, b3 = 0;
                u8 *start_of_ldx;
                s64 jmp_offset;
@@ -1924,6 +1938,8 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                        EMIT_ENDBR();
 
                ip = image + addrs[i - 1] + (prog - temp);
+               accesses_stack_only =
+                       bpf_insn_accesses_stack_only(env, bpf_prog, i - 1);
 
                switch (insn->code) {
                        /* ALU */
@@ -2304,6 +2320,10 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                case BPF_ST | BPF_MEM | BPF_H:
                case BPF_ST | BPF_MEM | BPF_W:
                case BPF_ST | BPF_MEM | BPF_DW:
+                       err = emit_kasan_check(&prog, dst_reg, insn, ip, true,
+                                              accesses_stack_only);
+                       if (err)
+                               return err;
                        switch (BPF_SIZE(insn->code)) {
                        case BPF_B:
                                if (is_ereg(dst_reg))
@@ -2369,6 +2389,10 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                                insn_off = outgoing_arg_base - outgoing_rsp - 
insn_off - 16;
                                dst_reg = BPF_REG_FP;
                        }
+                       err = emit_kasan_check(&prog, dst_reg, insn, ip, true,
+                                              accesses_stack_only);
+                       if (err)
+                               return err;
                        emit_stx(&prog, BPF_SIZE(insn->code), dst_reg, src_reg, 
insn_off);
                        break;
 
@@ -2530,6 +2554,12 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                                /* populate jmp_offset for JAE above to jump to 
start_of_ldx */
                                start_of_ldx = prog;
                                end_of_jmp[-1] = start_of_ldx - end_of_jmp;
+                       } else {
+                               err = emit_kasan_check(&prog, src_reg, insn, ip,
+                                                      false,
+                                                      accesses_stack_only);
+                               if (err)
+                                       return err;
                        }
                        if (BPF_MODE(insn->code) == BPF_PROBE_MEMSX ||
                            BPF_MODE(insn->code) == BPF_MEMSX)
@@ -2592,13 +2622,13 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                        fallthrough;
                case BPF_STX | BPF_ATOMIC | BPF_W:
                case BPF_STX | BPF_ATOMIC | BPF_DW:
+                       bool is64 = BPF_SIZE(insn->code) == BPF_DW;
+                       u32 real_src_reg = src_reg;
+                       u32 real_dst_reg = dst_reg;
+                       u8 *branch_target;
                        if (insn->imm == (BPF_AND | BPF_FETCH) ||
                            insn->imm == (BPF_OR | BPF_FETCH) ||
                            insn->imm == (BPF_XOR | BPF_FETCH)) {
-                               bool is64 = BPF_SIZE(insn->code) == BPF_DW;
-                               u32 real_src_reg = src_reg;
-                               u32 real_dst_reg = dst_reg;
-                               u8 *branch_target;
 
                                /*
                                 * Can't be implemented with a single x86 insn.
@@ -2612,7 +2642,19 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                                if (dst_reg == BPF_REG_0)
                                        real_dst_reg = BPF_REG_AX;
 
+                               ip += 3;
+                       }
+                       if (!bpf_atomic_is_load_store(insn)) {
+                               err = emit_kasan_check(&prog, real_dst_reg,
+                                                      insn, ip, false,
+                                                      accesses_stack_only);
+                               if (err)
+                                       return err;
                                branch_target = prog;
+                       }
+                       if (insn->imm == (BPF_AND | BPF_FETCH) ||
+                           insn->imm == (BPF_OR | BPF_FETCH) ||
+                           insn->imm == (BPF_XOR | BPF_FETCH)) {
                                /* Load old value */
                                emit_ldx(&prog, BPF_SIZE(insn->code),
                                         BPF_REG_0, real_dst_reg, insn->off);
@@ -2644,8 +2686,9 @@ static int do_jit(struct bpf_verifier_env *env, struct 
bpf_prog *bpf_prog, int *
                        }
 
                        if (bpf_atomic_is_load_store(insn))
-                               err = emit_atomic_ld_st(&prog, insn->imm, 
dst_reg, src_reg,
-                                                       insn->off, 
BPF_SIZE(insn->code));
+                               err = emit_atomic_ld_st(&prog, insn, ip,
+                                                       dst_reg, src_reg,
+                                                       accesses_stack_only);
                        else
                                err = emit_atomic_rmw(&prog, insn->imm, 
dst_reg, src_reg,
                                                      insn->off, 
BPF_SIZE(insn->code));

-- 
2.54.0


Reply via email to