Route RX packets through the netfilter socket when configured. Key points: - Add VHOST_NET_FILTER_MAX_LEN upper bound for filter payload size - Introduce vhost_net_filter_request() to send REQUEST to userspace - Add handle_rx_filter() fast path for RX when filter is active - Hook filter path in handle_rx() when filter_sock is set
Signed-off-by: Cindy Lu <[email protected]> --- drivers/vhost/net.c | 229 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 229 insertions(+) diff --git a/drivers/vhost/net.c b/drivers/vhost/net.c index f02deff0e53c..aa9a5ed43eae 100644 --- a/drivers/vhost/net.c +++ b/drivers/vhost/net.c @@ -161,6 +161,13 @@ struct vhost_net { static unsigned vhost_net_zcopy_mask __read_mostly; +/* + * Upper bound for a single packet payload on the filter path. + * Keep this large enough for the largest expected frame plus vnet headers, + * but still bounded to avoid unbounded allocations. + */ +#define VHOST_NET_FILTER_MAX_LEN (4096 + 65536) + static void *vhost_net_buf_get_ptr(struct vhost_net_buf *rxq) { if (rxq->tail != rxq->head) @@ -1227,6 +1234,222 @@ static long vhost_net_set_filter(struct vhost_net *n, int fd) return r; } +/* + * Send a filter REQUEST message to userspace for a single packet. + * + * The caller provides a writable buffer; userspace may inspect the content and + * optionally modify it in place. We only accept the packet if the returned + * length matches the original length, otherwise the packet is dropped. + */ +static int vhost_net_filter_request(struct vhost_net *n, u16 direction, + void *buf, u32 *len) +{ + struct vhost_net_filter_msg msg = { + .type = VHOST_NET_FILTER_MSG_REQUEST, + .direction = direction, + .len = *len, + }; + struct msghdr msghdr = { 0 }; + struct kvec iov[2] = { + { .iov_base = &msg, .iov_len = sizeof(msg) }, + { .iov_base = buf, .iov_len = *len }, + }; + struct socket *sock; + struct file *sock_file = NULL; + int ret; + + /* + * Take a temporary file reference to guard against concurrent + * filter socket replacement while we send the message. + */ + spin_lock(&n->filter_lock); + sock = n->filter_sock; + if (sock) + sock_file = get_file(sock->file); + spin_unlock(&n->filter_lock); + + if (!sock) { + ret = -ENOTCONN; + goto out_put; + } + + ret = kernel_sendmsg(sock, &msghdr, iov, + *len ? 2 : 1, sizeof(msg) + *len); + +out_put: + if (sock_file) + fput(sock_file); + + if (ret < 0) + return ret; + return 0; +} + +/* + * RX fast path when filter offload is active. + * + * This mirrors handle_rx() but routes each RX packet through userspace + * netfilter. Packets are copied into a temporary buffer, sent to the filter + * socket as a REQUEST, and only delivered to the guest if userspace keeps the + * length unchanged. Any truncation or mismatch drops the packet. + */ +static void handle_rx_filter(struct vhost_net *net, struct socket *sock) +{ + struct vhost_net_virtqueue *nvq = &net->vqs[VHOST_NET_VQ_RX]; + struct vhost_virtqueue *vq = &nvq->vq; + bool in_order = vhost_has_feature(vq, VIRTIO_F_IN_ORDER); + unsigned int count = 0; + unsigned int in, log; + struct vhost_log *vq_log; + struct virtio_net_hdr hdr = { + .flags = 0, + .gso_type = VIRTIO_NET_HDR_GSO_NONE + }; + struct msghdr msg = { + .msg_name = NULL, + .msg_namelen = 0, + .msg_control = NULL, + .msg_controllen = 0, + .msg_flags = MSG_DONTWAIT, + }; + size_t total_len = 0; + int mergeable; + bool set_num_buffers; + size_t vhost_hlen, sock_hlen; + size_t vhost_len, sock_len; + bool busyloop_intr = false; + struct iov_iter fixup; + __virtio16 num_buffers; + int recv_pkts = 0; + unsigned int ndesc; + void *pkt; + + pkt = kvmalloc(VHOST_NET_FILTER_MAX_LEN, GFP_KERNEL | __GFP_NOWARN); + if (!pkt) { + vhost_net_enable_vq(net, vq); + return; + } + + vhost_hlen = nvq->vhost_hlen; + sock_hlen = nvq->sock_hlen; + + vq_log = unlikely(vhost_has_feature(vq, VHOST_F_LOG_ALL)) ? vq->log : NULL; + mergeable = vhost_has_feature(vq, VIRTIO_NET_F_MRG_RXBUF); + set_num_buffers = mergeable || vhost_has_feature(vq, VIRTIO_F_VERSION_1); + + do { + u32 pkt_len; + int err; + s16 headcount; + struct kvec iov; + + sock_len = vhost_net_rx_peek_head_len(net, sock->sk, + &busyloop_intr, &count); + if (!sock_len) + break; + sock_len += sock_hlen; + if (sock_len > VHOST_NET_FILTER_MAX_LEN) { + /* Consume and drop oversized packet. */ + iov.iov_base = pkt; + iov.iov_len = 1; + kernel_recvmsg(sock, &msg, &iov, 1, 1, + MSG_DONTWAIT | MSG_TRUNC); + continue; + } + + vhost_len = sock_len + vhost_hlen; + headcount = get_rx_bufs(nvq, vq->heads + count, + vq->nheads + count, vhost_len, &in, + vq_log, &log, + likely(mergeable) ? UIO_MAXIOV : 1, + &ndesc); + if (unlikely(headcount < 0)) + goto out; + + if (!headcount) { + if (unlikely(busyloop_intr)) { + vhost_poll_queue(&vq->poll); + } else if (unlikely(vhost_enable_notify(&net->dev, vq))) { + vhost_disable_notify(&net->dev, vq); + continue; + } + goto out; + } + + busyloop_intr = false; + + if (nvq->rx_ring) + msg.msg_control = vhost_net_buf_consume(&nvq->rxq); + + iov.iov_base = pkt; + iov.iov_len = sock_len; + err = kernel_recvmsg(sock, &msg, &iov, 1, sock_len, + MSG_DONTWAIT | MSG_TRUNC); + if (unlikely(err != sock_len)) { + vhost_discard_vq_desc(vq, headcount, ndesc); + continue; + } + + pkt_len = sock_len; + err = vhost_net_filter_request(net, VHOST_NET_FILTER_DIRECTION_TX, + pkt, &pkt_len); + if (err < 0) + pkt_len = sock_len; + if (pkt_len != sock_len) { + vhost_discard_vq_desc(vq, headcount, ndesc); + continue; + } + + iov_iter_init(&msg.msg_iter, ITER_DEST, vq->iov, in, vhost_len); + fixup = msg.msg_iter; + if (unlikely(vhost_hlen)) + iov_iter_advance(&msg.msg_iter, vhost_hlen); + + if (copy_to_iter(pkt, sock_len, &msg.msg_iter) != sock_len) { + vhost_discard_vq_desc(vq, headcount, ndesc); + goto out; + } + + if (unlikely(vhost_hlen)) { + if (copy_to_iter(&hdr, sizeof(hdr), + &fixup) != sizeof(hdr)) { + vhost_discard_vq_desc(vq, headcount, ndesc); + goto out; + } + } else { + iov_iter_advance(&fixup, sizeof(hdr)); + } + + num_buffers = cpu_to_vhost16(vq, headcount); + if (likely(set_num_buffers) && + copy_to_iter(&num_buffers, sizeof(num_buffers), &fixup) != + sizeof(num_buffers)) { + vhost_discard_vq_desc(vq, headcount, ndesc); + goto out; + } + + nvq->done_idx += headcount; + count += in_order ? 1 : headcount; + if (nvq->done_idx > VHOST_NET_BATCH) { + vhost_net_signal_used(nvq, count); + count = 0; + } + + if (unlikely(vq_log)) + vhost_log_write(vq, vq_log, log, vhost_len, vq->iov, in); + + total_len += vhost_len; + } while (likely(!vhost_exceeds_weight(vq, ++recv_pkts, total_len))); + + if (unlikely(busyloop_intr)) + vhost_poll_queue(&vq->poll); + else if (!sock_len) + vhost_net_enable_vq(net, vq); + +out: + vhost_net_signal_used(nvq, count); + kvfree(pkt); +} /* Expects to be always run from workqueue - which acts as * read-size critical section for our kind of RCU. */ static void handle_rx(struct vhost_net *net) @@ -1281,6 +1504,11 @@ static void handle_rx(struct vhost_net *net) set_num_buffers = mergeable || vhost_has_feature(vq, VIRTIO_F_VERSION_1); + if (READ_ONCE(net->filter_sock)) { + handle_rx_filter(net, sock); + goto out_unlock; + } + do { sock_len = vhost_net_rx_peek_head_len(net, sock->sk, &busyloop_intr, &count); @@ -1383,6 +1611,7 @@ static void handle_rx(struct vhost_net *net) vhost_net_enable_vq(net, vq); out: vhost_net_signal_used(nvq, count); +out_unlock: mutex_unlock(&vq->mutex); } -- 2.52.0

