This adds rest of logic for SEQPACKET: 1) SEQPACKET specific functions which send SEQ_BEGIN/SEQ_END. Note that both functions may sleep to wait enough space for SEQPACKET header. 2) SEQ_BEGIN/SEQ_END in TAP packet capture. 3) Send SHUTDOWN on socket close for SEQPACKET type. 4) Set SEQPACKET packet type during send. 5) Set MSG_EOR in flags for SEQPACKET during send.
Signed-off-by: Arseny Krasnov <arseny.kras...@kaspersky.com> --- include/linux/virtio_vsock.h | 3 ++ net/vmw_vsock/virtio_transport_common.c | 67 ++++++++++++++++++++++++- 2 files changed, 68 insertions(+), 2 deletions(-) diff --git a/include/linux/virtio_vsock.h b/include/linux/virtio_vsock.h index 022667d57884..bf09d9aafa20 100644 --- a/include/linux/virtio_vsock.h +++ b/include/linux/virtio_vsock.h @@ -41,6 +41,7 @@ struct virtio_vsock_sock { u32 user_read_seq_len; u32 user_read_copied; u32 curr_rx_msg_cnt; + u32 next_tx_msg_cnt; }; struct virtio_vsock_pkt { @@ -85,6 +86,8 @@ virtio_transport_dgram_dequeue(struct vsock_sock *vsk, struct msghdr *msg, size_t len, int flags); +int virtio_transport_seqpacket_seq_send_len(struct vsock_sock *vsk, size_t len, int flags); +int virtio_transport_seqpacket_seq_send_eor(struct vsock_sock *vsk, int flags); size_t virtio_transport_seqpacket_seq_get_len(struct vsock_sock *vsk); int virtio_transport_seqpacket_dequeue(struct vsock_sock *vsk, diff --git a/net/vmw_vsock/virtio_transport_common.c b/net/vmw_vsock/virtio_transport_common.c index 3ca0009c553e..8431d0a891ed 100644 --- a/net/vmw_vsock/virtio_transport_common.c +++ b/net/vmw_vsock/virtio_transport_common.c @@ -139,6 +139,8 @@ static struct sk_buff *virtio_transport_build_skb(void *opaque) break; case VIRTIO_VSOCK_OP_CREDIT_UPDATE: case VIRTIO_VSOCK_OP_CREDIT_REQUEST: + case VIRTIO_VSOCK_OP_SEQ_BEGIN: + case VIRTIO_VSOCK_OP_SEQ_END: hdr->op = cpu_to_le16(AF_VSOCK_OP_CONTROL); break; default: @@ -187,7 +189,12 @@ static int virtio_transport_send_pkt_info(struct vsock_sock *vsk, struct virtio_vsock_pkt *pkt; u32 pkt_len = info->pkt_len; - info->type = VIRTIO_VSOCK_TYPE_STREAM; + info->type = virtio_transport_get_type(sk_vsock(vsk)); + + if (info->type == VIRTIO_VSOCK_TYPE_SEQPACKET && + info->msg && + info->msg->msg_flags & MSG_EOR) + info->flags |= VIRTIO_VSOCK_RW_EOR; t_ops = virtio_transport_get_ops(vsk); if (unlikely(!t_ops)) @@ -401,6 +408,62 @@ virtio_transport_stream_do_dequeue(struct vsock_sock *vsk, return err; } +static int virtio_transport_seqpacket_send_ctrl(struct vsock_sock *vsk, + int type, + size_t len, + int flags) +{ + struct virtio_vsock_sock *vvs = vsk->trans; + struct virtio_vsock_pkt_info info = { + .op = type, + .vsk = vsk, + .pkt_len = sizeof(struct virtio_vsock_seq_hdr) + }; + + struct virtio_vsock_seq_hdr seq_hdr = { + .msg_cnt = cpu_to_le32(vvs->next_tx_msg_cnt), + .msg_len = cpu_to_le32(len) + }; + + struct kvec seq_hdr_kiov = { + .iov_base = (void *)&seq_hdr, + .iov_len = sizeof(struct virtio_vsock_seq_hdr) + }; + + struct msghdr msg = {0}; + + //XXX: do we need 'vsock_transport_send_notify_data' pointer? + if (vsock_wait_space(sk_vsock(vsk), + sizeof(struct virtio_vsock_seq_hdr), + flags, NULL)) + return -1; + + iov_iter_kvec(&msg.msg_iter, WRITE, &seq_hdr_kiov, 1, sizeof(seq_hdr)); + + info.msg = &msg; + vvs->next_tx_msg_cnt++; + + return virtio_transport_send_pkt_info(vsk, &info); +} + +int virtio_transport_seqpacket_seq_send_len(struct vsock_sock *vsk, size_t len, int flags) +{ + return virtio_transport_seqpacket_send_ctrl(vsk, + VIRTIO_VSOCK_OP_SEQ_BEGIN, + len, + flags); +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_seq_send_len); + +int virtio_transport_seqpacket_seq_send_eor(struct vsock_sock *vsk, int flags) +{ + return virtio_transport_seqpacket_send_ctrl(vsk, + VIRTIO_VSOCK_OP_SEQ_END, + 0, + flags); +} +EXPORT_SYMBOL_GPL(virtio_transport_seqpacket_seq_send_eor); + static inline void virtio_transport_remove_pkt(struct virtio_vsock_pkt *pkt) { list_del(&pkt->list); @@ -999,7 +1062,7 @@ void virtio_transport_release(struct vsock_sock *vsk) struct sock *sk = &vsk->sk; bool remove_sock = true; - if (sk->sk_type == SOCK_STREAM) + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) remove_sock = virtio_transport_close(vsk); list_for_each_entry_safe(pkt, tmp, &vvs->rx_queue, list) { -- 2.25.1