On 22/11/21(Mon) 14:42, Vitaliy Makkoveev wrote:
> On Sat, Nov 20, 2021 at 03:12:31AM +0300, Vitaliy Makkoveev wrote:
> > Updated diff. Re-lock dances were simplified in the unix(4) sockets
> > layer.
> > 
> > Reference counters added to unix(4) sockets layer too. This makes 
> > pointer dereference of peer's control block always safe after re-lock.
> > 
> > The `unp_refs' list cleanup done in the unp_detach(). This removes the
> > case where the socket connected to our dying socket could be passed to
> > unp_disconnect() and the check of it's connection state became much
> > easier.
> >
> 
> Another re-lock simplification. We could enforce the lock order between
> the listening socket `head' and the socket `so' linked to it's `so_q0'
> or `so_q' to solock(head) -> solock(so).
> 
> This removes re-lock from accept(2) and the accepting socket couldn't be
> stolen by concurrent accept(2) thread. This removes re-lock from `so_q0'
> and `so_q' cleanup on dying listening socket.
> 
> The previous incarnation of this diff does re-lock in a half of
> doaccept(), soclose(), sofree() and soisconnected() calls. The current
> diff does not re-lock in doaccept() and soclose() and always so re-lock
> in sofree() and soisconnected().
> 
> I guess this is the latest simplification and this diff could be pushed
> forward.

This diff is really interesting.  It shows that the current locking
design needs to be reworked.

I don't think we should expose the locking strategy with a `persocket'
variable then use if/else dances to decide if one of two locks need to
be taken/released.  Instead could we fold the TCP/UDP locking into more
generic functions?  For example connect() could be:

int
soconnect2(struct socket *so1, struct socket *so2)
{
        int s, error;

        s = solock_pair(so1, so2);
        error = (*so1->so_proto->pr_usrreq)(so1, PRU_CONNECT2, NULL,
            (struct mbuf *)so2, NULL, curproc);
        sounlock_pair(so1, so2, s);
        return (error);
}

And solock_pair() would do the right thing(tm) based on the socket type.

Because in the end we want to prepare this layer to use per-socket locks
with TCP/UDP sockets as well.

Could something similar be done for doaccept()?

I'm afraid about introducing reference counting.  Once there is reference
counting it tends to be abused.  It's not clear to me for which reason it
is added.  It looks like to work around lock ordering issues, could you
talk a bit about this?  Is there any alternative?

I also don't understand the problem behind:

> +             unp_ref(unp2);
> +             sounlock(so, SL_LOCKED);
> +             solock(so2);
> +             solock(so);
> +
> +             /* Datagram socket could be reconnected due to re-lock. */
> +             if (unp->unp_conn != unp2) {
> +                     sounlock(so2, SL_LOCKED);
> +                     unp_rele(unp2);
> +                     goto again;
> +             }
> +
> +             unp_rele(unp2);


It seems that doing an unlock/relock dance requires a lot of added
complexity, why is it done this way?

Thanks for dealing with this!

