The branch main has been updated by tuexen:

URL: 
https://cgit.FreeBSD.org/src/commit/?id=989453da0589b8dc5c1948fd81f986a37ea385eb

commit 989453da0589b8dc5c1948fd81f986a37ea385eb
Author:     Michael Tuexen <[email protected]>
AuthorDate: 2021-12-27 22:40:31 +0000
Commit:     Michael Tuexen <[email protected]>
CommitDate: 2021-12-27 22:40:31 +0000

    sctp: cleanup the SCTP_MAXSEG socket option.
    
    This patch makes the handling of the SCTP_MAXSEG socket option
    compliant with RFC 6458 (SCTP socket API) and fixes an issue
    found by syzkaller.
    
    Reported by:    [email protected]
    MFC after:      3 days
---
 sys/netinet/sctp_constants.h |  2 -
 sys/netinet/sctp_output.c    | 93 ++++++++++++++++++++++++--------------------
 sys/netinet/sctp_output.h    |  2 +-
 sys/netinet/sctp_pcb.c       |  2 +-
 sys/netinet/sctp_usrreq.c    | 37 +++---------------
 sys/netinet/sctputil.c       |  2 +-
 6 files changed, 59 insertions(+), 79 deletions(-)

diff --git a/sys/netinet/sctp_constants.h b/sys/netinet/sctp_constants.h
index 1ff3f3918ef6..66f2cca5ab6d 100644
--- a/sys/netinet/sctp_constants.h
+++ b/sys/netinet/sctp_constants.h
@@ -673,8 +673,6 @@ __FBSDID("$FreeBSD$");
 /* amount peer is obligated to have in rwnd or I will abort */
 #define SCTP_MIN_RWND  1500
 
-#define SCTP_DEFAULT_MAXSEGMENT 65535
-
 #define SCTP_CHUNK_BUFFER_SIZE 512
 #define SCTP_PARAM_BUFFER_SIZE 512
 
diff --git a/sys/netinet/sctp_output.c b/sys/netinet/sctp_output.c
index 65767f9f73a9..f6597bc6cbdc 100644
--- a/sys/netinet/sctp_output.c
+++ b/sys/netinet/sctp_output.c
@@ -6217,43 +6217,48 @@ sctp_prune_prsctp(struct sctp_tcb *stcb,
        }                       /* if enabled in asoc */
 }
 
