I'd like to enforce the following "lock" ordering: always hold the
socket lock when calling sblock().

This would allow me to protect `so_state' in sosend() when setting the
SS_ISSENDING bit.

Diff below implements that.  It also gets rid of sbsleep() and uses
sosleep() instead.

ok?

Index: sys/socketvar.h
===================================================================
RCS file: /cvs/src/sys/sys/socketvar.h,v
retrieving revision 1.70
diff -u -p -r1.70 socketvar.h
--- sys/socketvar.h     26 Jun 2017 09:32:32 -0000      1.70
+++ sys/socketvar.h     26 Jun 2017 14:01:31 -0000
@@ -239,7 +239,7 @@ struct rwlock;
  * Unless SB_NOINTR is set on sockbuf, sleep is interruptible.
  * Returns error without lock if sleep is interrupted.
  */
-int sblock(struct sockbuf *, int, struct rwlock *);
+int sblock(struct socket *, struct sockbuf *, int);
 
 /* release lock on sockbuf sb */
 void sbunlock(struct sockbuf *);
Index: kern/uipc_socket.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket.c,v
retrieving revision 1.189
diff -u -p -r1.189 uipc_socket.c
--- kern/uipc_socket.c  26 Jun 2017 09:32:31 -0000      1.189
+++ kern/uipc_socket.c  26 Jun 2017 14:10:03 -0000
@@ -418,14 +418,14 @@ sosend(struct socket *so, struct mbuf *a
                            (sizeof(struct fdpass) / sizeof(int)));
        }
 
-#define        snderr(errno)   { error = errno; sounlock(s); goto release; }
+#define        snderr(errno)   { error = errno; goto release; }
 
 restart:
-       if ((error = sblock(&so->so_snd, SBLOCKWAIT(flags), NULL)) != 0)
+       s = solock(so);
+       if ((error = sblock(so, &so->so_snd, SBLOCKWAIT(flags))) != 0)
                goto out;
        so->so_state |= SS_ISSENDING;
        do {
-               s = solock(so);
                if (so->so_state & SS_CANTSENDMORE)
                        snderr(EPIPE);
                if (so->so_error) {
@@ -455,12 +455,10 @@ restart:
                        sbunlock(&so->so_snd);
                        error = sbwait(so, &so->so_snd);
                        so->so_state &= ~SS_ISSENDING;
-                       sounlock(s);
                        if (error)
                                goto out;
                        goto restart;
                }
-               sounlock(s);
                space -= clen;
                do {
                        if (uio == NULL) {
@@ -471,8 +469,9 @@ restart:
                                if (flags & MSG_EOR)
                                        top->m_flags |= M_EOR;
                        } else {
-                               error = m_getuio(&top, atomic,
-                                   space, uio);
+                               sounlock(s);
+                               error = m_getuio(&top, atomic, space, uio);
+                               s = solock(so);
                                if (error)
                                        goto release;
                                space -= top->m_pkthdr.len;
@@ -480,7 +479,6 @@ restart:
                                if (flags & MSG_EOR)
                                        top->m_flags |= M_EOR;
                        }
-                       s = solock(so);
                        if (resid == 0)
                                so->so_state &= ~SS_ISSENDING;
                        if (top && so->so_options & SO_ZEROIZE)
@@ -488,7 +486,6 @@ restart:
                        error = (*so->so_proto->pr_usrreq)(so,
                            (flags & MSG_OOB) ? PRU_SENDOOB : PRU_SEND,
                            top, addr, control, curproc);
-                       sounlock(s);
                        clen = 0;
                        control = NULL;
                        top = NULL;
@@ -501,6 +498,7 @@ release:
        so->so_state &= ~SS_ISSENDING;
        sbunlock(&so->so_snd);
 out:
+       sounlock(s);
        m_freem(top);
        m_freem(control);
        return (error);
@@ -670,9 +668,11 @@ bad:
                *mp = NULL;
 
 restart:
-       if ((error = sblock(&so->so_rcv, SBLOCKWAIT(flags), NULL)) != 0)
-               return (error);
        s = solock(so);
+       if ((error = sblock(so, &so->so_rcv, SBLOCKWAIT(flags))) != 0) {
+               sounlock(s);
+               return (error);
+       }
 
        m = so->so_rcv.sb_mb;
 #ifdef SOCKET_SPLICE
@@ -1040,13 +1040,10 @@ sorflush(struct socket *so)
 {
        struct sockbuf *sb = &so->so_rcv;
        struct protosw *pr = so->so_proto;
-       sa_family_t af = pr->pr_domain->dom_family;
        struct socket aso;
 
        sb->sb_flags |= SB_NOINTR;
-       sblock(sb, M_WAITOK,
-           (af != PF_LOCAL && af != PF_ROUTE && af != PF_KEY) ?
-           &netlock : NULL);
+       sblock(so, sb, M_WAITOK);
        socantrcvmore(so);
        sbunlock(sb);
        aso.so_proto = pr;
@@ -1094,11 +1091,13 @@ sosplice(struct socket *so, int fd, off_
 
        /* If no fd is given, unsplice by removing existing link. */
        if (fd < 0) {
+               s = solock(so);
                /* Lock receive buffer. */
-               if ((error = sblock(&so->so_rcv,
-                   (so->so_state & SS_NBIO) ? M_NOWAIT : M_WAITOK, NULL)) != 0)
+               if ((error = sblock(so, &so->so_rcv,
+                   (so->so_state & SS_NBIO) ? M_NOWAIT : M_WAITOK)) != 0) {
+                       sounlock(s);
                        return (error);
-               s = solock(so);
+               }
                if (so->so_sp->ssp_socket)
                        sounsplice(so, so->so_sp->ssp_socket, 1);
                sounlock(s);
@@ -1119,18 +1118,20 @@ sosplice(struct socket *so, int fd, off_
        if (sosp->so_sp == NULL)
                sosp->so_sp = pool_get(&sosplice_pool, PR_WAITOK | PR_ZERO);
 
+       s = solock(so);
        /* Lock both receive and send buffer. */
-       if ((error = sblock(&so->so_rcv,
-           (so->so_state & SS_NBIO) ? M_NOWAIT : M_WAITOK, NULL)) != 0) {
+       if ((error = sblock(so, &so->so_rcv,
+           (so->so_state & SS_NBIO) ? M_NOWAIT : M_WAITOK)) != 0) {
+               sounlock(s);
                FRELE(fp, curproc);
                return (error);
        }
-       if ((error = sblock(&sosp->so_snd, M_WAITOK, NULL)) != 0) {
+       if ((error = sblock(so, &sosp->so_snd, M_WAITOK)) != 0) {
                sbunlock(&so->so_rcv);
+               sounlock(s);
                FRELE(fp, curproc);
                return (error);
        }
-       s = solock(so);
 
        if (so->so_sp->ssp_socket || sosp->so_sp->ssp_soback) {
                error = EBUSY;
Index: kern/uipc_socket2.c
===================================================================
RCS file: /cvs/src/sys/kern/uipc_socket2.c,v
retrieving revision 1.79
diff -u -p -r1.79 uipc_socket2.c
--- kern/uipc_socket2.c 26 Jun 2017 09:32:31 -0000      1.79
+++ kern/uipc_socket2.c 26 Jun 2017 14:01:31 -0000
@@ -54,8 +54,6 @@ u_long        sb_max = SB_MAX;                /* patchable */
 extern struct pool mclpools[];
 extern struct pool mbpool;
 
-int sbsleep(struct sockbuf *, struct rwlock *);
-
 /*
  * Procedures to manipulate state flags of socket
  * and do appropriate wakeups.  Normal sequence from the
@@ -332,24 +330,12 @@ sbwait(struct socket *so, struct sockbuf
 }
 
 int
-sbsleep(struct sockbuf *sb, struct rwlock *lock)
+sblock(struct socket *so, struct sockbuf *sb, int wait)
 {
        int error, prio = (sb->sb_flags & SB_NOINTR) ? PSOCK : PSOCK | PCATCH;
 
-       if (lock != NULL)
-               error = rwsleep(&sb->sb_flags, lock, prio, "netlck", 0);
-       else
-               error = tsleep(&sb->sb_flags, prio, "netlck", 0);
-
-       return (error);
-}
-
-int
-sblock(struct sockbuf *sb, int wait, struct rwlock *lock)
-{
-       int error;
-
        KERNEL_ASSERT_LOCKED();
+       soassertlocked(so);
 
        if ((sb->sb_flags & SB_LOCK) == 0) {
                sb->sb_flags |= SB_LOCK;
@@ -360,7 +346,7 @@ sblock(struct sockbuf *sb, int wait, str
 
        while (sb->sb_flags & SB_LOCK) {
                sb->sb_flags |= SB_WANT;
-               error = sbsleep(sb, lock);
+               error = sosleep(so, &sb->sb_flags, prio, "netlck", 0);
                if (error)
                        return (error);
        }

Reply via email to