By reference-counting how many children of an explored_state are still
 being walked, we can avoid pruning based on a state that's in our own
 history (and thus hasn't reached an exit yet) without a persistent mark
 that prevents other, later branches from being pruned against it when
 it _has_ reached an exit.
Includes a check at free_states() time to ensure that all the reference
 counts have fallen to zero.

Signed-off-by: Edward Cree <ec...@solarflare.com>
---
 include/linux/bpf_verifier.h |   3 +-
 kernel/bpf/verifier.c        | 109 ++++++++++++++++++++++++++++---------------
 2 files changed, 74 insertions(+), 38 deletions(-)

diff --git a/include/linux/bpf_verifier.h b/include/linux/bpf_verifier.h
index 6abd484391f4..ee034232fbd6 100644
--- a/include/linux/bpf_verifier.h
+++ b/include/linux/bpf_verifier.h
@@ -124,7 +124,8 @@ struct bpf_func_state {
        /* loop detection; points into an explored_state */
        struct bpf_func_state *parent;
        /* These flags are only meaningful in an explored_state, not cur_state 
*/
-       bool in_loop, bounded_loop, conditional;
+       bool bounded_loop, conditional;
+       int live_children;
 
        /* should be second to last. See copy_func_state() */
        int allocated_stack;
diff --git a/kernel/bpf/verifier.c b/kernel/bpf/verifier.c
index 9828cb0cde73..cc8aaf73b4a2 100644
--- a/kernel/bpf/verifier.c
+++ b/kernel/bpf/verifier.c
@@ -398,6 +398,12 @@ static void free_verifier_state(struct bpf_verifier_state 
*state,
                free_func_state(state->frame[i]);
                state->frame[i] = NULL;
        }
+       /* Check that the live_children accounting is correct */
+       if (state->live_children)
+               pr_warn("Leaked live_children=%d at insn %d, frame %d\n",
+                       state->live_children,
+                       state->frame[state->curframe]->insn_idx,
+                       state->curframe);
        if (free_self)
                kfree(state);
 }
@@ -429,6 +435,7 @@ static int copy_verifier_state(struct bpf_verifier_state 
*dst_state,
                dst_state->frame[i] = NULL;
        }
        dst_state->curframe = src->curframe;
+       dst_state->parent = src->parent;
 
        for (i = 0; i <= src->curframe; i++) {
                dst = dst_state->frame[i];
@@ -445,6 +452,15 @@ static int copy_verifier_state(struct bpf_verifier_state 
*dst_state,
        return 0;
 }
 
+/* Mark this thread as having reached an exit */
+static void kill_thread(struct bpf_verifier_state *state)
+{
+       struct bpf_verifier_state *cur = state->parent;
+
+       while (cur && !--cur->live_children)
+               cur = cur->parent;
+}
+
 static int pop_stack(struct bpf_verifier_env *env, int *insn_idx)
 {
        struct bpf_verifier_state *cur = env->cur_state;
@@ -458,6 +474,8 @@ static int pop_stack(struct bpf_verifier_env *env, int 
*insn_idx)
                err = copy_verifier_state(cur, &head->st);
                if (err)
                        return err;
+       } else {
+               kill_thread(&head->st);
        }
        if (insn_idx)
                *insn_idx = head->insn_idx;
@@ -479,6 +497,7 @@ static int unpush_stack(struct bpf_verifier_env *env)
                return -ENOENT;
 
        elem = head->next;
+       kill_thread(&head->st);
        free_verifier_state(&head->st, false);
        kfree(head);
        env->head = elem;
@@ -509,6 +528,8 @@ static struct bpf_verifier_state *push_stack(struct 
bpf_verifier_env *env,
                verbose(env, "BPF program is too complex\n");
                goto err;
        }
+       if (elem->st.parent)
+               elem->st.parent->live_children++;
        return &elem->st;
 err:
        free_verifier_state(env->cur_state, true);
@@ -728,11 +749,9 @@ static void init_reg_state(struct bpf_verifier_env *env,
 
 static void init_func_state(struct bpf_verifier_env *env,
                            struct bpf_func_state *state,
-                           struct bpf_func_state *parent,
                            int entry, int frameno, int subprogno)
 {
        state->insn_idx = entry;
-       state->parent = parent;
        state->frameno = frameno;
        state->subprogno = subprogno;
        init_reg_state(env, state);
@@ -2111,7 +2130,6 @@ static int check_func_call(struct bpf_verifier_env *env, 
struct bpf_insn *insn,
         * callee can read/write into caller's stack
         */
        init_func_state(env, callee,
-                       caller->parent /* parent state for loop detection */,
                        target /* entry point */,
                        state->curframe + 1 /* frameno within this callchain */,
                        subprog /* subprog number within this prog */);
@@ -4207,14 +4225,20 @@ static int propagate_liveness(struct bpf_verifier_env 
*env,
        return err;
 }
 
-static struct bpf_func_state *find_loop(struct bpf_verifier_env *env,
-                                       bool *saw_cond, bool *saw_bound)
+static struct bpf_verifier_state *find_loop(struct bpf_verifier_env *env,
+                                           bool *saw_cond, bool *saw_bound,
+                                           int *min_frame)
 {
-       struct bpf_func_state *cur = cur_frame(env);
-       int insn_idx = cur->insn_idx;
+       struct bpf_verifier_state *cur = env->cur_state;
+       int insn_idx = cur_frame(env)->insn_idx;
+
+       if (min_frame)
+               *min_frame = cur->curframe;
 
        while ((cur = cur->parent) != NULL) {
-               if (cur->insn_idx == insn_idx)
+               if (min_frame)
+                       *min_frame = min_t(int, cur->curframe, *min_frame);
+               if (cur->frame[cur->curframe]->insn_idx == insn_idx)
                        return cur;
                if (cur->conditional && saw_cond)
                        *saw_cond = true;
@@ -4273,9 +4297,10 @@ static bool is_conditional_jump(struct bpf_verifier_env 
*env)
 }
 
 static bool is_loop_bounded(struct bpf_verifier_env *env, int insn_idx,
-                           struct bpf_func_state *old)
+                           struct bpf_verifier_state *vold)
 {
        struct bpf_insn *insn = env->prog->insnsi + env->prev_insn_idx;
+       struct bpf_func_state *old = vold->frame[vold->curframe];
        struct bpf_func_state *new = cur_frame(env);
        struct bpf_reg_state *oldreg, *newreg;
        u8 opcode = BPF_OP(insn->code);
@@ -4339,11 +4364,10 @@ static int is_state_visited(struct bpf_verifier_env 
*env, int insn_idx)
 {
        struct bpf_verifier_state_list *new_sl;
        struct bpf_verifier_state_list *sl;
-       struct bpf_verifier_state *cur = env->cur_state, *new;
+       struct bpf_verifier_state *cur = env->cur_state, *new, *old;
        bool saw_cond = false, saw_bound = false;
        bool cond = false, bounded = false;
-       struct bpf_func_state *old;
-       int i, j, err;
+       int i, j, err, min_frame;
 
        sl = env->explored_states[insn_idx];
        if (!sl)
@@ -4354,25 +4378,31 @@ static int is_state_visited(struct bpf_verifier_env 
*env, int insn_idx)
 
        cond = is_conditional_jump(env);
        /* Check our parentage chain: have we looped? */
-       old = find_loop(env, &saw_cond, &saw_bound);
-       if (old != NULL) {
-               if (old->frameno != cur->curframe) {
-                       /* if it's in our parentage chain, then it called us;
-                        * but we're the same insn, so in the same subprog, so
-                        * recursion has occurred.
-                        * The loop detection could handle recursion fine (it
-                        * distinguishes between bounded and unbounded
-                        * recursion, and the latter would quickly run out of
-                        * call stack anyway), but the stack max depth
-                        * calculation can't deal with it (because it doesn't
-                        * know how deeply we might recurse).
+       old = find_loop(env, &saw_cond, &saw_bound, &min_frame);
+       /* If old->curframe != min_frame, then there is a return (BPF_EXIT) from
+        * old's frame somewhere in the "loop", so it's not a real loop, just
+        * two calls to the same function.  (Those calls might come from a loop
+        * in the outer frame, but we'll deal with that when we walk the outer
+        * frame.)
+        */
+       if (old != NULL && old->curframe == min_frame) {
+               if (old->curframe != cur->curframe) {
+                       /* since it's in our parentage chain and its
+                        * frame is the minimum in the loop body, it
+                        * called us; but we're the same insn, so in the
+                        * same subprog, so recursion has occurred.
+                        * The loop detection could handle recursion
+                        * fine (it distinguishes between bounded and
+                        * unbounded recursion, and the latter would
+                        * quickly run out of call stack anyway), but
+                        * the stack max depth calculation can't deal
+                        * with it (because it doesn't know how deeply
+                        * we might recurse).
                         */
                        verbose(env, "recursive call from insn %d to %d\n",
                                env->prev_insn_idx, insn_idx);
                        return -EINVAL;
                }
-               /* Mark old state as not prunable */
-               old->in_loop = true;
                if (cond)
                        bounded = is_loop_bounded(env, insn_idx, old);
                if (bounded) {
@@ -4394,11 +4424,12 @@ static int is_state_visited(struct bpf_verifier_env 
*env, int insn_idx)
                        verbose(env, "following loop from insn %d to %d, 
bounded elsewhere\n",
                                env->prev_insn_idx, insn_idx);
                } else if (saw_cond && !cond) {
-                       /* We're not a conditional, but there's a conditional
-                        * somewhere else in the loop.  So they will be
-                        * responsible for ensuring the loop is bounded (it's
-                        * possible we've been revisited but they haven't, which
-                        * is why they might not have bounded_loop set).
+                       /* We're not a conditional, but there's a
+                        * conditional somewhere else in the loop.  So
+                        * they will be responsible for ensuring the
+                        * loop is bounded (it's possible we've been
+                        * revisited but they haven't, which is why they
+                        * might not have bounded_loop set).
                         */
                        verbose(env, "following loop from insn %d to %d for 
now, condition is elsewhere\n",
                                env->prev_insn_idx, insn_idx);
@@ -4410,7 +4441,7 @@ static int is_state_visited(struct bpf_verifier_env *env, 
int insn_idx)
        }
 
        while (sl != STATE_LIST_MARK) {
-               if (!sl->state.frame[sl->state.curframe]->in_loop &&
+               if (!sl->state.live_children &&
                    states_equal(env, &sl->state, cur)) {
                        /* reached equivalent register/stack state,
                         * prune the search.
@@ -4450,11 +4481,13 @@ static int is_state_visited(struct bpf_verifier_env 
*env, int insn_idx)
                kfree(new_sl);
                return err;
        }
-       new->frame[new->curframe]->conditional = cond;
-       new->frame[new->curframe]->bounded_loop = bounded;
+       new->conditional = cond;
+       new->bounded_loop = bounded;
        new_sl->next = env->explored_states[insn_idx];
        env->explored_states[insn_idx] = new_sl;
-       /* connect new state's regs to parentage chain */
+       /* connect new state and its regs to parentage chain */
+       cur->parent = new;
+       new->live_children = 1;
        for (i = 0; i < BPF_REG_FP; i++)
                cur_regs(env)[i].parent = &new->frame[new->curframe]->regs[i];
        /* clear write marks in current state: the writes we did are not writes
@@ -4471,7 +4504,6 @@ static int is_state_visited(struct bpf_verifier_env *env, 
int insn_idx)
                struct bpf_func_state *frame = cur->frame[j];
                struct bpf_func_state *newframe = new->frame[j];
 
-               frame->parent = newframe;
                for (i = 0; i < frame->allocated_stack / BPF_REG_SIZE; i++) {
                        frame->stack[i].spilled_ptr.live = REG_LIVE_NONE;
                        frame->stack[i].spilled_ptr.parent =
@@ -4501,6 +4533,7 @@ static int do_check(struct bpf_verifier_env *env)
                kfree(state);
                return -ENOMEM;
        }
+       state->parent = NULL;
        env->cur_state = state;
        mainprogno = add_subprog(env, 0);
        if (mainprogno < 0)
@@ -4508,7 +4541,6 @@ static int do_check(struct bpf_verifier_env *env)
        insn_idx = 0;
        env->prev_insn_idx = -1;
        init_func_state(env, state->frame[0],
-                       NULL /* parent state for loop detection */,
                        insn_idx /* entry point */,
                        0 /* frameno */,
                        mainprogno /* subprogno */);
@@ -4781,6 +4813,7 @@ static int do_check(struct bpf_verifier_env *env)
                                if (err)
                                        return err;
 process_bpf_exit:
+                               kill_thread(env->cur_state);
                                err = pop_stack(env, &insn_idx);
                                /* We are following a path that was pushed to
                                 * the stack, thus was a jump-taken path.
@@ -5719,6 +5752,8 @@ int bpf_check(struct bpf_prog **prog, union bpf_attr 
*attr)
 
        ret = do_check(env);
        if (env->cur_state) {
+               if (ret)
+                       kill_thread(env->cur_state);
                free_verifier_state(env->cur_state, true);
                env->cur_state = NULL;
        }

Reply via email to