This is a part of my standalone sblock() work. I need this movement
because buffers related SO_SND* and SO_RCV* socket options modification
should be protected with sblock(). However, standalone sblock() has
different lock orders with solock() for receive and send buffers. At
least sblock() for `so_snd' buffer will always be taken before solock()
in the sosend() path.

The switch() block was split by two. SO_DONTROUTE, SO_SPLICE, SO_SND*
and SO_RCV* cases do not require to call (*pr_ctloutput)(), so they were
moved to the first switch() block solock() was pushed into each case
individually. For SO_SND* and SO_RCV* cases solock() will be replaced by
sblock() in the future. SO_RTABLE case calls (*pr_ctloutput)(), but do
this in the special way, so it was placed to the first switch() block
too.

The second switch() block contains the cases which require to call
(*pr_ctloutput)(). solock() is taken around this block together with the
(*pr_ctloutput)() call to keep atomicy.

sys_setsockopt() is not the only sosetopt() caller. For such places
the solock() could be just dropped around sosetopt() call. Please note,
solock() protects only socket consistency so this doesn't brings any
atomicy loss.

I want to receive feedback, polish the diff if required, and then I'll
ask to test the final version with bulk builds and the snaps.

Index: sys/kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
retrieving revision 1.305
diff -u -p -r1.305 uipc_socket.c
--- sys/kern/uipc_socket.c      4 Jul 2023 22:28:24 -0000       1.305
+++ sys/kern/uipc_socket.c      12 Jul 2023 23:08:02 -0000
@@ -1789,57 +1789,23 @@ sosetopt(struct socket *so, int level, i
 {
        int error = 0;
 
-       soassertlocked(so);
-
        if (level != SOL_SOCKET) {
                if (so->so_proto->pr_ctloutput) {
+                       solock(so);
                        error = (*so->so_proto->pr_ctloutput)(PRCO_SETOPT, so,
                            level, optname, m);
+                       sounlock(so);
                        return (error);
                }
                error = ENOPROTOOPT;
        } else {
                switch (optname) {
-               case SO_BINDANY:
-                       if ((error = suser(curproc)) != 0)      /* XXX */
-                               return (error);
-                       break;
-               }
-
-               switch (optname) {
-
-               case SO_LINGER:
-                       if (m == NULL || m->m_len != sizeof (struct linger) ||
-                           mtod(m, struct linger *)->l_linger < 0 ||
-                           mtod(m, struct linger *)->l_linger > SHRT_MAX)
-                               return (EINVAL);
-                       so->so_linger = mtod(m, struct linger *)->l_linger;
-                       /* FALLTHROUGH */
-
-               case SO_BINDANY:
-               case SO_DEBUG:
-               case SO_KEEPALIVE:
-               case SO_USELOOPBACK:
-               case SO_BROADCAST:
-               case SO_REUSEADDR:
-               case SO_REUSEPORT:
-               case SO_OOBINLINE:
-               case SO_TIMESTAMP:
-               case SO_ZEROIZE:
-                       if (m == NULL || m->m_len < sizeof (int))
-                               return (EINVAL);
-                       if (*mtod(m, int *))
-                               so->so_options |= optname;
-                       else
-                               so->so_options &= ~optname;
-                       break;
-
                case SO_DONTROUTE:
                        if (m == NULL || m->m_len < sizeof (int))
                                return (EINVAL);
                        if (*mtod(m, int *))
-                               error = EOPNOTSUPP;
-                       break;
+                               return (EOPNOTSUPP);
+                       return (0);
 
                case SO_SNDBUF:
                case SO_RCVBUF:
@@ -1853,23 +1819,32 @@ sosetopt(struct socket *so, int level, i
                        cnt = *mtod(m, int *);
                        if ((long)cnt <= 0)
                                cnt = 1;
-                       switch (optname) {
 
+                       solock(so);
+                       switch (optname) {
                        case SO_SNDBUF:
-                               if (so->so_snd.sb_state & SS_CANTSENDMORE)
-                                       return (EINVAL);
+                               if (so->so_snd.sb_state & SS_CANTSENDMORE) {
+                                       error = EINVAL;
+                                       break;
+                               }
                                if (sbcheckreserve(cnt, so->so_snd.sb_wat) ||
-                                   sbreserve(so, &so->so_snd, cnt))
-                                       return (ENOBUFS);
+                                   sbreserve(so, &so->so_snd, cnt)) {
+                                       error = ENOBUFS;
+                                       break;
+                               }
                                so->so_snd.sb_wat = cnt;
                                break;
 
                        case SO_RCVBUF:
-                               if (so->so_rcv.sb_state & SS_CANTRCVMORE)
-                                       return (EINVAL);
+                               if (so->so_rcv.sb_state & SS_CANTRCVMORE) {
+                                       error = EINVAL;
+                                       break;
+                               }
                                if (sbcheckreserve(cnt, so->so_rcv.sb_wat) ||
-                                   sbreserve(so, &so->so_rcv, cnt))
-                                       return (ENOBUFS);
+                                   sbreserve(so, &so->so_rcv, cnt)) {
+                                       error = ENOBUFS;
+                                       break;
+                               }
                                so->so_rcv.sb_wat = cnt;
                                break;
 
@@ -1884,7 +1859,8 @@ sosetopt(struct socket *so, int level, i
                                    so->so_rcv.sb_hiwat : cnt;
                                break;
                        }
-                       break;
+                       sounlock(so);
+                       return (error);
                    }
 
                case SO_SNDTIMEO:
@@ -1903,8 +1879,9 @@ sosetopt(struct socket *so, int level, i
                                return (EDOM);
                        if (nsecs == 0)
                                nsecs = INFSLP;
-                       switch (optname) {
 
+                       solock(so);
+                       switch (optname) {
                        case SO_SNDTIMEO:
                                so->so_snd.sb_timeo_nsecs = nsecs;
                                break;
@@ -1912,7 +1889,8 @@ sosetopt(struct socket *so, int level, i
                                so->so_rcv.sb_timeo_nsecs = nsecs;
                                break;
                        }
-                       break;
+                       sounlock(so);
+                       return (0);
                    }
 
                case SO_RTABLE:
@@ -1923,19 +1901,21 @@ sosetopt(struct socket *so, int level, i
                                    so->so_proto->pr_domain;
 
                                level = dom->dom_protosw->pr_protocol;
+                               solock(so);
                                error = (*so->so_proto->pr_ctloutput)
                                    (PRCO_SETOPT, so, level, optname, m);
+                               sounlock(so);
                                return (error);
                        }
-                       error = ENOPROTOOPT;
-                       break;
+                       return (ENOPROTOOPT);
 
 #ifdef SOCKET_SPLICE
                case SO_SPLICE:
+                       solock(so);
                        if (m == NULL) {
                                error = sosplice(so, -1, 0, NULL);
                        } else if (m->m_len < sizeof(int)) {
-                               return (EINVAL);
+                               error = EINVAL;
                        } else if (m->m_len < sizeof(struct splice)) {
                                error = sosplice(so, *mtod(m, int *), 0, NULL);
                        } else {
@@ -1944,17 +1924,59 @@ sosetopt(struct socket *so, int level, i
                                    mtod(m, struct splice *)->sp_max,
                                   &mtod(m, struct splice *)->sp_idle);
                        }
-                       break;
+                       sounlock(so);
+                       return (error);
 #endif /* SOCKET_SPLICE */
+               }
+
+               switch (optname) {
+               case SO_BINDANY:
+                       if ((error = suser(curproc)) != 0)      /* XXX */
+                               return (error);
+                       break;
+               }
+
+               solock(so);
+               switch (optname) {
+               case SO_LINGER:
+                       if (m == NULL || m->m_len != sizeof (struct linger) ||
+                           mtod(m, struct linger *)->l_linger < 0 ||
+                           mtod(m, struct linger *)->l_linger > SHRT_MAX) {
+                               error = EINVAL;
+                               break;
+                       }
+                       so->so_linger = mtod(m, struct linger *)->l_linger;
+                       /* FALLTHROUGH */
+
+               case SO_BINDANY:
+               case SO_DEBUG:
+               case SO_KEEPALIVE:
+               case SO_USELOOPBACK:
+               case SO_BROADCAST:
+               case SO_REUSEADDR:
+               case SO_REUSEPORT:
+               case SO_OOBINLINE:
+               case SO_TIMESTAMP:
+               case SO_ZEROIZE:
+                       if (m == NULL || m->m_len < sizeof (int)) {
+                               error = EINVAL;
+                               break;
+                       }
+                       if (*mtod(m, int *))
+                               so->so_options |= optname;
+                       else
+                               so->so_options &= ~optname;
+
+                       if (so->so_proto->pr_ctloutput)
+                               (*so->so_proto->pr_ctloutput)(PRCO_SETOPT,
+                                   so, level, optname, m);
+                       break;
 
                default:
                        error = ENOPROTOOPT;
                        break;
                }
-               if (error == 0 && so->so_proto->pr_ctloutput) {
-                       (*so->so_proto->pr_ctloutput)(PRCO_SETOPT, so,
-                           level, optname, m);
-               }
+               sounlock(so);
        }
 
        return (error);
Index: sys/kern/uipc_syscalls.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_syscalls.c,v
retrieving revision 1.212
diff -u -p -r1.212 uipc_syscalls.c
--- sys/kern/uipc_syscalls.c    10 Feb 2023 14:34:17 -0000      1.212
+++ sys/kern/uipc_syscalls.c    12 Jul 2023 23:08:02 -0000
@@ -1232,9 +1232,7 @@ sys_setsockopt(struct proc *p, void *v, 
                m->m_len = SCARG(uap, valsize);
        }
        so = fp->f_data;
-       solock(so);
        error = sosetopt(so, SCARG(uap, level), SCARG(uap, name), m);
-       sounlock(so);
 bad:
        m_freem(m);
        FRELE(fp, p);
Index: sys/net/bfd.c
===================================================================
RCS file: /cvs/src/sys/net/bfd.c,v
retrieving revision 1.79
diff -u -p -r1.79 bfd.c
--- sys/net/bfd.c       12 Jul 2023 16:10:45 -0000      1.79
+++ sys/net/bfd.c       12 Jul 2023 23:08:02 -0000
@@ -452,9 +452,7 @@ bfd_listener(struct bfd_config *bfd, uns
        mopt->m_len = sizeof(int);
        ip = mtod(mopt, int *);
        *ip = MAXTTL;
-       solock(so);
        error = sosetopt(so, IPPROTO_IP, IP_MINTTL, mopt);
-       sounlock(so);
        m_freem(mopt);
        if (error) {
                printf("%s: sosetopt error %d\n",
@@ -531,9 +529,7 @@ bfd_sender(struct bfd_config *bfd, unsig
        mopt->m_len = sizeof(int);
        ip = mtod(mopt, int *);
        *ip = IP_PORTRANGE_HIGH;
-       solock(so);
        error = sosetopt(so, IPPROTO_IP, IP_PORTRANGE, mopt);
-       sounlock(so);
        m_freem(mopt);
        if (error) {
                printf("%s: sosetopt error %d\n",
@@ -545,9 +541,7 @@ bfd_sender(struct bfd_config *bfd, unsig
        mopt->m_len = sizeof(int);
        ip = mtod(mopt, int *);
        *ip = MAXTTL;
-       solock(so);
        error = sosetopt(so, IPPROTO_IP, IP_TTL, mopt);
-       sounlock(so);
        m_freem(mopt);
        if (error) {
                printf("%s: sosetopt error %d\n",
@@ -559,9 +553,7 @@ bfd_sender(struct bfd_config *bfd, unsig
        mopt->m_len = sizeof(int);
        ip = mtod(mopt, int *);
        *ip = IPTOS_PREC_INTERNETCONTROL;
-       solock(so);
        error = sosetopt(so, IPPROTO_IP, IP_TOS, mopt);
-       sounlock(so);
        m_freem(mopt);
        if (error) {
                printf("%s: sosetopt error %d\n",
Index: sys/net/if_vxlan.c
===================================================================
RCS file: /cvs/src/sys/net/if_vxlan.c,v
retrieving revision 1.92
diff -u -p -r1.92 if_vxlan.c
--- sys/net/if_vxlan.c  13 Apr 2023 02:19:05 -0000      1.92
+++ sys/net/if_vxlan.c  12 Jul 2023 23:08:02 -0000
@@ -934,9 +934,9 @@ vxlan_tep_add_addr(struct vxlan_softc *s
                goto free;
 
        solock(so);
-
        sotoinpcb(so)->inp_upcall = vxlan_input;
        sotoinpcb(so)->inp_upcall_arg = vt;
+       sounlock(so);
 
        m_inithdr(&m);
        m.m_len = sizeof(vt->vt_rdomain);
@@ -973,12 +973,12 @@ vxlan_tep_add_addr(struct vxlan_softc *s
                unhandled_af(vt->vt_af);
        }
 
+       solock(so);
        error = sobind(so, &m, curproc);
+       sounlock(so);
        if (error != 0)
                goto close;
 
-       sounlock(so);
-
        rw_assert_wrlock(&vxlan_lock);
        TAILQ_INSERT_TAIL(&vxlan_teps, vt, vt_entry);
 
@@ -987,7 +987,6 @@ vxlan_tep_add_addr(struct vxlan_softc *s
        return (0);
 
 close:
-       sounlock(so);
        soclose(so, MSG_DONTWAIT);
 free:
        free(vt, M_DEVBUF, sizeof(*vt));
Index: sys/net/if_wg.c
===================================================================
RCS file: /cvs/src/sys/net/if_wg.c,v
retrieving revision 1.28
diff -u -p -r1.28 if_wg.c
--- sys/net/if_wg.c     1 Jun 2023 18:57:53 -0000       1.28
+++ sys/net/if_wg.c     12 Jul 2023 23:08:02 -0000
@@ -720,14 +720,16 @@ wg_socket_open(struct socket **so, int a
        solock(*so);
        sotoinpcb(*so)->inp_upcall = wg_input;
        sotoinpcb(*so)->inp_upcall_arg = upcall_arg;
+       sounlock(*so);
 
        if ((ret = sosetopt(*so, SOL_SOCKET, SO_RTABLE, &mrtable)) == 0) {
+               solock(*so);
                if ((ret = sobind(*so, &mhostnam, curproc)) == 0) {
                        *port = sotoinpcb(*so)->inp_lport;
                        *rtable = sotoinpcb(*so)->inp_rtableid;
                }
+               sounlock(*so);
        }
-       sounlock(*so);
 
        if (ret != 0)
                wg_socket_close(so);
Index: sys/nfs/krpc_subr.c
===================================================================
RCS file: /cvs/src/sys/nfs/krpc_subr.c,v
retrieving revision 1.37
diff -u -p -r1.37 krpc_subr.c
--- sys/nfs/krpc_subr.c 6 Jun 2022 14:45:41 -0000       1.37
+++ sys/nfs/krpc_subr.c 12 Jul 2023 23:08:02 -0000
@@ -239,9 +239,7 @@ krpc_call(struct sockaddr_in *sa, u_int 
        tv.tv_usec = 0;
        memcpy(mtod(m, struct timeval *), &tv, sizeof tv);
        m->m_len = sizeof(tv);
-       solock(so);
        error = sosetopt(so, SOL_SOCKET, SO_RCVTIMEO, m);
-       sounlock(so);
        m_freem(m);
        if (error)
                goto out;
@@ -255,9 +253,7 @@ krpc_call(struct sockaddr_in *sa, u_int 
                on = mtod(m, int32_t *);
                m->m_len = sizeof(*on);
                *on = 1;
-               solock(so);
                error = sosetopt(so, SOL_SOCKET, SO_BROADCAST, m);
-               sounlock(so);
                m_freem(m);
                if (error)
                        goto out;
@@ -272,9 +268,7 @@ krpc_call(struct sockaddr_in *sa, u_int 
        mopt->m_len = sizeof(int);
        ip = mtod(mopt, int *);
        *ip = IP_PORTRANGE_LOW;
-       solock(so);
        error = sosetopt(so, IPPROTO_IP, IP_PORTRANGE, mopt);
-       sounlock(so);
        m_freem(mopt);
        if (error)
                goto out;
@@ -299,9 +293,7 @@ krpc_call(struct sockaddr_in *sa, u_int 
        mopt->m_len = sizeof(int);
        ip = mtod(mopt, int *);
        *ip = IP_PORTRANGE_DEFAULT;
-       solock(so);
        error = sosetopt(so, IPPROTO_IP, IP_PORTRANGE, mopt);
-       sounlock(so);
        m_freem(mopt);
        if (error)
                goto out;
Index: sys/nfs/nfs_socket.c
===================================================================
RCS file: /cvs/src/sys/nfs/nfs_socket.c,v
retrieving revision 1.143
diff -u -p -r1.143 nfs_socket.c
--- sys/nfs/nfs_socket.c        13 Aug 2022 21:01:46 -0000      1.143
+++ sys/nfs/nfs_socket.c        12 Jul 2023 23:08:02 -0000
@@ -258,7 +258,6 @@ nfs_connect(struct nfsmount *nmp, struct
                MGET(nam, M_WAIT, MT_SONAME);
 
        so = nmp->nm_so;
-       solock(so);
        nmp->nm_soflags = so->so_proto->pr_flags;
 
        /*
@@ -282,7 +281,9 @@ nfs_connect(struct nfsmount *nmp, struct
                sin->sin_family = AF_INET;
                sin->sin_addr.s_addr = INADDR_ANY;
                sin->sin_port = htons(0);
+               solock(so);
                error = sobind(so, nam, &proc0);
+               sounlock(so);
                if (error)
                        goto bad;
 
@@ -294,6 +295,7 @@ nfs_connect(struct nfsmount *nmp, struct
                        goto bad;
        }
 
+       solock(so);
        /*
         * Protocols that do not require connections may be optionally left
         * unconnected for servers that reply from a port other than NFS_PORT.
@@ -301,12 +303,12 @@ nfs_connect(struct nfsmount *nmp, struct
        if (nmp->nm_flag & NFSMNT_NOCONN) {
                if (nmp->nm_soflags & PR_CONNREQUIRED) {
                        error = ENOTCONN;
-                       goto bad;
+                       goto bad_locked;
                }
        } else {
                error = soconnect(so, nmp->nm_nam);
                if (error)
-                       goto bad;
+                       goto bad_locked;
 
                /*
                 * Wait for the connection to complete. Cribbed from the
@@ -320,13 +322,13 @@ nfs_connect(struct nfsmount *nmp, struct
                            so->so_error == 0 && rep &&
                            (error = nfs_sigintr(nmp, rep, rep->r_procp)) != 0){
                                so->so_state &= ~SS_ISCONNECTING;
-                               goto bad;
+                               goto bad_locked;
                        }
                }
                if (so->so_error) {
                        error = so->so_error;
                        so->so_error = 0;
-                       goto bad;
+                       goto bad_locked;
                }
        }
        /*
@@ -338,6 +340,7 @@ nfs_connect(struct nfsmount *nmp, struct
                so->so_snd.sb_timeo_nsecs = SEC_TO_NSEC(5);
        else
                so->so_snd.sb_timeo_nsecs = INFSLP;
+       sounlock(so);
        if (nmp->nm_sotype == SOCK_DGRAM) {
                sndreserve = nmp->nm_wsize + NFS_MAXPKTHDR;
                rcvreserve = (max(nmp->nm_rsize, nmp->nm_readdirsize) +
@@ -360,9 +363,10 @@ nfs_connect(struct nfsmount *nmp, struct
        } else {
                panic("%s: nm_sotype %d", __func__, nmp->nm_sotype);
        }
+       solock(so);
        error = soreserve(so, sndreserve, rcvreserve);
        if (error)
-               goto bad;
+               goto bad_locked;
        so->so_rcv.sb_flags |= SB_NOINTR;
        so->so_snd.sb_flags |= SB_NOINTR;
        sounlock(so);
@@ -377,8 +381,9 @@ nfs_connect(struct nfsmount *nmp, struct
        nmp->nm_timeouts = 0;
        return (0);
 
-bad:
+bad_locked:
        sounlock(so);
+bad:
 
        m_freem(mopt);
        m_freem(nam);
Index: sys/nfs/nfs_syscalls.c
===================================================================
RCS file: /cvs/src/sys/nfs/nfs_syscalls.c,v
retrieving revision 1.118
diff -u -p -r1.118 nfs_syscalls.c
--- sys/nfs/nfs_syscalls.c      6 Jun 2022 14:45:41 -0000       1.118
+++ sys/nfs/nfs_syscalls.c      12 Jul 2023 23:08:02 -0000
@@ -249,8 +249,8 @@ nfssvc_addsock(struct file *fp, struct m
                siz = NFS_MAXPACKET;
        solock(so);
        error = soreserve(so, siz, siz); 
+       sounlock(so);
        if (error) {
-               sounlock(so);
                m_freem(mynam);
                return (error);
        }
@@ -275,6 +275,7 @@ nfssvc_addsock(struct file *fp, struct m
                sosetopt(so, IPPROTO_TCP, TCP_NODELAY, m);
                m_freem(m);
        }
+       solock(so);
        so->so_rcv.sb_flags &= ~SB_NOINTR;
        so->so_rcv.sb_timeo_nsecs = INFSLP;
        so->so_snd.sb_flags &= ~SB_NOINTR;

Reply via email to