This adds a BPF SK_MSG program helper so that we can pop data from a
msg. We use this to pop metadata from a previous push data call.

Signed-off-by: John Fastabend <john.fastab...@gmail.com>
---
 include/uapi/linux/bpf.h |  16 ++++-
 net/core/filter.c        | 171 +++++++++++++++++++++++++++++++++++++++++++++++
 net/ipv4/tcp_bpf.c       |  17 ++++-
 net/tls/tls_sw.c         |  11 ++-
 4 files changed, 209 insertions(+), 6 deletions(-)

diff --git a/include/uapi/linux/bpf.h b/include/uapi/linux/bpf.h
index 23e2031..597afdb 100644
--- a/include/uapi/linux/bpf.h
+++ b/include/uapi/linux/bpf.h
@@ -2268,6 +2268,19 @@ union bpf_attr {
  *
  *     Return
  *             0 on success, or a negative error in case of failure.
+ *
+ * int bpf_msg_pop_data(struct sk_msg_buff *msg, u32 start, u32 pop, u64 flags)
+ *      Description
+ *             Will remove *pop* bytes from a *msg* starting at byte *start*.
+ *             This may result in **ENOMEM** errors under certain situations if
+ *             an allocation and copy are required due to a full ring buffer.
+ *             However, the helper will try to avoid doing the allocation
+ *             if possible. Other errors can occur if input parameters are
+ *             invalid either due to *start* byte not being valid part of msg
+ *             payload and/or *pop* value being to large.
+ *
+ *     Return
+ *             0 on success, or a negative erro in case of failure.
  */
 #define __BPF_FUNC_MAPPER(FN)          \
        FN(unspec),                     \
@@ -2360,7 +2373,8 @@ union bpf_attr {
        FN(map_push_elem),              \
        FN(map_pop_elem),               \
        FN(map_peek_elem),              \
-       FN(msg_push_data),
+       FN(msg_push_data),              \
+       FN(msg_pop_data),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
diff --git a/net/core/filter.c b/net/core/filter.c
index f50ea97..bd0df75 100644
--- a/net/core/filter.c
+++ b/net/core/filter.c
@@ -2425,6 +2425,174 @@ static const struct bpf_func_proto 
bpf_msg_push_data_proto = {
        .arg4_type      = ARG_ANYTHING,
 };
 
+static void sk_msg_shift_left(struct sk_msg *msg, int i)
+{
+       int prev;
+
+       do {
+               prev = i;
+               sk_msg_iter_var_next(i);
+               msg->sg.data[prev] = msg->sg.data[i];
+       } while (i != msg->sg.end);
+
+       sk_msg_iter_prev(msg, end);
+}
+
+static void sk_msg_shift_right(struct sk_msg *msg, int i)
+{
+       struct scatterlist tmp, sge;
+
+       sk_msg_iter_next(msg, end);
+       sge = sk_msg_elem_cpy(msg, i);
+       sk_msg_iter_var_next(i);
+       tmp = sk_msg_elem_cpy(msg, i);
+
+       while (i != msg->sg.end) {
+               msg->sg.data[i] = sge;
+               sk_msg_iter_var_next(i);
+               sge = tmp;
+               tmp = sk_msg_elem_cpy(msg, i);
+       }
+}
+
+BPF_CALL_4(bpf_msg_pop_data, struct sk_msg *, msg, u32, start,
+          u32, len, u64, flags)
+{
+       u32 i = 0, l, space, offset = 0;
+       u64 last = start + len;
+       int pop;
+
+       if (unlikely(flags))
+               return -EINVAL;
+
+       /* First find the starting scatterlist element */
+       i = msg->sg.start;
+       do {
+               l = sk_msg_elem(msg, i)->length;
+
+               if (start < offset + l)
+                       break;
+               offset += l;
+               sk_msg_iter_var_next(i);
+       } while (i != msg->sg.end);
+
+       /* Bounds checks: start and pop must be inside message */
+       if (start >= offset + l || last >= msg->sg.size)
+               return -EINVAL;
+
+       space = MAX_MSG_FRAGS - sk_msg_elem_used(msg);
+
+       pop = len;
+       /* --------------| offset
+        * -| start      |-------- len -------|
+        *
+        *  |----- a ----|-------- pop -------|----- b ----|
+        *  |______________________________________________| length
+        *
+        *
+        * a:   region at front of scatter element to save
+        * b:   region at back of scatter element to save when length > A + pop
+        * pop: region to pop from element, same as input 'pop' here will be
+        *      decremented below per iteration.
+        *
+        * Two top-level cases to handle when start != offset, first B is non
+        * zero and second B is zero corresponding to when a pop includes more
+        * than one element.
+        *
+        * Then if B is non-zero AND there is no space allocate space and
+        * compact A, B regions into page. If there is space shift ring to
+        * the rigth free'ing the next element in ring to place B, leaving
+        * A untouched except to reduce length.
+        */
+       if (start != offset) {
+               struct scatterlist *nsge, *sge = sk_msg_elem(msg, i);
+               int a = start;
+               int b = sge->length - pop - a;
+
+               sk_msg_iter_var_next(i);
+
+               if (pop < sge->length - a) {
+                       if (space) {
+                               sge->length = a;
+                               sk_msg_shift_right(msg, i);
+                               nsge = sk_msg_elem(msg, i);
+                               get_page(sg_page(sge));
+                               sg_set_page(nsge,
+                                           sg_page(sge),
+                                           b, sge->offset + pop + a);
+                       } else {
+                               struct page *page, *orig;
+                               u8 *to, *from;
+
+                               page = alloc_pages(__GFP_NOWARN |
+                                                  __GFP_COMP   | GFP_ATOMIC,
+                                                  get_order(a + b));
+                               if (unlikely(!page))
+                                       return -ENOMEM;
+
+                               sge->length = a;
+                               orig = sg_page(sge);
+                               from = sg_virt(sge);
+                               to = page_address(page);
+                               memcpy(to, from, a);
+                               memcpy(to + a, from + a + pop, b);
+                               sg_set_page(sge, page, a + b, 0);
+                               put_page(orig);
+                       }
+                       pop = 0;
+               } else if (pop >= sge->length - a) {
+                       sge->length = a;
+                       pop -= (sge->length - a);
+               }
+       }
+
+       /* From above the current layout _must_ be as follows,
+        *
+        * -| offset
+        * -| start
+        *
+        *  |---- pop ---|---------------- b ------------|
+        *  |____________________________________________| length
+        *
+        * Offset and start of the current msg elem are equal because in the
+        * previous case we handled offset != start and either consumed the
+        * entire element and advanced to the next element OR pop == 0.
+        *
+        * Two cases to handle here are first pop is less than the length
+        * leaving some remainder b above. Simply adjust the element's layout
+        * in this case. Or pop >= length of the element so that b = 0. In this
+        * case advance to next element decrementing pop.
+        */
+       while (pop) {
+               struct scatterlist *sge = sk_msg_elem(msg, i);
+
+               if (pop < sge->length) {
+                       sge->length -= pop;
+                       sge->offset += pop;
+                       pop = 0;
+               } else {
+                       pop -= sge->length;
+                       sk_msg_shift_left(msg, i);
+               }
+               sk_msg_iter_var_next(i);
+       }
+
+       sk_mem_uncharge(msg->sk, len - pop);
+       msg->sg.size -= (len - pop);
+       sk_msg_compute_data_pointers(msg);
+       return 0;
+}
+
+static const struct bpf_func_proto bpf_msg_pop_data_proto = {
+       .func           = bpf_msg_pop_data,
+       .gpl_only       = false,
+       .ret_type       = RET_INTEGER,
+       .arg1_type      = ARG_PTR_TO_CTX,
+       .arg2_type      = ARG_ANYTHING,
+       .arg3_type      = ARG_ANYTHING,
+       .arg4_type      = ARG_ANYTHING,
+};
+
 BPF_CALL_1(bpf_get_cgroup_classid, const struct sk_buff *, skb)
 {
        return task_get_classid(skb);
@@ -5098,6 +5266,7 @@ bool bpf_helper_changes_pkt_data(void *func)
            func == bpf_xdp_adjust_meta ||
            func == bpf_msg_pull_data ||
            func == bpf_msg_push_data ||
+           func == bpf_msg_pop_data ||
            func == bpf_xdp_adjust_tail ||
 #if IS_ENABLED(CONFIG_IPV6_SEG6_BPF)
            func == bpf_lwt_seg6_store_bytes ||
@@ -5394,6 +5563,8 @@ sk_msg_func_proto(enum bpf_func_id func_id, const struct 
bpf_prog *prog)
                return &bpf_msg_pull_data_proto;
        case BPF_FUNC_msg_push_data:
                return &bpf_msg_push_data_proto;
+       case BPF_FUNC_msg_pop_data:
+               return &bpf_msg_pop_data_proto;
        default:
                return bpf_base_func_proto(func_id);
        }
diff --git a/net/ipv4/tcp_bpf.c b/net/ipv4/tcp_bpf.c
index 3b45fe5..a47c1cd 100644
--- a/net/ipv4/tcp_bpf.c
+++ b/net/ipv4/tcp_bpf.c
@@ -289,12 +289,23 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct 
sk_psock *psock,
 {
        bool cork = false, enospc = msg->sg.start == msg->sg.end;
        struct sock *sk_redir;
-       u32 tosend;
+       u32 tosend, delta = 0;
        int ret;
 
 more_data:
-       if (psock->eval == __SK_NONE)
+       if (psock->eval == __SK_NONE) {
+               /* Track delta in msg size to add/subtract it on SK_DROP from
+                * returned to user copied size. This ensures user doesn't
+                * get a positive return code with msg_cut_data and SK_DROP
+                * verdict.
+                */
+               delta = msg->sg.size;
                psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+               if (msg->sg.size < delta)
+                       delta -= msg->sg.size;
+               else
+                       delta = 0;
+       }
 
        if (msg->cork_bytes &&
            msg->cork_bytes > msg->sg.size && !enospc) {
@@ -350,7 +361,7 @@ static int tcp_bpf_send_verdict(struct sock *sk, struct 
sk_psock *psock,
        default:
                sk_msg_free_partial(sk, msg, tosend);
                sk_msg_apply_bytes(psock, tosend);
-               *copied -= tosend;
+               *copied -= (tosend + delta);
                return -EACCES;
        }
 
diff --git a/net/tls/tls_sw.c b/net/tls/tls_sw.c
index 7b1af8b..d4ecc66 100644
--- a/net/tls/tls_sw.c
+++ b/net/tls/tls_sw.c
@@ -687,6 +687,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct 
sock *sk,
        struct sock *sk_redir;
        struct tls_rec *rec;
        int err = 0, send;
+       u32 delta = 0;
        bool enospc;
 
        psock = sk_psock_get(sk);
@@ -694,8 +695,14 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct 
sock *sk,
                return tls_push_record(sk, flags, record_type);
 more_data:
        enospc = sk_msg_full(msg);
-       if (psock->eval == __SK_NONE)
+       if (psock->eval == __SK_NONE) {
+               delta = msg->sg.size;
                psock->eval = sk_psock_msg_verdict(sk, psock, msg);
+               if (delta < msg->sg.size)
+                       delta -= msg->sg.size;
+               else
+                       delta = 0;
+       }
        if (msg->cork_bytes && msg->cork_bytes > msg->sg.size &&
            !enospc && !full_record) {
                err = -ENOSPC;
@@ -743,7 +750,7 @@ static int bpf_exec_tx_verdict(struct sk_msg *msg, struct 
sock *sk,
                        msg->apply_bytes -= send;
                if (msg->sg.size == 0)
                        tls_free_open_rec(sk);
-               *copied -= send;
+               *copied -= (send + delta);
                err = -EACCES;
        }
 
-- 
2.7.4

Reply via email to