> Index: sys/kern/uipc_socket.c
> ===================================================================
> RCS file: /cvs/src/sys/kern/uipc_socket.c,v
> retrieving revision 1.269
> diff -u -p -r1.269 uipc_socket.c
> --- sys/kern/uipc_socket.c    11 Nov 2021 16:35:09 -0000      1.269
> +++ sys/kern/uipc_socket.c    22 Nov 2021 11:36:40 -0000
> @@ -52,6 +52,7 @@
>  #include <sys/atomic.h>
>  #include <sys/rwlock.h>
>  #include <sys/time.h>
> +#include <sys/refcnt.h>
>  
>  #ifdef DDB
>  #include <machine/db_machdep.h>
> @@ -156,7 +157,9 @@ soalloc(int prflags)
>       so = pool_get(&socket_pool, prflags);
>       if (so == NULL)
>               return (NULL);
> -     rw_init(&so->so_lock, "solock");
> +     rw_init_flags(&so->so_lock, "solock", RWL_DUPOK);
> +     refcnt_init(&so->so_refcnt);
> +
>       return (so);
>  }
>  
> @@ -257,6 +260,8 @@ solisten(struct socket *so, int backlog)
>  void
>  sofree(struct socket *so, int s)
>  {
> +     int persocket = solock_persocket(so);
> +
>       soassertlocked(so);
>  
>       if (so->so_pcb || (so->so_state & SS_NOFDREF) == 0) {
> @@ -264,16 +269,53 @@ sofree(struct socket *so, int s)
>               return;
>       }
>       if (so->so_head) {
> +             struct socket *head = so->so_head;
> +
>               /*
>                * We must not decommission a socket that's on the accept(2)
>                * queue.  If we do, then accept(2) may hang after select(2)
>                * indicated that the listening socket was ready.
>                */
> -             if (!soqremque(so, 0)) {
> +             if (so->so_onq == &head->so_q) {
>                       sounlock(so, s);
>                       return;
>               }
> +
> +             if (persocket) {
> +                     /*
> +                      * Concurrent close of `head' could
> +                      * abort `so' due to re-lock.
> +                      */
> +                     soref(so);
> +                     soref(head);
> +                     sounlock(so, SL_LOCKED);
> +                     solock(head);
> +                     solock(so);
> +
> +                     if (so->so_onq != &head->so_q0) {
> +                             sounlock(head, SL_LOCKED);
> +                             sounlock(so, SL_LOCKED);
> +                             sorele(head);
> +                             sorele(so);
> +                             return;
> +                     }
> +
> +                     sorele(head);
> +                     sorele(so);
> +             }
> +
> +             soqremque(so, 0);
> +
> +             if (persocket)
> +                     sounlock(head, SL_LOCKED);
>       }
> +
> +     if (persocket) {
> +             sounlock(so, SL_LOCKED);
> +             refcnt_finalize(&so->so_refcnt, "sofinal");
> +             solock(so);
> +     }
> +
>       sigio_free(&so->so_sigio);
>       klist_free(&so->so_rcv.sb_sel.si_note);
>       klist_free(&so->so_snd.sb_sel.si_note);
> @@ -363,13 +405,36 @@ drop:
>                       error = error2;
>       }
>       if (so->so_options & SO_ACCEPTCONN) {
> +             int persocket = solock_persocket(so);
> +
> +             if (persocket) {
> +                     /* Wait concurrent sonewconn() threads. */
> +                     while (so->so_newconn > 0) {
> +                             so->so_state |= SS_NEWCONN_WAIT;
> +                             sosleep_nsec(so, &so->so_newconn, PSOCK,
> +                                     "netlck", INFSLP);
> +                     }
> +             }
> +
>               while ((so2 = TAILQ_FIRST(&so->so_q0)) != NULL) {
> +                     if (persocket)
> +                             solock(so2);
>                       (void) soqremque(so2, 0);
> +                     if (persocket)
> +                             sounlock(so, SL_LOCKED);
>                       (void) soabort(so2);
> +                     if (persocket)
> +                             solock(so);
>               }
>               while ((so2 = TAILQ_FIRST(&so->so_q)) != NULL) {
> +                     if (persocket)
> +                             solock(so2);
>                       (void) soqremque(so2, 1);
> +                     if (persocket)
> +                             sounlock(so, SL_LOCKED);
>                       (void) soabort(so2);
> +                     if (persocket)
> +                             solock(so);
>               }
>       }
>  discard:
> @@ -437,11 +502,19 @@ soconnect(struct socket *so, struct mbuf
>  int
>  soconnect2(struct socket *so1, struct socket *so2)
>  {
> -     int s, error;
> +     int persocket, s, error;
> +
> +     if ((persocket = solock_persocket(so1))) {
> +             solock_pair(so1, so2);
> +             s = SL_LOCKED;
> +     } else
> +             s = solock(so1);
>  
> -     s = solock(so1);
>       error = (*so1->so_proto->pr_usrreq)(so1, PRU_CONNECT2, NULL,
>           (struct mbuf *)so2, NULL, curproc);
> +
> +     if (persocket)
> +             sounlock(so2, s);
>       sounlock(so1, s);
>       return (error);
>  }
> Index: sys/kern/uipc_socket2.c
> ===================================================================
> RCS file: /cvs/src/sys/kern/uipc_socket2.c,v
> retrieving revision 1.116
> diff -u -p -r1.116 uipc_socket2.c
> --- sys/kern/uipc_socket2.c   6 Nov 2021 05:26:33 -0000       1.116
> +++ sys/kern/uipc_socket2.c   22 Nov 2021 11:36:40 -0000
> @@ -53,8 +53,6 @@ u_long      sb_max = SB_MAX;                /* patchable */
>  extern struct pool mclpools[];
>  extern struct pool mbpool;
>  
> -extern struct rwlock unp_lock;
> -
>  /*
>   * Procedures to manipulate state flags of socket
>   * and do appropriate wakeups.  Normal sequence from the
> @@ -101,10 +99,37 @@ soisconnected(struct socket *so)
>       soassertlocked(so);
>       so->so_state &= ~(SS_ISCONNECTING|SS_ISDISCONNECTING);
>       so->so_state |= SS_ISCONNECTED;
> -     if (head && soqremque(so, 0)) {
> +
> +     if (head != NULL && so->so_onq == &head->so_q0) {
> +             int persocket = solock_persocket(so);
> +
> +             if (persocket) {
> +                     soref(so);
> +                     soref(head);
> +
> +                     sounlock(so, SL_LOCKED);
> +                     solock(head);
> +                     solock(so);
> +
> +                     if (so->so_onq != &head->so_q0) {
> +                             sounlock(head, SL_LOCKED);
> +                             sorele(head);
> +                             sorele(so);
> +
> +                             return;
> +                     }
> +
> +                     sorele(head);
> +                     sorele(so);
> +             }
> +
> +             soqremque(so, 0);
>               soqinsque(head, so, 1);
>               sorwakeup(head);
>               wakeup_one(&head->so_timeo);
> +
> +             if (persocket)
> +                     sounlock(head, SL_LOCKED);
>       } else {
>               wakeup(&so->so_timeo);
>               sorwakeup(so);
> @@ -146,7 +171,8 @@ struct socket *
>  sonewconn(struct socket *head, int connstatus)
>  {
>       struct socket *so;
> -     int soqueue = connstatus ? 1 : 0;
> +     int persocket = solock_persocket(head);
> +     int error;
>  
>       /*
>        * XXXSMP as long as `so' and `head' share the same lock, we
> @@ -175,9 +201,17 @@ sonewconn(struct socket *head, int conns
>       so->so_cpid = head->so_cpid;
>  
>       /*
> +      * Lock order will be `head' -> `so' while these sockets are linked.
> +      */
> +     if (persocket)
> +             solock(so);
> +
> +     /*
>        * Inherit watermarks but those may get clamped in low mem situations.
>        */
>       if (soreserve(so, head->so_snd.sb_hiwat, head->so_rcv.sb_hiwat)) {
> +             if (persocket)
> +                     sounlock(so, SL_LOCKED);
>               pool_put(&socket_pool, so);
>               return (NULL);
>       }
> @@ -193,20 +227,54 @@ sonewconn(struct socket *head, int conns
>       sigio_init(&so->so_sigio);
>       sigio_copy(&so->so_sigio, &head->so_sigio);
>  
> -     soqinsque(head, so, soqueue);
> -     if ((*so->so_proto->pr_attach)(so, 0)) {
> -             (void) soqremque(so, soqueue);
> +     soqinsque(head, so, 0);
> +
> +     /*
> +      * We need to unlock `head' because PCB layer could release
> +      * solock() to enforce desired lock order.
> +      */
> +     if (persocket) {
> +             head->so_newconn++;
> +             sounlock(head, SL_LOCKED);
> +     }
> +
> +     error = (*so->so_proto->pr_attach)(so, 0);
> +
> +     if (persocket) {
> +             sounlock(so, SL_LOCKED);
> +             solock(head);
> +             solock(so);
> +
> +             if ((head->so_newconn--) == 0) {
> +                     if ((head->so_state & SS_NEWCONN_WAIT) != 0) {
> +                             head->so_state &= ~SS_NEWCONN_WAIT;
> +                             wakeup(&head->so_newconn);
> +                     }
> +             }
> +     }
> +
> +     if (error) {
> +             soqremque(so, 0);
> +             if (persocket)
> +                     sounlock(so, SL_LOCKED);
>               sigio_free(&so->so_sigio);
>               klist_free(&so->so_rcv.sb_sel.si_note);
>               klist_free(&so->so_snd.sb_sel.si_note);
>               pool_put(&socket_pool, so);
>               return (NULL);
>       }
> +
>       if (connstatus) {
> +             so->so_state |= connstatus;
> +             soqremque(so, 0);
> +             soqinsque(head, so, 1);
>               sorwakeup(head);
>               wakeup(&head->so_timeo);
> -             so->so_state |= connstatus;
>       }
> +
> +     if (persocket)
> +             sounlock(so, SL_LOCKED);
> +
>       return (so);
>  }
>  
> @@ -214,6 +282,7 @@ void
>  soqinsque(struct socket *head, struct socket *so, int q)
>  {
>       soassertlocked(head);
> +     soassertlocked(so);
>  
>       KASSERT(so->so_onq == NULL);
>  
> @@ -233,6 +302,7 @@ soqremque(struct socket *so, int q)
>  {
>       struct socket *head = so->so_head;
>  
> +     soassertlocked(so);
>       soassertlocked(head);
>  
>       if (q == 0) {
> @@ -284,9 +354,6 @@ solock(struct socket *so)
>       case PF_INET6:
>               NET_LOCK();
>               break;
> -     case PF_UNIX:
> -             rw_enter_write(&unp_lock);
> -             break;
>       default:
>               rw_enter_write(&so->so_lock);
>               break;
> @@ -295,6 +362,34 @@ solock(struct socket *so)
>       return (SL_LOCKED);
>  }
>  
> +int
> +solock_persocket(struct socket *so)
> +{
> +     switch (so->so_proto->pr_domain->dom_family) {
> +     case PF_INET:
> +     case PF_INET6:
> +             return 0;
> +     default:
> +             return 1;
> +     }
> +}
> +
> +void
> +solock_pair(struct socket *so1, struct socket *so2)
> +{
> +     KASSERT(so1 != so2);
> +     KASSERT(so1->so_type == so2->so_type);
> +     KASSERT(solock_persocket(so1));
> +
> +     if (so1 < so2) {
> +             solock(so1);
> +             solock(so2);
> +     } else {
> +             solock(so2);
> +             solock(so1);
> +     }
> +}
> +
>  void
>  sounlock(struct socket *so, int s)
>  {
> @@ -308,9 +403,6 @@ sounlock(struct socket *so, int s)
>       case PF_INET6:
>               NET_UNLOCK();
>               break;
> -     case PF_UNIX:
> -             rw_exit_write(&unp_lock);
> -             break;
>       default:
>               rw_exit_write(&so->so_lock);
>               break;
> @@ -325,9 +417,6 @@ soassertlocked(struct socket *so)
>       case PF_INET6:
>               NET_ASSERT_LOCKED();
>               break;
> -     case PF_UNIX:
> -             rw_assert_wrlock(&unp_lock);
> -             break;
>       default:
>               rw_assert_wrlock(&so->so_lock);
>               break;
> @@ -344,9 +433,6 @@ sosleep_nsec(struct socket *so, void *id
>       case PF_INET:
>       case PF_INET6:
>               ret = rwsleep_nsec(ident, &netlock, prio, wmesg, nsecs);
> -             break;
> -     case PF_UNIX:
> -             ret = rwsleep_nsec(ident, &unp_lock, prio, wmesg, nsecs);
>               break;
>       default:
>               ret = rwsleep_nsec(ident, &so->so_lock, prio, wmesg, nsecs);
> Index: sys/kern/uipc_syscalls.c
> ===================================================================
> RCS file: /cvs/src/sys/kern/uipc_syscalls.c,v
> retrieving revision 1.194
> diff -u -p -r1.194 uipc_syscalls.c
> --- sys/kern/uipc_syscalls.c  24 Oct 2021 00:02:25 -0000      1.194
> +++ sys/kern/uipc_syscalls.c  22 Nov 2021 11:36:40 -0000
> @@ -246,7 +246,7 @@ doaccept(struct proc *p, int sock, struc
>       socklen_t namelen;
>       int error, s, tmpfd;
>       struct socket *head, *so;
> -     int cloexec, nflag;
> +     int cloexec, nflag, persocket;
>  
>       cloexec = (flags & SOCK_CLOEXEC) ? UF_EXCLOSE : 0;
>  
> @@ -269,16 +269,19 @@ doaccept(struct proc *p, int sock, struc
>  
>       head = headfp->f_data;
>       s = solock(head);
> +
> +     persocket = solock_persocket(head);
> +
>       if (isdnssocket(head) || (head->so_options & SO_ACCEPTCONN) == 0) {
>               error = EINVAL;
> -             goto out;
> +             goto out_unlock;
>       }
>       if ((headfp->f_flag & FNONBLOCK) && head->so_qlen == 0) {
>               if (head->so_state & SS_CANTRCVMORE)
>                       error = ECONNABORTED;
>               else
>                       error = EWOULDBLOCK;
> -             goto out;
> +             goto out_unlock;
>       }
>       while (head->so_qlen == 0 && head->so_error == 0) {
>               if (head->so_state & SS_CANTRCVMORE) {
> @@ -288,18 +291,23 @@ doaccept(struct proc *p, int sock, struc
>               error = sosleep_nsec(head, &head->so_timeo, PSOCK | PCATCH,
>                   "netcon", INFSLP);
>               if (error)
> -                     goto out;
> +                     goto out_unlock;
>       }
>       if (head->so_error) {
>               error = head->so_error;
>               head->so_error = 0;
> -             goto out;
> +             goto out_unlock;
>       }
>  
>       /*
>        * Do not sleep after we have taken the socket out of the queue.
>        */
> +
>       so = TAILQ_FIRST(&head->so_q);
> +
> +     if (persocket)
> +             solock(so);
> +
>       if (soqremque(so, 1) == 0)
>               panic("accept");
>  
> @@ -310,30 +318,53 @@ doaccept(struct proc *p, int sock, struc
>       /* connection has been removed from the listen queue */
>       KNOTE(&head->so_rcv.sb_sel.si_note, 0);
>  
> +     if (persocket)
> +             sounlock(head, s);
> +
>       fp->f_type = DTYPE_SOCKET;
>       fp->f_flag = FREAD | FWRITE | nflag;
>       fp->f_ops = &socketops;
>       fp->f_data = so;
> +
>       error = soaccept(so, nam);
> -out:
> -     sounlock(head, s);
> -     if (!error && name != NULL)
> +
> +     /*
> +      * It doesn't matter which socket to unlock when we
> +      * locked the whole layer.
> +      */
> +     sounlock(so, s);
> +
> +     if (error)
> +             goto out;
> +
> +     if (name != NULL) {
>               error = copyaddrout(p, nam, name, namelen, anamelen);
> -     if (!error) {
> -             fdplock(fdp);
> -             fdinsert(fdp, tmpfd, cloexec, fp);
> -             fdpunlock(fdp);
> -             FRELE(fp, p);
> -             *retval = tmpfd;
> -     } else {
> -             fdplock(fdp);
> -             fdremove(fdp, tmpfd);
> -             fdpunlock(fdp);
> -             closef(fp, p);
> +             if (error)
> +                     goto out;
>       }
>  
> +     fdplock(fdp);
> +     fdinsert(fdp, tmpfd, cloexec, fp);
> +     fdpunlock(fdp);
> +     FRELE(fp, p);
> +     *retval = tmpfd;
> +
>       m_freem(nam);
>       FRELE(headfp, p);
> +
> +     return 0;
> +
> +out_unlock:
> +     sounlock(head, s);
> +out:
> +     fdplock(fdp);
> +     fdremove(fdp, tmpfd);
> +     fdpunlock(fdp);
> +     closef(fp, p);
> +
> +     m_freem(nam);
> +     FRELE(headfp, p);
> +
>       return (error);
>  }
>  
> Index: sys/kern/uipc_usrreq.c
> ===================================================================
> RCS file: /cvs/src/sys/kern/uipc_usrreq.c,v
> retrieving revision 1.158
> diff -u -p -r1.158 uipc_usrreq.c
> --- sys/kern/uipc_usrreq.c    17 Nov 2021 22:56:19 -0000      1.158
> +++ sys/kern/uipc_usrreq.c    22 Nov 2021 11:36:40 -0000
> @@ -52,25 +52,26 @@
>  #include <sys/pledge.h>
>  #include <sys/pool.h>
>  #include <sys/rwlock.h>
> -#include <sys/mutex.h>
>  #include <sys/sysctl.h>
>  #include <sys/lock.h>
> +#include <sys/refcnt.h>
>  
>  /*
>   * Locks used to protect global data and struct members:
>   *      I       immutable after creation
>   *      D       unp_df_lock
>   *      G       unp_gc_lock
> - *      U       unp_lock
> + *      M       unp_ino_mtx
>   *      R       unp_rights_mtx
>   *      a       atomic
> + *      s       socket lock
>   */
>  
> -struct rwlock unp_lock = RWLOCK_INITIALIZER("unplock");
>  struct rwlock unp_df_lock = RWLOCK_INITIALIZER("unpdflk");
>  struct rwlock unp_gc_lock = RWLOCK_INITIALIZER("unpgclk");
>  
>  struct mutex unp_rights_mtx = MUTEX_INITIALIZER(IPL_SOFTNET);
> +struct mutex unp_ino_mtx = MUTEX_INITIALIZER(IPL_SOFTNET);
>  
>  /*
>   * Stack of sets of files that were passed over a socket but were
> @@ -88,6 +89,9 @@ void        unp_discard(struct fdpass *, int);
>  void unp_mark(struct fdpass *, int);
>  void unp_scan(struct mbuf *, void (*)(struct fdpass *, int));
>  int  unp_nam2sun(struct mbuf *, struct sockaddr_un **, size_t *);
> +static inline void unp_ref(struct unpcb *);
> +static inline void unp_rele(struct unpcb *);
> +struct socket *unp_solock_peer(struct socket *);
>  
>  struct pool unpcb_pool;
>  struct task unp_gc_task = TASK_INITIALIZER(unp_gc, NULL);
> @@ -121,6 +125,53 @@ unp_init(void)
>           IPL_SOFTNET, 0, "unpcb", NULL);
>  }
>  
> +static inline void
> +unp_ref(struct unpcb *unp)
> +{
> +     refcnt_take(&unp->unp_refcnt);
> +}
> +
> +static inline void
> +unp_rele(struct unpcb *unp)
> +{
> +     refcnt_rele_wake(&unp->unp_refcnt);
> +}
> +
> +struct socket *
> +unp_solock_peer(struct socket *so)
> +{
> +     struct unpcb *unp, *unp2;
> +     struct socket *so2;
> +
> +     unp = so->so_pcb;
> +
> +again:
> +     if ((unp2 = unp->unp_conn) == NULL)
> +             return NULL;
> +
> +     so2 = unp2->unp_socket;
> +
> +     if (so < so2)
> +             solock(so2);
> +     else if (so > so2){
> +             unp_ref(unp2);
> +             sounlock(so, SL_LOCKED);
> +             solock(so2);
> +             solock(so);
> +
> +             /* Datagram socket could be reconnected due to re-lock. */
> +             if (unp->unp_conn != unp2) {
> +                     sounlock(so2, SL_LOCKED);
> +                     unp_rele(unp2);
> +                     goto again;
> +             }
> +
> +             unp_rele(unp2);
> +     }
> +
> +     return so2;
> +}
> +
>  void
>  uipc_setaddr(const struct unpcb *unp, struct mbuf *nam)
>  {
> @@ -195,6 +246,12 @@ uipc_usrreq(struct socket *so, int req, 
>                * if it was bound and we are still connected
>                * (our peer may have closed already!).
>                */
> +             /*
> +              * Don't need to lock `unp_conn'. The `unp_addr' is
> +              * immutable since we set it within unp_connect().
> +              * Both sockets are locked while we connecting them
> +              * so it's enough to hold lock on `unp'.
> +              */
>               uipc_setaddr(unp->unp_conn, nam);
>               break;
>  
> @@ -212,9 +269,8 @@ uipc_usrreq(struct socket *so, int req, 
>  
>               case SOCK_STREAM:
>               case SOCK_SEQPACKET:
> -                     if (unp->unp_conn == NULL)
> +                     if ((so2 = unp_solock_peer(so)) == NULL)
>                               break;
> -                     so2 = unp->unp_conn->unp_socket;
>                       /*
>                        * Adjust backpressure on sender
>                        * and wakeup any waiting to write.
> @@ -222,6 +278,7 @@ uipc_usrreq(struct socket *so, int req, 
>                       so2->so_snd.sb_mbcnt = so->so_rcv.sb_mbcnt;
>                       so2->so_snd.sb_cc = so->so_rcv.sb_cc;
>                       sowwakeup(so2);
> +                     sounlock(so2, SL_LOCKED);
>                       break;
>  
>               default:
> @@ -250,13 +307,16 @@ uipc_usrreq(struct socket *so, int req, 
>                               error = unp_connect(so, nam, p);
>                               if (error)
>                                       break;
> -                     } else {
> -                             if (unp->unp_conn == NULL) {
> +                     }
> +
> +                     if ((so2 = unp_solock_peer(so)) == NULL) {
> +                             if (nam != NULL)
> +                                     error = ECONNREFUSED;
> +                             else
>                                       error = ENOTCONN;
> -                                     break;
> -                             }
> +                             break;
>                       }
> -                     so2 = unp->unp_conn->unp_socket;
> +
>                       if (unp->unp_addr)
>                               from = mtod(unp->unp_addr, struct sockaddr *);
>                       else
> @@ -267,6 +327,10 @@ uipc_usrreq(struct socket *so, int req, 
>                               control = NULL;
>                       } else
>                               error = ENOBUFS;
> +
> +                     if (so2 != so)
> +                             sounlock(so2, SL_LOCKED);
> +
>                       if (nam)
>                               unp_disconnect(unp);
>                       break;
> @@ -278,11 +342,11 @@ uipc_usrreq(struct socket *so, int req, 
>                               error = EPIPE;
>                               break;
>                       }
> -                     if (unp->unp_conn == NULL) {
> +                     if ((so2 = unp_solock_peer(so)) == NULL) {
>                               error = ENOTCONN;
>                               break;
>                       }
> -                     so2 = unp->unp_conn->unp_socket;
> +
>                       /*
>                        * Send to paired receive port, and then raise
>                        * send buffer counts to maintain backpressure.
> @@ -304,6 +368,8 @@ uipc_usrreq(struct socket *so, int req, 
>                       so->so_snd.sb_cc = so2->so_rcv.sb_cc;
>                       if (so2->so_rcv.sb_cc > 0)
>                               sorwakeup(so2);
> +
> +                     sounlock(so2, SL_LOCKED);
>                       m = NULL;
>                       break;
>  
> @@ -317,12 +383,7 @@ uipc_usrreq(struct socket *so, int req, 
>  
>       case PRU_ABORT:
>               unp_detach(unp);
> -             /*
> -              * As long as `unp_lock' is taken before entering
> -              * uipc_usrreq() releasing it here would lead to a
> -              * double unlock.
> -              */
> -             sofree(so, SL_NOUNLOCK);
> +             sofree(so, SL_LOCKED);
>               break;
>  
>       case PRU_SENSE: {
> @@ -330,8 +391,10 @@ uipc_usrreq(struct socket *so, int req, 
>  
>               sb->st_blksize = so->so_snd.sb_hiwat;
>               sb->st_dev = NODEV;
> +             mtx_enter(&unp_ino_mtx);
>               if (unp->unp_ino == 0)
>                       unp->unp_ino = unp_ino++;
> +             mtx_leave(&unp_ino_mtx);
>               sb->st_atim.tv_sec =
>                   sb->st_mtim.tv_sec =
>                   sb->st_ctim.tv_sec = unp->unp_ctime.tv_sec;
> @@ -352,6 +415,12 @@ uipc_usrreq(struct socket *so, int req, 
>               break;
>  
>       case PRU_PEERADDR:
> +             /*
> +              * Don't need to lock `unp_conn'. The `unp_addr' is
> +              * immutable since we set it within unp_connect().
> +              * Both sockets are locked while we connecting them
> +              * so it's enough to hold lock on `unp'.
> +              */
>               uipc_setaddr(unp->unp_conn, nam);
>               break;
>  
> @@ -404,8 +473,6 @@ uipc_attach(struct socket *so, int proto
>       struct unpcb *unp;
>       int error;
>  
> -     rw_assert_wrlock(&unp_lock);
> -
>       if (so->so_pcb)
>               return EISCONN;
>       if (so->so_snd.sb_hiwat == 0 || so->so_rcv.sb_hiwat == 0) {
> @@ -432,6 +499,7 @@ uipc_attach(struct socket *so, int proto
>       unp = pool_get(&unpcb_pool, PR_NOWAIT|PR_ZERO);
>       if (unp == NULL)
>               return (ENOBUFS);
> +     refcnt_init(&unp->unp_refcnt);
>       unp->unp_socket = so;
>       so->so_pcb = unp;
>       getnanotime(&unp->unp_ctime);
> @@ -439,12 +507,6 @@ uipc_attach(struct socket *so, int proto
>       /*
>        * Enforce `unp_gc_lock' -> `solock()' lock order.
>        */
> -     /*
> -      * We also release the lock on listening socket and on our peer
> -      * socket when called from unp_connect(). This is safe. The
> -      * listening socket protected by vnode(9) lock. The peer socket
> -      * has 'UNP_CONNECTING' flag set.
> -      */
>       sounlock(so, SL_LOCKED);
>       rw_enter_write(&unp_gc_lock);
>       LIST_INSERT_HEAD(&unp_head, unp, unp_link);
> @@ -506,14 +568,13 @@ unp_detach(struct unpcb *unp)
>  {
>       struct socket *so = unp->unp_socket;
>       struct vnode *vp = unp->unp_vnode;
> -
> -     rw_assert_wrlock(&unp_lock);
> +     struct unpcb *unp2;
>  
>       unp->unp_vnode = NULL;
>  
>       /*
>        * Enforce `unp_gc_lock' -> `solock()' lock order.
> -      * Enforce `i_lock' -> `unp_lock' lock order.
> +      * Enforce `i_lock' -> `solock()' lock order.
>        */
>       sounlock(so, SL_LOCKED);
>  
> @@ -532,10 +593,47 @@ unp_detach(struct unpcb *unp)
>  
>       solock(so);
>  
> -     if (unp->unp_conn)
> +     if (unp->unp_conn != NULL) {
> +             /*
> +              * Datagram socket could be connected to itself.
> +              * Such socket will be disconnected here.
> +              */
>               unp_disconnect(unp);
> -     while (!SLIST_EMPTY(&unp->unp_refs))
> -             unp_drop(SLIST_FIRST(&unp->unp_refs), ECONNRESET);
> +     }
> +
> +     while ((unp2 = SLIST_FIRST(&unp->unp_refs)) != NULL) {
> +             struct socket *so2 = unp2->unp_socket;
> +
> +             if (so < so2)
> +                     solock(so2);
> +             else {
> +                     unp_ref(unp2);
> +                     sounlock(so, SL_LOCKED);
> +                     solock(so2);
> +                     solock(so);
> +
> +                     if (unp2->unp_conn != unp) {
> +                             /* `unp2' was disconnected due to re-lock. */
> +                             sounlock(so2, SL_LOCKED);
> +                             unp_rele(unp2);
> +                             continue;
> +                     }
> +
> +                     unp_rele(unp2);
> +             }
> +
> +             unp2->unp_conn = NULL;
> +             SLIST_REMOVE(&unp->unp_refs, unp2, unpcb, unp_nextref);
> +             so2->so_error = ECONNRESET;
> +             so2->so_state &= ~SS_ISCONNECTED;
> +
> +             sounlock(so2, SL_LOCKED);
> +     }
> +
> +     sounlock(so, SL_LOCKED);
> +     refcnt_finalize(&unp->unp_refcnt, "unpfinal");
> +     solock(so);
> +
>       soisdisconnected(so);
>       so->so_pcb = NULL;
>       m_freem(unp->unp_addr);
> @@ -675,24 +773,42 @@ unp_connect(struct socket *so, struct mb
>       }
>       if ((error = VOP_ACCESS(vp, VWRITE, p->p_ucred, p)) != 0)
>               goto put;
> -     solock(so);
>       so2 = vp->v_socket;
>       if (so2 == NULL) {
>               error = ECONNREFUSED;
> -             goto put_locked;
> +             goto put;
>       }
>       if (so->so_type != so2->so_type) {
>               error = EPROTOTYPE;
> -             goto put_locked;
> +             goto put;
>       }
> +
>       if (so->so_proto->pr_flags & PR_CONNREQUIRED) {
> +             solock(so2);
> +
>               if ((so2->so_options & SO_ACCEPTCONN) == 0 ||
>                   (so3 = sonewconn(so2, 0)) == NULL) {
>                       error = ECONNREFUSED;
> -                     goto put_locked;
>               }
> +
> +             sounlock(so2, SL_LOCKED);
> +
> +             if (error != 0)
> +                     goto put;
> +
> +             /*
> +              * Since `so2' is protected by vnode(9) lock, `so3'
> +              * can't be PRU_ABORT'ed here.
> +              */
> +             solock_pair(so, so3);
> +
>               unp2 = sotounpcb(so2);
>               unp3 = sotounpcb(so3);
> +
> +             /*
> +              * `unp_addr', `unp_connid' and 'UNP_FEIDSBIND' flag
> +              * are immutable since we set them in unp_bind().
> +              */
>               if (unp2->unp_addr)
>                       unp3->unp_addr =
>                           m_copym(unp2->unp_addr, 0, M_COPYALL, M_NOWAIT);
> @@ -700,15 +816,29 @@ unp_connect(struct socket *so, struct mb
>               unp3->unp_connid.gid = p->p_ucred->cr_gid;
>               unp3->unp_connid.pid = p->p_p->ps_pid;
>               unp3->unp_flags |= UNP_FEIDS;
> -             so2 = so3;
> +
>               if (unp2->unp_flags & UNP_FEIDSBIND) {
>                       unp->unp_connid = unp2->unp_connid;
>                       unp->unp_flags |= UNP_FEIDS;
>               }
> +
> +             so2 = so3;
> +     } else {
> +             if (so2 != so)
> +                     solock_pair(so, so2);
> +             else
> +                     solock(so);
>       }
> +
>       error = unp_connect2(so, so2);
> -put_locked:
> +
>       sounlock(so, SL_LOCKED);
> +
> +     /*
> +      * `so2' can't be PRU_ABORT'ed concurrently
> +      */
> +     if (so2 != so)
> +             sounlock(so2, SL_LOCKED);
>  put:
>       vput(vp);
>  unlock:
> @@ -732,7 +862,8 @@ unp_connect2(struct socket *so, struct s
>       struct unpcb *unp = sotounpcb(so);
>       struct unpcb *unp2;
>  
> -     rw_assert_wrlock(&unp_lock);
> +     soassertlocked(so);
> +     soassertlocked(so2);
>  
>       if (so2->so_type != so->so_type)
>               return (EPROTOTYPE);
> @@ -761,11 +892,15 @@ unp_connect2(struct socket *so, struct s
>  void
>  unp_disconnect(struct unpcb *unp)
>  {
> -     struct unpcb *unp2 = unp->unp_conn;
> +     struct socket *so2;
> +     struct unpcb *unp2;
>  
> -     if (unp2 == NULL)
> +     if ((so2 = unp_solock_peer(unp->unp_socket)) == NULL)
>               return;
> +
> +     unp2 = unp->unp_conn;
>       unp->unp_conn = NULL;
> +
>       switch (unp->unp_socket->so_type) {
>  
>       case SOCK_DGRAM:
> @@ -784,33 +919,29 @@ unp_disconnect(struct unpcb *unp)
>               soisdisconnected(unp2->unp_socket);
>               break;
>       }
> +
> +     if (so2 != unp->unp_socket)
> +             sounlock(so2, SL_LOCKED);
>  }
>  
>  void
>  unp_shutdown(struct unpcb *unp)
>  {
> -     struct socket *so;
> +     struct socket *so2;
>  
>       switch (unp->unp_socket->so_type) {
>       case SOCK_STREAM:
>       case SOCK_SEQPACKET:
> -             if (unp->unp_conn && (so = unp->unp_conn->unp_socket))
> -                     socantrcvmore(so);
> +             if ((so2 = unp_solock_peer(unp->unp_socket)) == NULL)
> +                     break;
> +             
> +             socantrcvmore(so2);
> +             sounlock(so2, SL_LOCKED);
> +
>               break;
>       default:
>               break;
>       }
> -}
> -
> -void
> -unp_drop(struct unpcb *unp, int errno)
> -{
> -     struct socket *so = unp->unp_socket;
> -
> -     rw_assert_wrlock(&unp_lock);
> -
> -     so->so_error = errno;
> -     unp_disconnect(unp);
>  }
>  
>  #ifdef notdef
> Index: sys/miscfs/fifofs/fifo_vnops.c
> ===================================================================
> RCS file: /cvs/src/sys/miscfs/fifofs/fifo_vnops.c,v
> retrieving revision 1.85
> diff -u -p -r1.85 fifo_vnops.c
> --- sys/miscfs/fifofs/fifo_vnops.c    24 Oct 2021 11:23:22 -0000      1.85
> +++ sys/miscfs/fifofs/fifo_vnops.c    22 Nov 2021 11:36:40 -0000
> @@ -156,7 +156,7 @@ fifo_open(void *v)
>       struct vnode *vp = ap->a_vp;
>       struct fifoinfo *fip;
>       struct socket *rso, *wso;
> -     int s, error;
> +     int error;
>  
>       if ((fip = vp->v_fifoinfo) == NULL) {
>               fip = malloc(sizeof(*fip), M_VNODE, M_WAITOK);
> @@ -182,18 +182,20 @@ fifo_open(void *v)
>                       return (error);
>               }
>               fip->fi_readers = fip->fi_writers = 0;
> -             s = solock(wso);
> +             solock(wso);
>               wso->so_state |= SS_CANTSENDMORE;
>               wso->so_snd.sb_lowat = PIPE_BUF;
> +             sounlock(wso, SL_LOCKED);
>       } else {
>               rso = fip->fi_readsock;
>               wso = fip->fi_writesock;
> -             s = solock(wso);
>       }
>       if (ap->a_mode & FREAD) {
>               fip->fi_readers++;
>               if (fip->fi_readers == 1) {
> +                     solock(wso);
>                       wso->so_state &= ~SS_CANTSENDMORE;
> +                     sounlock(wso, SL_LOCKED);
>                       if (fip->fi_writers > 0)
>                               wakeup(&fip->fi_writers);
>               }
> @@ -202,16 +204,16 @@ fifo_open(void *v)
>               fip->fi_writers++;
>               if ((ap->a_mode & O_NONBLOCK) && fip->fi_readers == 0) {
>                       error = ENXIO;
> -                     sounlock(wso, s);
>                       goto bad;
>               }
>               if (fip->fi_writers == 1) {
> +                     solock(rso);
>                       rso->so_state &= ~(SS_CANTRCVMORE|SS_ISDISCONNECTED);
> +                     sounlock(rso, SL_LOCKED);
>                       if (fip->fi_readers > 0)
>                               wakeup(&fip->fi_readers);
>               }
>       }
> -     sounlock(wso, s);
>       if ((ap->a_mode & O_NONBLOCK) == 0) {
>               if ((ap->a_mode & FREAD) && fip->fi_writers == 0) {
>                       VOP_UNLOCK(vp);
> @@ -327,17 +329,16 @@ fifo_poll(void *v)
>       struct socket *wso = ap->a_vp->v_fifoinfo->fi_writesock;
>       int events = 0;
>       int revents = 0;
> -     int s;
>  
>       /*
>        * FIFOs don't support out-of-band or high priority data.
>        */
> -     s = solock(rso);
>       if (ap->a_fflag & FREAD)
>               events |= ap->a_events & (POLLIN | POLLRDNORM);
>       if (ap->a_fflag & FWRITE)
>               events |= ap->a_events & (POLLOUT | POLLWRNORM);
>  
> +     solock_pair(rso, wso);
>       if (events & (POLLIN | POLLRDNORM)) {
>               if (soreadable(rso))
>                       revents |= events & (POLLIN | POLLRDNORM);
> @@ -362,7 +363,8 @@ fifo_poll(void *v)
>                       wso->so_snd.sb_flags |= SB_SEL;
>               }
>       }
> -     sounlock(rso, s);
> +     sounlock(rso, SL_LOCKED);
> +     sounlock(wso, SL_LOCKED);
>       return (revents);
>  }
>  
> Index: sys/sys/socketvar.h
> ===================================================================
> RCS file: /cvs/src/sys/sys/socketvar.h,v
> retrieving revision 1.101
> diff -u -p -r1.101 socketvar.h
> --- sys/sys/socketvar.h       6 Nov 2021 05:26:33 -0000       1.101
> +++ sys/sys/socketvar.h       22 Nov 2021 11:36:40 -0000
> @@ -38,6 +38,7 @@
>  #include <sys/task.h>
>  #include <sys/timeout.h>
>  #include <sys/rwlock.h>
> +#include <sys/refcnt.h>
>  
>  #ifndef      _SOCKLEN_T_DEFINED_
>  #define      _SOCKLEN_T_DEFINED_
> @@ -55,6 +56,7 @@ TAILQ_HEAD(soqhead, socket);
>  struct socket {
>       const struct protosw *so_proto; /* protocol handle */
>       struct rwlock so_lock;          /* this socket lock */
> +     struct refcnt so_refcnt;        /* references to this socket */
>       void    *so_pcb;                /* protocol control block */
>       u_int   so_state;               /* internal state flags SS_*, below */
>       short   so_type;                /* generic type, see socket.h */
> @@ -80,6 +82,7 @@ struct socket {
>       short   so_q0len;               /* partials on so_q0 */
>       short   so_qlen;                /* number of connections on so_q */
>       short   so_qlimit;              /* max number queued connections */
> +     u_long  so_newconn;             /* # of pending sonewconn() threads */
>       short   so_timeo;               /* connection timeout */
>       u_long  so_oobmark;             /* chars to oob mark */
>       u_int   so_error;               /* error affecting connection */
> @@ -150,6 +153,7 @@ struct socket {
>  #define      SS_CONNECTOUT           0x1000  /* connect, not accept, at this 
> end */
>  #define      SS_ISSENDING            0x2000  /* hint for lower layer */
>  #define      SS_DNS                  0x4000  /* created using SOCK_DNS 
> socket(2) */
> +#define      SS_NEWCONN_WAIT         0x8000  /* waiting sonewconn() relock */
>  
>  #ifdef _KERNEL
>  
> @@ -163,6 +167,18 @@ struct socket {
>  
>  void soassertlocked(struct socket *);
>  
> +static inline void
> +soref(struct socket *so)
> +{
> +     refcnt_take(&so->so_refcnt);
> +}
> +
> +static inline void
> +sorele(struct socket *so)
> +{
> +     refcnt_rele_wake(&so->so_refcnt);
> +}
> +
>  /*
>   * Macros for sockets and socket buffering.
>   */
> @@ -337,6 +353,8 @@ int       sockargs(struct mbuf **, const void 
>  
>  int  sosleep_nsec(struct socket *, void *, int, const char *, uint64_t);
>  int  solock(struct socket *);
> +int  solock_persocket(struct socket *);
> +void solock_pair(struct socket *, struct socket *);
>  void sounlock(struct socket *, int);
>  
>  int  sendit(struct proc *, int, struct msghdr *, int, register_t *);
> Index: sys/sys/unpcb.h
> ===================================================================
> RCS file: /cvs/src/sys/sys/unpcb.h,v
> retrieving revision 1.20
> diff -u -p -r1.20 unpcb.h
> --- sys/sys/unpcb.h   16 Nov 2021 08:56:20 -0000      1.20
> +++ sys/sys/unpcb.h   22 Nov 2021 11:36:40 -0000
> @@ -60,24 +60,26 @@
>   * Locks used to protect struct members:
>   *      I       immutable after creation
>   *      G       unp_gc_lock
> - *      U       unp_lock
>   *      a       atomic
> + *      s       socket lock
>   */
>  
>  
>  struct       unpcb {
> +     struct  refcnt unp_refcnt;      /* references to this pcb */
>       struct  socket *unp_socket;     /* [I] pointer back to socket */
> -     struct  vnode *unp_vnode;       /* [U] if associated with file */
> +     struct  vnode *unp_vnode;       /* [s] if associated with file */
>       struct  file *unp_file;         /* [a] backpointer for unp_gc() */
> -     struct  unpcb *unp_conn;        /* [U] control block of connected 
> socket */
> -     ino_t   unp_ino;                /* [U] fake inode number */
> -     SLIST_HEAD(,unpcb) unp_refs;    /* [U] referencing socket linked list */
> -     SLIST_ENTRY(unpcb) unp_nextref; /* [U] link in unp_refs list */
> -     struct  mbuf *unp_addr;         /* [U] bound address of socket */
> +     struct  unpcb *unp_conn;        /* [s] control block of connected
> +                                             socket */
> +     ino_t   unp_ino;                /* [s] fake inode number */
> +     SLIST_HEAD(,unpcb) unp_refs;    /* [s] referencing socket linked list */
> +     SLIST_ENTRY(unpcb) unp_nextref; /* [s] link in unp_refs list */
> +     struct  mbuf *unp_addr;         /* [s] bound address of socket */
>       long    unp_msgcount;           /* [a] references from socket rcv buf */
> -     int     unp_flags;              /* [U] this unpcb contains peer eids */
> -     int     unp_gcflags;            /* [G] garbge collector flags */
> -     struct  sockpeercred unp_connid;/* [U] id of peer process */
> +     int     unp_flags;              /* [s] this unpcb contains peer eids */
> +     int     unp_gcflags;            /* [G] garbage collector flags */
> +     struct  sockpeercred unp_connid;/* [s] id of peer process */
>       struct  timespec unp_ctime;     /* [I] holds creation time */
>       LIST_ENTRY(unpcb) unp_link;     /* [G] link in per-AF list of sockets */
>  };
> @@ -116,7 +118,6 @@ int       unp_connect(struct socket *, struct 
>  int  unp_connect2(struct socket *, struct socket *);
>  void unp_detach(struct unpcb *);
>  void unp_disconnect(struct unpcb *);
> -void unp_drop(struct unpcb *, int);
>  void unp_gc(void *);
>  void unp_shutdown(struct unpcb *);
>  int  unp_externalize(struct mbuf *, socklen_t, int);
> 

Reply via email to