-int
-sctp_get_frag_point(struct sctp_tcb *stcb,
-    struct sctp_association *asoc)
+uint32_t
+sctp_get_frag_point(struct sctp_tcb *stcb)
 {
-       int siz, ovh;
+       struct sctp_association *asoc;
+       uint32_t frag_point, overhead;
 
-       /*
-        * For endpoints that have both v6 and v4 addresses we must reserve
-        * room for the ipv6 header, for those that are only dealing with V4
-        * we use a larger frag point.
-        */
+       asoc = &stcb->asoc;
+       /* Consider IP header and SCTP common header. */
        if (stcb->sctp_ep->sctp_flags & SCTP_PCB_FLAGS_BOUND_V6) {
-               ovh = SCTP_MIN_OVERHEAD;
+               overhead = SCTP_MIN_OVERHEAD;
        } else {
-               ovh = SCTP_MIN_V4_OVERHEAD;
+               overhead = SCTP_MIN_V4_OVERHEAD;
        }
-       ovh += SCTP_DATA_CHUNK_OVERHEAD(stcb);
-       if (stcb->asoc.sctp_frag_point > asoc->smallest_mtu)
-               siz = asoc->smallest_mtu - ovh;
-       else
-               siz = (stcb->asoc.sctp_frag_point - ovh);
-       /*
-        * if (siz > (MCLBYTES-sizeof(struct sctp_data_chunk))) {
-        */
-       /* A data chunk MUST fit in a cluster */
-       /* siz = (MCLBYTES - sizeof(struct sctp_data_chunk)); */
-       /* } */
-
-       /* adjust for an AUTH chunk if DATA requires auth */
-       if (sctp_auth_is_required_chunk(SCTP_DATA, stcb->asoc.peer_auth_chunks))
-               siz -= sctp_get_auth_chunk_len(stcb->asoc.peer_hmac_id);
+       /* Consider DATA/IDATA chunk header and AUTH header, if needed. */
+       if (asoc->idata_supported) {
+               overhead += sizeof(struct sctp_idata_chunk);
+               if (sctp_auth_is_required_chunk(SCTP_IDATA, 
asoc->peer_auth_chunks)) {
+                       overhead += sctp_get_auth_chunk_len(asoc->peer_hmac_id);
+               }
+       } else {
+               overhead += sizeof(struct sctp_idata_chunk);
+               if (sctp_auth_is_required_chunk(SCTP_DATA, 
asoc->peer_auth_chunks)) {
+                       overhead += sctp_get_auth_chunk_len(asoc->peer_hmac_id);
+               }
+       }
+       /* Consider padding. */
+       if (asoc->smallest_mtu % 4) {
+               overhead += (asoc->smallest_mtu % 4);
+       }
+       KASSERT(overhead % 4 == 0,
+           ("overhead (%u) not a multiple of 4", overhead));
+       KASSERT(asoc->smallest_mtu > overhead,
+           ("Association MTU (%u) too small for overhead (%u)",
+           asoc->smallest_mtu, overhead));
 
-       if (siz % 4) {
-               /* make it an even word boundary please */
-               siz -= (siz % 4);
+       frag_point = asoc->smallest_mtu - overhead;
+       /* Honor MAXSEG socket option. */
+       if ((asoc->sctp_frag_point > 0) &&
+           (asoc->sctp_frag_point < frag_point)) {
+               frag_point = asoc->sctp_frag_point;
        }
-       return (siz);
+       return (frag_point);
 }
 
 static void
@@ -6571,7 +6576,8 @@ sctp_med_chunk_output(struct sctp_inpcb *inp,
     int *num_out,
     int *reason_code,
     int control_only, int from_where,
-    struct timeval *now, int *now_filled, int frag_point, int so_locked);
+    struct timeval *now, int *now_filled,
+    uint32_t frag_point, int so_locked);
 
 static void
 sctp_sendall_iterator(struct sctp_inpcb *inp, struct sctp_tcb *stcb, void *ptr,
@@ -6740,13 +6746,13 @@ sctp_sendall_iterator(struct sctp_inpcb *inp, struct 
sctp_tcb *stcb, void *ptr,
        if (do_chunk_output)
                sctp_chunk_output(inp, stcb, SCTP_OUTPUT_FROM_USR_SEND, 
SCTP_SO_NOT_LOCKED);
        else if (added_control) {
-               int num_out, reason, now_filled = 0;
                struct timeval now;
-               int frag_point;
+               int num_out, reason, now_filled = 0;
 
-               frag_point = sctp_get_frag_point(stcb, &stcb->asoc);
                (void)sctp_med_chunk_output(inp, stcb, &stcb->asoc, &num_out,
-                   &reason, 1, 1, &now, &now_filled, frag_point, 
SCTP_SO_NOT_LOCKED);
+                   &reason, 1, 1, &now, &now_filled,
+                   sctp_get_frag_point(stcb),
+                   SCTP_SO_NOT_LOCKED);
        }
 no_chunk_output:
        if (ret) {
@@ -7674,8 +7680,9 @@ out_of:
 }
 
 static void
-sctp_fill_outqueue(struct sctp_tcb *stcb, struct sctp_nets *net, int 
frag_point,
-    int eeor_mode, int *quit_now, int so_locked)
+sctp_fill_outqueue(struct sctp_tcb *stcb, struct sctp_nets *net,
+    uint32_t frag_point, int eeor_mode, int *quit_now,
+    int so_locked)
 {
        struct sctp_association *asoc;
        struct sctp_stream_out *strq;
@@ -7794,7 +7801,8 @@ sctp_med_chunk_output(struct sctp_inpcb *inp,
     int *num_out,
     int *reason_code,
     int control_only, int from_where,
-    struct timeval *now, int *now_filled, int frag_point, int so_locked)
+    struct timeval *now, int *now_filled,
+    uint32_t frag_point, int so_locked)
 {
        /**
         * Ok this is the generic chunk service queue. we must do the
@@ -9975,7 +9983,7 @@ sctp_chunk_output(struct sctp_inpcb *inp,
        struct timeval now;
        int now_filled = 0;
        int nagle_on;
-       int frag_point = sctp_get_frag_point(stcb, &stcb->asoc);
+       uint32_t frag_point = sctp_get_frag_point(stcb);
        int un_sent = 0;
        int fr_done;
        unsigned int tot_frs = 0;
@@ -13663,16 +13671,17 @@ skip_out_eof:
                }
                sctp_chunk_output(inp, stcb, SCTP_OUTPUT_FROM_USR_SEND, 
SCTP_SO_LOCKED);
        } else if (some_on_control) {
-               int num_out, reason, frag_point;
+               int num_out, reason;
 
                /* Here we do control only */
                if (hold_tcblock == 0) {
                        hold_tcblock = 1;
                        SCTP_TCB_LOCK(stcb);
                }
-               frag_point = sctp_get_frag_point(stcb, &stcb->asoc);
                (void)sctp_med_chunk_output(inp, stcb, &stcb->asoc, &num_out,
-                   &reason, 1, 1, &now, &now_filled, frag_point, 
SCTP_SO_LOCKED);
+                   &reason, 1, 1, &now, &now_filled,
+                   sctp_get_frag_point(stcb),
+                   SCTP_SO_LOCKED);
        }
        NET_EPOCH_EXIT(et);
        SCTPDBG(SCTP_DEBUG_OUTPUT1, "USR Send complete qo:%d prw:%d unsent:%d 
tf:%d cooq:%d toqs:%d err:%d\n",
diff --git a/sys/netinet/sctp_output.h b/sys/netinet/sctp_output.h
index 7d2cdc4071d8..e6ee80c41f1a 100644
--- a/sys/netinet/sctp_output.h
+++ b/sys/netinet/sctp_output.h
@@ -117,7 +117,7 @@ void sctp_send_asconf(struct sctp_tcb *, struct sctp_nets 
*, int addr_locked);
 
 void sctp_send_asconf_ack(struct sctp_tcb *);
 
-int sctp_get_frag_point(struct sctp_tcb *, struct sctp_association *);
+uint32_t sctp_get_frag_point(struct sctp_tcb *);
 
 void sctp_toss_old_cookies(struct sctp_tcb *, struct sctp_association *);
 
diff --git a/sys/netinet/sctp_pcb.c b/sys/netinet/sctp_pcb.c
index b4a742c11629..7ad651ec377f 100644
--- a/sys/netinet/sctp_pcb.c
+++ b/sys/netinet/sctp_pcb.c
@@ -2422,7 +2422,7 @@ sctp_inpcb_alloc(struct socket *so, uint32_t vrf_id)
 #endif
        inp->sctp_associd_counter = 1;
        inp->partial_delivery_point = SCTP_SB_LIMIT_RCV(so) >> 
SCTP_PARTIAL_DELIVERY_SHIFT;
-       inp->sctp_frag_point = SCTP_DEFAULT_MAXSEGMENT;
+       inp->sctp_frag_point = 0;
        inp->max_cwnd = 0;
        inp->sctp_cmt_on_off = SCTP_BASE_SYSCTL(sctp_cmt_on_off);
        inp->ecn_supported = (uint8_t)SCTP_BASE_SYSCTL(sctp_ecn_enable);
diff --git a/sys/netinet/sctp_usrreq.c b/sys/netinet/sctp_usrreq.c
index f218950feef9..bb84d3b7083f 100644
--- a/sys/netinet/sctp_usrreq.c
+++ b/sys/netinet/sctp_usrreq.c
@@ -2032,13 +2032,12 @@ flags_out:
        case SCTP_MAXSEG:
                {
                        struct sctp_assoc_value *av;
-                       int ovh;
 
                        SCTP_CHECK_AND_CAST(av, optval, struct 
sctp_assoc_value, *optsize);
                        SCTP_FIND_STCB(inp, stcb, av->assoc_id);
 
                        if (stcb) {
-                               av->assoc_value = sctp_get_frag_point(stcb, 
&stcb->asoc);
+                               av->assoc_value = stcb->asoc.sctp_frag_point;
                                SCTP_TCB_UNLOCK(stcb);
                        } else {
                                if ((inp->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) 
||
@@ -2046,15 +2045,7 @@ flags_out:
                                    ((inp->sctp_flags & SCTP_PCB_FLAGS_UDPTYPE) 
&&
                                    (av->assoc_id == SCTP_FUTURE_ASSOC))) {
                                        SCTP_INP_RLOCK(inp);
-                                       if (inp->sctp_flags & 
SCTP_PCB_FLAGS_BOUND_V6) {
-                                               ovh = SCTP_MED_OVERHEAD;
-                                       } else {
-                                               ovh = SCTP_MED_V4_OVERHEAD;
-                                       }
-                                       if (inp->sctp_frag_point >= 
SCTP_DEFAULT_MAXSEGMENT)
-                                               av->assoc_value = 0;
-                                       else
-                                               av->assoc_value = 
inp->sctp_frag_point - ovh;
+                                       av->assoc_value = inp->sctp_frag_point;
                                        SCTP_INP_RUNLOCK(inp);
                                } else {
                                        SCTP_LTRACE_ERR_RET(inp, NULL, NULL, 
SCTP_FROM_SCTP_USRREQ, EINVAL);
@@ -2623,7 +2614,7 @@ flags_out:
                            stcb->asoc.cnt_on_all_streams);
                        sstat->sstat_instrms = stcb->asoc.streamincnt;
                        sstat->sstat_outstrms = stcb->asoc.streamoutcnt;
-                       sstat->sstat_fragmentation_point = 
sctp_get_frag_point(stcb, &stcb->asoc);
+                       sstat->sstat_fragmentation_point = 
sctp_get_frag_point(stcb);
                        net = stcb->asoc.primary_destination;
                        if (net != NULL) {
                                memcpy(&sstat->sstat_primary.spinfo_address,
@@ -4977,22 +4968,12 @@ sctp_setopt(struct socket *so, int optname, void 
*optval, size_t optsize,
        case SCTP_MAXSEG:
                {
                        struct sctp_assoc_value *av;
-                       int ovh;
 
                        SCTP_CHECK_AND_CAST(av, optval, struct 
sctp_assoc_value, optsize);
                        SCTP_FIND_STCB(inp, stcb, av->assoc_id);
 
-                       if (inp->sctp_flags & SCTP_PCB_FLAGS_BOUND_V6) {
-                               ovh = SCTP_MED_OVERHEAD;
-                       } else {
-                               ovh = SCTP_MED_V4_OVERHEAD;
-                       }
                        if (stcb) {
-                               if (av->assoc_value) {
-                                       stcb->asoc.sctp_frag_point = 
(av->assoc_value + ovh);
-                               } else {
-                                       stcb->asoc.sctp_frag_point = 
SCTP_DEFAULT_MAXSEGMENT;
-                               }
+                               stcb->asoc.sctp_frag_point = av->assoc_value;
                                SCTP_TCB_UNLOCK(stcb);
                        } else {
                                if ((inp->sctp_flags & SCTP_PCB_FLAGS_TCPTYPE) 
||
@@ -5000,15 +4981,7 @@ sctp_setopt(struct socket *so, int optname, void 
*optval, size_t optsize,
                                    ((inp->sctp_flags & SCTP_PCB_FLAGS_UDPTYPE) 
&&
                                    (av->assoc_id == SCTP_FUTURE_ASSOC))) {
                                        SCTP_INP_WLOCK(inp);
-                                       /*
-                                        * FIXME MT: I think this is not in
-                                        * tune with the API ID
-                                        */
-                                       if (av->assoc_value) {
-                                               inp->sctp_frag_point = 
(av->assoc_value + ovh);
-                                       } else {
-                                               inp->sctp_frag_point = 
SCTP_DEFAULT_MAXSEGMENT;
-                                       }
+                                       inp->sctp_frag_point = av->assoc_value;
                                        SCTP_INP_WUNLOCK(inp);
                                } else {
                                        SCTP_LTRACE_ERR_RET(inp, NULL, NULL, 
SCTP_FROM_SCTP_USRREQ, EINVAL);
diff --git a/sys/netinet/sctputil.c b/sys/netinet/sctputil.c
index 6c58ad47f274..df3768ca2a35 100644
--- a/sys/netinet/sctputil.c
+++ b/sys/netinet/sctputil.c
@@ -1248,7 +1248,7 @@ sctp_init_asoc(struct sctp_inpcb *inp, struct sctp_tcb 
*stcb,
        asoc->my_rwnd = max(SCTP_SB_LIMIT_RCV(inp->sctp_socket), 
SCTP_MINIMAL_RWND);
        asoc->peers_rwnd = SCTP_SB_LIMIT_RCV(inp->sctp_socket);
 
-       asoc->smallest_mtu = inp->sctp_frag_point;
+       asoc->smallest_mtu = 0;
        asoc->minrto = inp->sctp_ep.sctp_minrto;
        asoc->maxrto = inp->sctp_ep.sctp_maxrto;
 

Reply via email to