On Tue, 2024-04-09 at 18:57 +0800, Geliang Tang wrote:
> From: Geliang Tang <[email protected]>
>
> This patch extracts the code to send and receive data into a new
> helper named send_recv_data() in network_helpers.c and export it
> in network_helpers.h.
>
> This helper will be used for MPTCP BPF selftests.
>
> Signed-off-by: Geliang Tang <[email protected]>
> ---
> tools/testing/selftests/bpf/network_helpers.c | 96
> +++++++++++++++++++
> tools/testing/selftests/bpf/network_helpers.h | 1 +
> .../selftests/bpf/prog_tests/bpf_tcp_ca.c | 81 +---------------
> 3 files changed, 98 insertions(+), 80 deletions(-)
>
> diff --git a/tools/testing/selftests/bpf/network_helpers.c
> b/tools/testing/selftests/bpf/network_helpers.c
> index dbcbe2ac51ba..55d41508fe1f 100644
> --- a/tools/testing/selftests/bpf/network_helpers.c
> +++ b/tools/testing/selftests/bpf/network_helpers.c
> @@ -555,3 +555,99 @@ int set_hw_ring_size(char *ifname, struct
> ethtool_ringparam *ring_param)
> close(sockfd);
> return 0;
> }
> +
> +struct send_recv_arg {
> + int fd;
> + uint32_t bytes;
> + int stop;
> +};
> +
> +static void *send_recv_server(void *arg)
> +{
> + struct send_recv_arg *a = (struct send_recv_arg *)arg;
> + ssize_t nr_sent = 0, bytes = 0;
> + char batch[1500];
> + int err = 0, fd;
> +
> + fd = accept(a->fd, NULL, NULL);
> + while (fd == -1) {
> + if (errno == EINTR)
> + continue;
> + err = -errno;
> + goto done;
> + }
> +
> + if (settimeo(fd, 0)) {
> + err = -errno;
> + goto done;
> + }
> +
> + while (bytes < a->bytes && !READ_ONCE(a->stop)) {
> + nr_sent = send(fd, &batch,
> + MIN(a->bytes - bytes, sizeof(batch)),
> 0);
> + if (nr_sent == -1 && errno == EINTR)
> + continue;
> + if (nr_sent == -1) {
> + err = -errno;
> + break;
> + }
> + bytes += nr_sent;
> + }
> +
> + if (bytes != a->bytes)
> + log_err("send");
> +
> +done:
> + if (fd >= 0)
> + close(fd);
> + if (err) {
> + WRITE_ONCE(a->stop, 1);
> + return ERR_PTR(err);
> + }
> + return NULL;
> +}
> +
> +int send_recv_data(int lfd, int fd, uint32_t total_bytes)
> +{
> + ssize_t nr_recv = 0, bytes = 0;
> + struct send_recv_arg arg = {
> + .fd = lfd,
> + .bytes = total_bytes,
> + .stop = 0,
> + };
> + pthread_t srv_thread;
> + void *thread_ret;
> + char batch[1500];
> + int err;
> +
> + err = pthread_create(&srv_thread, NULL, send_recv_server,
> (void *)&arg);
> + if (!err) {
Sorry, here should be 'if (err)'.
Changes Requested.
-Geliang
> + log_err("pthread_create");
> + return err;
> + }
> +
> + /* recv total_bytes */
> + while (bytes < total_bytes && !READ_ONCE(arg.stop)) {
> + nr_recv = recv(fd, &batch,
> + MIN(total_bytes - bytes,
> sizeof(batch)), 0);
> + if (nr_recv == -1 && errno == EINTR)
> + continue;
> + if (nr_recv == -1)
> + break;
> + bytes += nr_recv;
> + }
> +
> + if (bytes != total_bytes) {
> + log_err("recv");
> + return -1;
> + }
> +
> + WRITE_ONCE(arg.stop, 1);
> + pthread_join(srv_thread, &thread_ret);
> + if (IS_ERR(thread_ret)) {
> + log_err("thread_ret");
> + return -1;
> + }
> +
> + return 0;
> +}
> diff --git a/tools/testing/selftests/bpf/network_helpers.h
> b/tools/testing/selftests/bpf/network_helpers.h
> index 6457445cc6e2..70f4e4c92733 100644
> --- a/tools/testing/selftests/bpf/network_helpers.h
> +++ b/tools/testing/selftests/bpf/network_helpers.h
> @@ -76,6 +76,7 @@ struct nstoken;
> */
> struct nstoken *open_netns(const char *name);
> void close_netns(struct nstoken *token);
> +int send_recv_data(int lfd, int fd, uint32_t total_bytes);
>
> static __u16 csum_fold(__u32 csum)
> {
> diff --git a/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c
> b/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c
> index 64f172f02a9a..907bac46c774 100644
> --- a/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c
> +++ b/tools/testing/selftests/bpf/prog_tests/bpf_tcp_ca.c
> @@ -33,75 +33,15 @@ static int settcpca(int fd, const char *tcp_ca)
> return 0;
> }
>
> -struct send_recv_arg {
> - int fd;
> - uint32_t bytes;
> - int stop;
> -};
> -
> -static void *server(void *arg)
> -{
> - struct send_recv_arg *a = (struct send_recv_arg *)arg;
> - ssize_t nr_sent = 0, bytes = 0;
> - char batch[1500];
> - int err = 0, fd;
> -
> - fd = accept(a->fd, NULL, NULL);
> - while (fd == -1) {
> - if (errno == EINTR)
> - continue;
> - err = -errno;
> - goto done;
> - }
> -
> - if (settimeo(fd, 0)) {
> - err = -errno;
> - goto done;
> - }
> -
> - while (bytes < a->bytes && !READ_ONCE(a->stop)) {
> - nr_sent = send(fd, &batch,
> - MIN(a->bytes - bytes, sizeof(batch)),
> 0);
> - if (nr_sent == -1 && errno == EINTR)
> - continue;
> - if (nr_sent == -1) {
> - err = -errno;
> - break;
> - }
> - bytes += nr_sent;
> - }
> -
> - ASSERT_EQ(bytes, a->bytes, "send");
> -
> -done:
> - if (fd >= 0)
> - close(fd);
> - if (err) {
> - WRITE_ONCE(a->stop, 1);
> - return ERR_PTR(err);
> - }
> - return NULL;
> -}
> -
> static void do_test(const char *tcp_ca, const struct bpf_map
> *sk_stg_map)
> {
> - ssize_t nr_recv = 0, bytes = 0;
> - struct send_recv_arg arg = {
> - .bytes = total_bytes,
> - .stop = 0,
> - };
> int lfd = -1, fd = -1;
> - pthread_t srv_thread;
> - void *thread_ret;
> - char batch[1500];
> int err;
>
> lfd = start_server(AF_INET6, SOCK_STREAM, NULL, 0, 0);
> if (!ASSERT_NEQ(lfd, -1, "socket"))
> return;
>
> - arg.fd = lfd;
> -
> fd = socket(AF_INET6, SOCK_STREAM, 0);
> if (!ASSERT_NEQ(fd, -1, "socket")) {
> close(lfd);
> @@ -133,26 +73,7 @@ static void do_test(const char *tcp_ca, const
> struct bpf_map *sk_stg_map)
> goto done;
> }
>
> - err = pthread_create(&srv_thread, NULL, server, (void
> *)&arg);
> - if (!ASSERT_OK(err, "pthread_create"))
> - goto done;
> -
> - /* recv total_bytes */
> - while (bytes < total_bytes && !READ_ONCE(arg.stop)) {
> - nr_recv = recv(fd, &batch,
> - MIN(total_bytes - bytes,
> sizeof(batch)), 0);
> - if (nr_recv == -1 && errno == EINTR)
> - continue;
> - if (nr_recv == -1)
> - break;
> - bytes += nr_recv;
> - }
> -
> - ASSERT_EQ(bytes, total_bytes, "recv");
> -
> - WRITE_ONCE(arg.stop, 1);
> - pthread_join(srv_thread, &thread_ret);
> - ASSERT_OK(IS_ERR(thread_ret), "thread_ret");
> + ASSERT_OK(send_recv_data(lfd, fd, total_bytes),
> "send_recv_data");
>
> done:
> close(lfd);