In the case where we need a specific number of bytes before a
verdict can be assigned, even if the data spans multiple sendmsg
or sendfile calls. The BPF program may use msg_apply_bytes().

The extreme case is a user can call sendmsg repeatedly with
1-byte msg segments. Obviously, this is bad for performance but
is still valid. If the BPF program needs N bytes to validate
a header it can use msg_cork_bytes to specify N bytes and the
BPF program will not be called again until N bytes have been
accumulated.

Signed-off-by: John Fastabend <john.fastab...@gmail.com>
---
 include/linux/filter.h   |    2 
 include/uapi/linux/bpf.h |    3 
 kernel/bpf/sockmap.c     |  334 ++++++++++++++++++++++++++++++++++++++++------
 net/core/filter.c        |   16 ++
 4 files changed, 310 insertions(+), 45 deletions(-)

diff --git a/include/linux/filter.h b/include/linux/filter.h
index 805a566..6058a1b 100644
--- a/include/linux/filter.h
+++ b/include/linux/filter.h
@@ -511,6 +511,8 @@ struct sk_msg_buff {
        void *data;
        void *data_end;
        int apply_bytes;
+       int cork_bytes;
+       int sg_copybreak;
        int sg_start;
        int sg_curr;
        int sg_end;
diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index e50c61f..cfcc002 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -770,7 +770,8 @@ enum bpf_attach_type {
        FN(override_return),            \
        FN(sock_ops_cb_flags_set),      \
        FN(msg_redirect_map),           \
-       FN(msg_apply_bytes),
+       FN(msg_apply_bytes),            \
+       FN(msg_cork_bytes),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
diff --git a/kernel/bpf/sockmap.c b/kernel/bpf/sockmap.c
index 98c6a3b..f637a83 100644
--- a/kernel/bpf/sockmap.c
+++ b/kernel/bpf/sockmap.c
@@ -78,8 +78,10 @@ struct smap_psock {
        /* datapath variables for tx_msg ULP */
        struct sock *sk_redir;
        int apply_bytes;
+       int cork_bytes;
        int sg_size;
        int eval;
+       struct sk_msg_buff *cork;
 
        struct strparser strp;
        struct bpf_prog *bpf_tx_msg;
@@ -140,22 +142,30 @@ static int bpf_tcp_init(struct sock *sk)
        return 0;
 }
 
+static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
+static int free_start_sg(struct sock *sk, struct sk_msg_buff *md);
+
 static void bpf_tcp_release(struct sock *sk)
 {
        struct smap_psock *psock;
 
        rcu_read_lock();
        psock = smap_psock_sk(sk);
+       if (unlikely(!psock))
+               goto out;
 
-       if (likely(psock)) {
-               sk->sk_prot = psock->sk_proto;
-               psock->sk_proto = NULL;
+       if (psock->cork) {
+               free_start_sg(psock->sock, psock->cork);
+               kfree(psock->cork);
+               psock->cork = NULL;
        }
+
+       sk->sk_prot = psock->sk_proto;
+       psock->sk_proto = NULL;
+out:
        rcu_read_unlock();
 }
 
-static void smap_release_sock(struct smap_psock *psock, struct sock *sock);
-
 static void bpf_tcp_close(struct sock *sk, long timeout)
 {
        void (*close_fun)(struct sock *sk, long timeout);
@@ -211,14 +221,25 @@ static int memcopy_from_iter(struct sock *sk,
                             struct iov_iter *from, int bytes)
 {
        struct scatterlist *sg = md->sg_data;
-       int i = md->sg_curr, rc = 0;
+       int i = md->sg_curr, rc = -ENOSPC;
 
        do {
                int copy;
                char *to;
 
-               copy = sg[i].length;
-               to = sg_virt(&sg[i]);
+               if (md->sg_copybreak >= sg[i].length) {
+                       md->sg_copybreak = 0;
+
+                       if (++i == MAX_SKB_FRAGS)
+                               i = 0;
+
+                       if (i == md->sg_end)
+                               break;
+               }
+
+               copy = sg[i].length - md->sg_copybreak;
+               to = sg_virt(&sg[i]) + md->sg_copybreak;
+               md->sg_copybreak += copy;
 
                if (sk->sk_route_caps & NETIF_F_NOCACHE_COPY)
                        rc = copy_from_iter_nocache(to, copy, from);
@@ -234,6 +255,7 @@ static int memcopy_from_iter(struct sock *sk,
                if (!bytes)
                        break;
 
+               md->sg_copybreak = 0;
                if (++i == MAX_SKB_FRAGS)
                        i = 0;
        } while (i != md->sg_end);
@@ -328,6 +350,33 @@ static void return_mem_sg(struct sock *sk, int bytes,  
struct sk_msg_buff *md)
        } while (i != md->sg_end);
 }
 
+static void free_bytes_sg(struct sock *sk, int bytes, struct sk_msg_buff *md)
+{
+       struct scatterlist *sg = md->sg_data;
+       int i = md->sg_start, free;
+
+       while (bytes && sg[i].length) {
+               free = sg[i].length;
+               if (bytes < free) {
+                       sg[i].length -= bytes;
+                       sg[i].offset += bytes;
+                       sk_mem_uncharge(sk, bytes);
+                       break;
+               }
+
+               sk_mem_uncharge(sk, sg[i].length);
+               put_page(sg_page(&sg[i]));
+               bytes -= sg[i].length;
+               sg[i].length = 0;
+               sg[i].page_link = 0;
+               sg[i].offset = 0;
+               i++;
+
+               if (i == MAX_SKB_FRAGS)
+                       i = 0;
+       }
+}
+
 static int free_sg(struct sock *sk, int start, struct sk_msg_buff *md)
 {
        struct scatterlist *sg = md->sg_data;
@@ -510,6 +559,9 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t size)
        timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
 
        while (msg_data_left(msg)) {
+               bool cork = false, enospc = false;
+               struct sk_msg_buff *m;
+
                if (sk->sk_err) {
                        err = sk->sk_err;
                        goto out_err;
@@ -519,32 +571,76 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t size)
                if (!sk_stream_memory_free(sk))
                        goto wait_for_sndbuf;
 
-               md.sg_curr = md.sg_end;
-               err = sk_alloc_sg(sk, copy, sg,
-                                 md.sg_start, &md.sg_end, &sg_copy,
-                                 md.sg_end);
+               m = psock->cork_bytes ? psock->cork : &md;
+               m->sg_curr = m->sg_copybreak ? m->sg_curr : m->sg_end;
+               err = sk_alloc_sg(sk, copy, m->sg_data,
+                                 m->sg_start, &m->sg_end, &sg_copy,
+                                 m->sg_end - 1);
                if (err) {
                        if (err != -ENOSPC)
                                goto wait_for_memory;
+                       enospc = true;
                        copy = sg_copy;
                }
 
-               err = memcopy_from_iter(sk, &md, &msg->msg_iter, copy);
+               err = memcopy_from_iter(sk, m, &msg->msg_iter, copy);
                if (err < 0) {
-                       free_curr_sg(sk, &md);
+                       free_curr_sg(sk, m);
                        goto out_err;
                }
 
                psock->sg_size += copy;
                copied += copy;
                sg_copy = 0;
+
+               /* When bytes are being corked skip running BPF program and
+                * applying verdict unless there is no more buffer space. In
+                * the ENOSPC case simply run BPF prorgram with currently
+                * accumulated data. We don't have much choice at this point
+                * we could try extending the page frags or chaining complex
+                * frags but even in these cases _eventually_ we will hit an
+                * OOM scenario. More complex recovery schemes may be
+                * implemented in the future, but BPF programs must handle
+                * the case where apply_cork requests are not honored. The
+                * canonical method to verify this is to check data length.
+                */
+               if (psock->cork_bytes) {
+                       if (copy > psock->cork_bytes)
+                               psock->cork_bytes = 0;
+                       else
+                               psock->cork_bytes -= copy;
+
+                       if (psock->cork_bytes && !enospc)
+                               goto out_cork;
+
+                       /* All cork bytes accounted for re-run filter */
+                       psock->eval = __SK_NONE;
+                       psock->cork_bytes = 0;
+               }
 more_data:
                /* If msg is larger than MAX_SKB_FRAGS we can send multiple
                 * scatterlists per msg. However BPF decisions apply to the
                 * entire msg.
                 */
                if (psock->eval == __SK_NONE)
-                       psock->eval = smap_do_tx_msg(sk, psock, &md);
+                       psock->eval = smap_do_tx_msg(sk, psock, m);
+
+               if (m->cork_bytes &&
+                   m->cork_bytes > psock->sg_size && !enospc) {
+                       psock->cork_bytes = m->cork_bytes - psock->sg_size;
+                       if (!psock->cork) {
+                               psock->cork = kcalloc(1,
+                                               sizeof(struct sk_msg_buff),
+                                               GFP_ATOMIC | __GFP_NOWARN);
+
+                               if (!psock->cork) {
+                                       err = -ENOMEM;
+                                       goto out_err;
+                               }
+                       }
+                       memcpy(psock->cork, m, sizeof(*m));
+                       goto out_cork;
+               }
 
                send = psock->sg_size;
                if (psock->apply_bytes && psock->apply_bytes < send)
@@ -552,9 +648,9 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t size)
 
                switch (psock->eval) {
                case __SK_PASS:
-                       err = bpf_tcp_push(sk, send, &md, flags, true);
+                       err = bpf_tcp_push(sk, send, m, flags, true);
                        if (unlikely(err)) {
-                               copied -= free_start_sg(sk, &md);
+                               copied -= free_start_sg(sk, m);
                                goto out_err;
                        }
 
@@ -576,13 +672,23 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t size)
                                        psock->apply_bytes -= send;
                        }
 
-                       return_mem_sg(sk, send, &md);
+                       if (psock->cork) {
+                               cork = true;
+                               psock->cork = NULL;
+                       }
+
+                       return_mem_sg(sk, send, m);
                        release_sock(sk);
 
                        err = bpf_tcp_sendmsg_do_redirect(redir, send,
-                                                         &md, flags);
+                                                         m, flags);
                        lock_sock(sk);
 
+                       if (cork) {
+                               free_start_sg(sk, m);
+                               kfree(m);
+                               m = NULL;
+                       }
                        if (unlikely(err)) {
                                copied -= err;
                                goto out_redir;
@@ -592,21 +698,23 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t size)
                        break;
                case __SK_DROP:
                default:
-                       copied -= free_start_sg(sk, &md);
-
+                       free_bytes_sg(sk, send, m);
                        if (psock->apply_bytes) {
                                if (psock->apply_bytes < send)
                                        psock->apply_bytes = 0;
                                else
                                        psock->apply_bytes -= send;
                        }
-                       psock->sg_size -= copied;
+                       copied -= send;
+                       psock->sg_size -= send;
                        err = -EACCES;
                        break;
                }
 
                bpf_md_init(psock);
-               if (sg[md.sg_start].page_link && sg[md.sg_start].length)
+               if (m &&
+                   m->sg_data[m->sg_start].page_link &&
+                   m->sg_data[m->sg_start].length)
                        goto more_data;
                continue;
 wait_for_sndbuf:
@@ -623,6 +731,47 @@ static int bpf_tcp_sendmsg(struct sock *sk, struct msghdr 
*msg, size_t size)
        release_sock(sk);
        smap_release_sock(psock, sk);
        return copied ? copied : err;
+out_cork:
+       release_sock(sk);
+       smap_release_sock(psock, sk);
+       return copied;
+}
+
+static int bpf_tcp_sendpage_sg_locked(struct sock *sk,
+                                     struct sk_msg_buff *m,
+                                     int send,
+                                     int flags)
+{
+       int copied = 0;
+
+       do {
+               struct scatterlist *sg = &m->sg_data[m->sg_start];
+               struct page *p = sg_page(sg);
+               int off = sg->offset;
+               int len = sg->length;
+               int err;
+
+               if (len > send)
+                       len = send;
+
+               err = tcp_sendpage_locked(sk, p, off, len, flags);
+               if (err < 0)
+                       break;
+
+               sg->length -= len;
+               sg->offset += len;
+               copied += len;
+               send -= len;
+               if (!sg->length) {
+                       sg->page_link = 0;
+                       put_page(p);
+                       m->sg_start++;
+                       if (m->sg_start == MAX_SKB_FRAGS)
+                               m->sg_start = 0;
+               }
+       } while (send && m->sg_start != m->sg_end);
+
+       return copied;
 }
 
 static int bpf_tcp_sendpage_do_redirect(struct sock *sk,
@@ -644,7 +793,10 @@ static int bpf_tcp_sendpage_do_redirect(struct sock *sk,
        rcu_read_unlock();
 
        lock_sock(sk);
-       rc = tcp_sendpage_locked(sk, page, offset, size, flags);
+       if (md)
+               rc = bpf_tcp_sendpage_sg_locked(sk, md, size, flags);
+       else
+               rc = tcp_sendpage_locked(sk, page, offset, size, flags);
        release_sock(sk);
 
        smap_release_sock(psock, sk);
@@ -657,10 +809,10 @@ static int bpf_tcp_sendpage_do_redirect(struct sock *sk,
 static int bpf_tcp_sendpage(struct sock *sk, struct page *page,
                            int offset, size_t size, int flags)
 {
-       struct sk_msg_buff md = {0};
+       struct sk_msg_buff md = {0}, *m = NULL;
+       bool cork = false, enospc = false;
        struct smap_psock *psock;
-       int send, total = 0, rc = __SK_NONE;
-       int orig_size = size;
+       int send, total = 0, rc;
        struct bpf_prog *prog;
        struct sock *redir;
 
@@ -686,19 +838,90 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page 
*page,
        preempt_enable();
 
        lock_sock(sk);
+
+       psock->sg_size += size;
+do_cork:
+       if (psock->cork_bytes) {
+               struct scatterlist *sg;
+
+               m = psock->cork;
+               sg = &m->sg_data[m->sg_end];
+               sg_set_page(sg, page, send, offset);
+               get_page(page);
+               sk_mem_charge(sk, send);
+               m->sg_end++;
+               cork = true;
+
+               if (send > psock->cork_bytes)
+                       psock->cork_bytes = 0;
+               else
+                       psock->cork_bytes -= send;
+
+               if (m->sg_end == MAX_SKB_FRAGS)
+                       m->sg_end = 0;
+
+               if (m->sg_end == m->sg_start) {
+                       enospc = true;
+                       psock->cork_bytes = 0;
+               }
+
+               if (!psock->cork_bytes)
+                       psock->eval = __SK_NONE;
+
+               if (!enospc && psock->cork_bytes) {
+                       total = send;
+                       goto out_err;
+               }
+       }
 more_sendpage_data:
        if (psock->eval == __SK_NONE)
                psock->eval = smap_do_tx_msg(sk, psock, &md);
 
+       if (md.cork_bytes && !enospc && md.cork_bytes > psock->sg_size) {
+               psock->cork_bytes = md.cork_bytes;
+               if (!psock->cork) {
+                       psock->cork = kzalloc(sizeof(struct sk_msg_buff),
+                                       GFP_ATOMIC | __GFP_NOWARN);
+
+                       if (!psock->cork) {
+                               psock->sg_size -= size;
+                               total = -ENOMEM;
+                               goto out_err;
+                       }
+               }
+
+               if (!cork) {
+                       send = psock->sg_size;
+                       goto do_cork;
+               }
+       }
+
+       send = psock->sg_size;
        if (psock->apply_bytes && psock->apply_bytes < send)
                send = psock->apply_bytes;
 
-       switch (rc) {
+       switch (psock->eval) {
        case __SK_PASS:
-               rc = tcp_sendpage_locked(sk, page, offset, send, flags);
-               if (rc < 0) {
-                       total = total ? : rc;
-                       goto out_err;
+               /* When data is corked once cork bytes limit is reached
+                * we may send more data then the current sendfile call
+                * is expecting. To handle this we have to fixup return
+                * codes. However, if there is an error there is nothing
+                * to do but continue. We can not go back in time and
+                * give errors to data we have already consumed.
+                */
+               if (m) {
+                       rc = bpf_tcp_sendpage_sg_locked(sk, m, send, flags);
+                       if (rc < 0) {
+                               total = total ? : rc;
+                               goto out_err;
+                       }
+                       sk_mem_uncharge(sk, rc);
+               } else {
+                       rc = tcp_sendpage_locked(sk, page, offset, send, flags);
+                       if (rc < 0) {
+                               total = total ? : rc;
+                               goto out_err;
+                       }
                }
 
                if (psock->apply_bytes) {
@@ -711,7 +934,7 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page 
*page,
                total += rc;
                psock->sg_size -= rc;
                offset += rc;
-               size -= rc;
+               send -= rc;
                break;
        case __SK_REDIRECT:
                redir = psock->sk_redir;
@@ -728,12 +951,30 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page 
*page,
                /* sock lock dropped must not dereference psock below */
                rc = bpf_tcp_sendpage_do_redirect(redir,
                                                  page, offset, send,
-                                                 flags, &md);
+                                                 flags, m);
                lock_sock(sk);
-               if (rc > 0) {
-                       offset += rc;
-                       psock->sg_size -= rc;
-                       send -= rc;
+               if (m) {
+                       int free = free_start_sg(sk, m);
+
+                       if (rc > 0) {
+                               sk_mem_uncharge(sk, rc);
+                               free = rc + free;
+                       }
+                       psock->sg_size -= free;
+                       psock->cork_bytes = 0;
+                       send = 0;
+                       if (psock->apply_bytes) {
+                               if (psock->apply_bytes > free)
+                                       psock->apply_bytes -= free;
+                               else
+                                       psock->apply_bytes = 0;
+                       }
+               } else {
+                       if (rc > 0) {
+                               offset += rc;
+                               psock->sg_size -= rc;
+                               send -= rc;
+                       }
                }
 
                if ((total && rc > 0) || (!total && rc < 0))
@@ -741,7 +982,8 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page 
*page,
                break;
        case __SK_DROP:
        default:
-               return_mem_sg(sk, send, &md);
+               if (m)
+                       free_bytes_sg(sk, send, m);
                if (psock->apply_bytes) {
                        if (psock->apply_bytes > send)
                                psock->apply_bytes -= send;
@@ -749,18 +991,17 @@ static int bpf_tcp_sendpage(struct sock *sk, struct page 
*page,
                                psock->apply_bytes -= 0;
                }
                psock->sg_size -= send;
-               size -= send;
-               total += send;
-               rc = -EACCES;
+               total = total ? : -EACCES;
+               goto out_err;
        }
 
        bpf_md_init(psock);
-       if (size)
+       if (psock->sg_size)
                goto more_sendpage_data;
 out_err:
        release_sock(sk);
        smap_release_sock(psock, sk);
-       return total <= orig_size ? total : orig_size;
+       return total <= size ? total : size;
 }
 
 static void bpf_tcp_msg_add(struct smap_psock *psock,
@@ -1077,6 +1318,11 @@ static void smap_gc_work(struct work_struct *w)
        if (psock->bpf_tx_msg)
                bpf_prog_put(psock->bpf_tx_msg);
 
+       if (psock->cork) {
+               free_start_sg(psock->sock, psock->cork);
+               kfree(psock->cork);
+       }
+
        list_for_each_entry_safe(e, tmp, &psock->maps, list) {
                list_del(&e->list);
                kfree(e);
diff --git a/net/core/filter.c b/net/core/filter.c
index df2a8f4..2c73af0 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -1942,6 +1942,20 @@ struct sock *do_msg_redirect_map(struct sk_msg_buff *msg)
        .arg2_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_2(bpf_msg_cork_bytes, struct sk_msg_buff *, msg, u64, bytes)
+{
+       msg->cork_bytes = bytes;
+       return 0;
+}
+
+static const struct bpf_func_proto bpf_msg_cork_bytes_proto = {
+       .func           = bpf_msg_cork_bytes,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 {
        return task_get_classid(skb);
@@ -3650,6 +3664,8 @@ static const struct bpf_func_proto 
*sk_msg_func_proto(enum bpf_func_id func_id)
                return &bpf_msg_redirect_map_proto;
        case BPF_FUNC_msg_apply_bytes:
                return &bpf_msg_apply_bytes_proto;
+       case BPF_FUNC_msg_cork_bytes:
+               return &bpf_msg_cork_bytes_proto;
        default:
                return bpf_base_func_proto(func_id);
        }

Reply via email to