Route RX packets through the netfilter socket when configured.
Key points:
- Add VHOST_NET_FILTER_MAX_LEN upper bound for filter payload size
- Introduce vhost_net_filter_request() to send REQUEST to userspace
- Add handle_rx_filter() fast path for RX when filter is active
- Hook filter path in handle_rx() when filter_sock is set

Signed-off-by: Cindy Lu <[email protected]>
---
 drivers/vhost/net.c | 229 ++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 229 insertions(+)

diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c
index f02deff0e53c..aa9a5ed43eae 100644
--- a/drivers/vhost/net.c
+++ b/drivers/vhost/net.c
@@ -161,6 +161,13 @@ struct vhost_net {
 
 static unsigned vhost_net_zcopy_mask __read_mostly;
 
+/*
+ * Upper bound for a single packet payload on the filter path.
+ * Keep this large enough for the largest expected frame plus vnet headers,
+ * but still bounded to avoid unbounded allocations.
+ */
+#define VHOST_NET_FILTER_MAX_LEN (4096 + 65536)
+
 static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq)
 {
        if (rxq->tail != rxq->head)
@@ -1227,6 +1234,222 @@ static long vhost_net_set_filter(struct vhost_net *n, 
int fd)
        return r;
 }
 
+/*
+ * Send a filter REQUEST message to userspace for a single packet.
+ *
+ * The caller provides a writable buffer; userspace may inspect the content and
+ * optionally modify it in place. We only accept the packet if the returned
+ * length matches the original length, otherwise the packet is dropped.
+ */
+static int vhost_net_filter_request(struct vhost_net *n, u16 direction,
+                                   void *buf, u32 *len)
+{
+       struct vhost_net_filter_msg msg = {
+               .type = VHOST_NET_FILTER_MSG_REQUEST,
+               .direction = direction,
+               .len = *len,
+       };
+       struct msghdr msghdr = { 0 };
+       struct kvec iov[2] = {
+               { .iov_base = &msg, .iov_len = sizeof(msg) },
+               { .iov_base = buf, .iov_len = *len },
+       };
+       struct socket *sock;
+       struct file *sock_file = NULL;
+       int ret;
+
+       /*
+        * Take a temporary file reference to guard against concurrent
+        * filter socket replacement while we send the message.
+        */
+       spin_lock(&n->filter_lock);
+       sock = n->filter_sock;
+       if (sock)
+               sock_file = get_file(sock->file);
+       spin_unlock(&n->filter_lock);
+
+       if (!sock) {
+               ret = -ENOTCONN;
+               goto out_put;
+       }
+
+       ret = kernel_sendmsg(sock, &msghdr, iov,
+                            *len ? 2 : 1, sizeof(msg) + *len);
+
+out_put:
+       if (sock_file)
+               fput(sock_file);
+
+       if (ret < 0)
+               return ret;
+       return 0;
+}
+
+/*
+ * RX fast path when filter offload is active.
+ *
+ * This mirrors handle_rx() but routes each RX packet through userspace
+ * netfilter. Packets are copied into a temporary buffer, sent to the filter
+ * socket as a REQUEST, and only delivered to the guest if userspace keeps the
+ * length unchanged. Any truncation or mismatch drops the packet.
+ */
+static void handle_rx_filter(struct vhost_net *net, struct socket *sock)
+{
+       struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX];
+       struct vhost_virtqueue *vq = &nvq->vq;
+       bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER);
+       unsigned int count = 0;
+       unsigned int in, log;
+       struct vhost_log *vq_log;
+       struct virtio_net_hdr hdr = {
+               .flags = 0,
+               .gso_type = VIRTIO_NET_HDR_GSO_NONE
+       };
+       struct msghdr msg = {
+               .msg_name = NULL,
+               .msg_namelen = 0,
+               .msg_control = NULL,
+               .msg_controllen = 0,
+               .msg_flags = MSG_DONTWAIT,
+       };
+       size_t total_len = 0;
+       int mergeable;
+       bool set_num_buffers;
+       size_t vhost_hlen, sock_hlen;
+       size_t vhost_len, sock_len;
+       bool busyloop_intr = false;
+       struct iov_iter fixup;
+       __virtio16 num_buffers;
+       int recv_pkts = 0;
+       unsigned int ndesc;
+       void *pkt;
+
+       pkt = kvmalloc(VHOST_NET_FILTER_MAX_LEN, GFP_KERNEL | __GFP_NOWARN);
+       if (!pkt) {
+               vhost_net_enable_vq(net, vq);
+               return;
+       }
+
+       vhost_hlen = nvq->vhost_hlen;
+       sock_hlen = nvq->sock_hlen;
+
+       vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ? vq->log : 
NULL;
+       mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF);
+       set_num_buffers = mergeable || vhost_has_feature(vq, 
VIRTIO_F_VERSION_1);
+
+       do {
+               u32 pkt_len;
+               int err;
+               s16 headcount;
+               struct kvec iov;
+
+               sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
+                                                     &busyloop_intr, &count);
+               if (!sock_len)
+                       break;
+               sock_len += sock_hlen;
+               if (sock_len > VHOST_NET_FILTER_MAX_LEN) {
+                       /* Consume and drop oversized packet. */
+                       iov.iov_base = pkt;
+                       iov.iov_len = 1;
+                       kernel_recvmsg(sock, &msg, &iov, 1, 1,
+                                      MSG_DONTWAIT | MSG_TRUNC);
+                       continue;
+               }
+
+               vhost_len = sock_len + vhost_hlen;
+               headcount = get_rx_bufs(nvq, vq->heads + count,
+                                       vq->nheads + count, vhost_len, &in,
+                                       vq_log, &log,
+                                       likely(mergeable) ? UIO_MAXIOV : 1,
+                                       &ndesc);
+               if (unlikely(headcount < 0))
+                       goto out;
+
+               if (!headcount) {
+                       if (unlikely(busyloop_intr)) {
+                               vhost_poll_queue(&vq->poll);
+                       } else if (unlikely(vhost_enable_notify(&net->dev, 
vq))) {
+                               vhost_disable_notify(&net->dev, vq);
+                               continue;
+                       }
+                       goto out;
+               }
+
+               busyloop_intr = false;
+
+               if (nvq->rx_ring)
+                       msg.msg_control = vhost_net_buf_consume(&nvq->rxq);
+
+               iov.iov_base = pkt;
+               iov.iov_len = sock_len;
+               err = kernel_recvmsg(sock, &msg, &iov, 1, sock_len,
+                                    MSG_DONTWAIT | MSG_TRUNC);
+               if (unlikely(err != sock_len)) {
+                       vhost_discard_vq_desc(vq, headcount, ndesc);
+                       continue;
+               }
+
+               pkt_len = sock_len;
+               err = vhost_net_filter_request(net, 
VHOST_NET_FILTER_DIRECTION_TX,
+                                              pkt, &pkt_len);
+               if (err < 0)
+                       pkt_len = sock_len;
+               if (pkt_len != sock_len) {
+                       vhost_discard_vq_desc(vq, headcount, ndesc);
+                       continue;
+               }
+
+               iov_iter_init(&msg.msg_iter, ITER_DEST, vq->iov, in, vhost_len);
+               fixup = msg.msg_iter;
+               if (unlikely(vhost_hlen))
+                       iov_iter_advance(&msg.msg_iter, vhost_hlen);
+
+               if (copy_to_iter(pkt, sock_len, &msg.msg_iter) != sock_len) {
+                       vhost_discard_vq_desc(vq, headcount, ndesc);
+                       goto out;
+               }
+
+               if (unlikely(vhost_hlen)) {
+                       if (copy_to_iter(&hdr, sizeof(hdr),
+                                        &fixup) != sizeof(hdr)) {
+                               vhost_discard_vq_desc(vq, headcount, ndesc);
+                               goto out;
+                       }
+               } else {
+                       iov_iter_advance(&fixup, sizeof(hdr));
+               }
+
+               num_buffers = cpu_to_vhost16(vq, headcount);
+               if (likely(set_num_buffers) &&
+                   copy_to_iter(&num_buffers, sizeof(num_buffers), &fixup) !=
+                           sizeof(num_buffers)) {
+                       vhost_discard_vq_desc(vq, headcount, ndesc);
+                       goto out;
+               }
+
+               nvq->done_idx += headcount;
+               count += in_order ? 1 : headcount;
+               if (nvq->done_idx > VHOST_NET_BATCH) {
+                       vhost_net_signal_used(nvq, count);
+                       count = 0;
+               }
+
+               if (unlikely(vq_log))
+                       vhost_log_write(vq, vq_log, log, vhost_len, vq->iov, 
in);
+
+               total_len += vhost_len;
+       } while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len)));
+
+       if (unlikely(busyloop_intr))
+               vhost_poll_queue(&vq->poll);
+       else if (!sock_len)
+               vhost_net_enable_vq(net, vq);
+
+out:
+       vhost_net_signal_used(nvq, count);
+       kvfree(pkt);
+}
 /* Expects to be always run from workqueue - which acts as
  * read-size critical section for our kind of RCU. */
 static void handle_rx(struct vhost_net *net)
@@ -1281,6 +1504,11 @@ static void handle_rx(struct vhost_net *net)
        set_num_buffers = mergeable ||
                          vhost_has_feature(vq, VIRTIO_F_VERSION_1);
 
+       if (READ_ONCE(net->filter_sock)) {
+               handle_rx_filter(net, sock);
+               goto out_unlock;
+       }
+
        do {
                sock_len = vhost_net_rx_peek_head_len(net, sock->sk,
                                                      &busyloop_intr, &count);
@@ -1383,6 +1611,7 @@ static void handle_rx(struct vhost_net *net)
                vhost_net_enable_vq(net, vq);
 out:
        vhost_net_signal_used(nvq, count);
+out_unlock:
        mutex_unlock(&vq->mutex);
 }
 
-- 
2.52.0


Reply via email to