if you're looking at an ip header, it makes sense to do some checks to
make sure that the values and addresses make some sense. the canonical
versions of these checks are in the ipv4 and ipv6 input paths, which
makes sense. when bridge(4) is about to run packets through pf it makes
sure the ip headers are sane before first, which i think also makes
sense. veb and tpmr don't do these checks before they run pf, but i
think they should. however, duplicating the code again doesn't appeal to
me.

this factors the ip checks out in the ip_input path, and uses that code
from bridge, veb, and tpmr.

this is mostly shuffling the deck chairs, but ipv6 is moved around a bit
more than ipv4, so some eyes and tests would be appreciated.

in the future i think the ipv6 code should do length checks like the
ipv4 code does too. this diff is big enough as it is though.

ok?

Index: net/if_bridge.c
===================================================================
RCS file: /cvs/src/sys/net/if_bridge.c,v
retrieving revision 1.354
diff -u -p -r1.354 if_bridge.c
--- net/if_bridge.c     5 Mar 2021 06:44:09 -0000       1.354
+++ net/if_bridge.c     31 May 2021 04:21:51 -0000
@@ -1674,61 +1674,12 @@ bridge_ip(struct ifnet *brifp, int dir, 
        switch (etype) {
 
        case ETHERTYPE_IP:
-               if (m->m_pkthdr.len < sizeof(struct ip))
-                       goto dropit;
-
-               /* Copy minimal header, and drop invalids */
-               if (m->m_len < sizeof(struct ip) &&
-                   (m = m_pullup(m, sizeof(struct ip))) == NULL) {
-                       ipstat_inc(ips_toosmall);
+               m = ipv4_check(ifp, m);
+               if (m == NULL)
                        return (NULL);
-               }
-               ip = mtod(m, struct ip *);
-
-               if (ip->ip_v != IPVERSION) {
-                       ipstat_inc(ips_badvers);
-                       goto dropit;
-               }
-
-               hlen = ip->ip_hl << 2;  /* get whole header length */
-               if (hlen < sizeof(struct ip)) {
-                       ipstat_inc(ips_badhlen);
-                       goto dropit;
-               }
-
-               if (hlen > m->m_len) {
-                       if ((m = m_pullup(m, hlen)) == NULL) {
-                               ipstat_inc(ips_badhlen);
-                               return (NULL);
-                       }
-                       ip = mtod(m, struct ip *);
-               }
-
-               if ((m->m_pkthdr.csum_flags & M_IPV4_CSUM_IN_OK) == 0) {
-                       if (m->m_pkthdr.csum_flags & M_IPV4_CSUM_IN_BAD) {
-                               ipstat_inc(ips_badsum);
-                               goto dropit;
-                       }
-
-                       ipstat_inc(ips_inswcsum);
-                       if (in_cksum(m, hlen) != 0) {
-                               ipstat_inc(ips_badsum);
-                               goto dropit;
-                       }
-               }
-
-               if (ntohs(ip->ip_len) < hlen)
-                       goto dropit;
 
-               if (m->m_pkthdr.len < ntohs(ip->ip_len))
-                       goto dropit;
-               if (m->m_pkthdr.len > ntohs(ip->ip_len)) {
-                       if (m->m_len == m->m_pkthdr.len) {
-                               m->m_len = ntohs(ip->ip_len);
-                               m->m_pkthdr.len = ntohs(ip->ip_len);
-                       } else
-                               m_adj(m, ntohs(ip->ip_len) - m->m_pkthdr.len);
-               }
+               ip = mtod(m, struct ip *);
+               hlen = ip->ip_hl << 2;
 
 #ifdef IPSEC
                if ((brifp->if_flags & IFF_LINK2) == IFF_LINK2 &&
@@ -1772,23 +1723,10 @@ bridge_ip(struct ifnet *brifp, int dir, 
                break;
 
 #ifdef INET6
-       case ETHERTYPE_IPV6: {
-               struct ip6_hdr *ip6;
-
-               if (m->m_len < sizeof(struct ip6_hdr)) {
-                       if ((m = m_pullup(m, sizeof(struct ip6_hdr)))
-                           == NULL) {
-                               ip6stat_inc(ip6s_toosmall);
-                               return (NULL);
-                       }
-               }
-
-               ip6 = mtod(m, struct ip6_hdr *);
-
-               if ((ip6->ip6_vfc & IPV6_VERSION_MASK) != IPV6_VERSION) {
-                       ip6stat_inc(ip6s_badvers);
-                       goto dropit;
-               }
+       case ETHERTYPE_IPV6:
+               m = ipv6_check(ifp, m);
+               if (m == NULL)
+                       return (NULL);
 
 #ifdef IPSEC
                hlen = sizeof(struct ip6_hdr);
@@ -1819,7 +1757,6 @@ bridge_ip(struct ifnet *brifp, int dir, 
 #endif /* NPF > 0 */
 
                break;
-       }
 #endif /* INET6 */
 
        default:
Index: net/if_tpmr.c
===================================================================
RCS file: /cvs/src/sys/net/if_tpmr.c,v
retrieving revision 1.26
diff -u -p -r1.26 if_tpmr.c
--- net/if_tpmr.c       27 May 2021 03:46:15 -0000      1.26
+++ net/if_tpmr.c       31 May 2021 04:21:51 -0000
@@ -242,23 +242,40 @@ tpmr_8021q_filter(const struct mbuf *m, 
 }
 
 #if NPF > 0
+struct tpmr_pf_ip_family {
+       sa_family_t        af;
+       struct mbuf     *(*ip_check)(struct ifnet *, struct mbuf *);
+       void             (*ip_input)(struct ifnet *, struct mbuf *);
+};
+
+static const struct tpmr_pf_ip_family tpmr_pf_ipv4 = {
+       .af             = AF_INET,
+       .ip_check       = ipv4_check,
+       .ip_input       = ipv4_input,
+};
+
+#ifdef INET6
+static const struct tpmr_pf_ip_family tpmr_pf_ipv6 = {
+       .af             = AF_INET6,
+       .ip_check       = ipv6_check,
+       .ip_input       = ipv6_input,
+};
+#endif
+
 static struct mbuf *
 tpmr_pf(struct ifnet *ifp0, int dir, struct mbuf *m)
 {
        struct ether_header *eh, copy;
-       sa_family_t af = AF_UNSPEC;
-       void (*ip_input)(struct ifnet *, struct mbuf *) = NULL;
+       const struct tpmr_pf_ip_family *fam;
 
        eh = mtod(m, struct ether_header *);
        switch (ntohs(eh->ether_type)) {
        case ETHERTYPE_IP:
-               af = AF_INET;
-               ip_input = ipv4_input;
+               fam = &tpmr_pf_ipv4;
                break;
 #ifdef INET6
        case ETHERTYPE_IPV6:
-               af = AF_INET6;
-               ip_input = ipv6_input;
+               fam = &tpmr_pf_ipv6;
                break;
 #endif
        default:
@@ -268,7 +285,13 @@ tpmr_pf(struct ifnet *ifp0, int dir, str
        copy = *eh;
        m_adj(m, sizeof(*eh));
 
-       if (pf_test(af, dir, ifp0, &m) != PF_PASS) {
+       if (dir == PF_IN) {
+               m = (*fam->ip_check)(ifp0, m);
+               if (m == NULL)
+                       return (NULL);
+       }
+
+       if (pf_test(fam->af, dir, ifp0, &m) != PF_PASS) {
                m_freem(m);
                return (NULL);
        }
@@ -278,7 +301,7 @@ tpmr_pf(struct ifnet *ifp0, int dir, str
        if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
                pf_mbuf_unlink_state_key(m);
                pf_mbuf_unlink_inpcb(m);
-               (*ip_input)(ifp0, m);
+               (*fam->ip_input)(ifp0, m);
                return (NULL);
        }
 
Index: net/if_veb.c
===================================================================
RCS file: /cvs/src/sys/net/if_veb.c,v
retrieving revision 1.18
diff -u -p -r1.18 if_veb.c
--- net/if_veb.c        27 May 2021 03:43:23 -0000      1.18
+++ net/if_veb.c        31 May 2021 04:21:51 -0000
@@ -495,12 +495,31 @@ veb_rule_filter(struct veb_port *p, int 
 }
 
 #if NPF > 0
+struct veb_pf_ip_family {
+       sa_family_t        af;
+       struct mbuf     *(*ip_check)(struct ifnet *, struct mbuf *);
+       void             (*ip_input)(struct ifnet *, struct mbuf *);
+};
+
+static const struct veb_pf_ip_family veb_pf_ipv4 = {
+       .af             = AF_INET,
+       .ip_check       = ipv4_check,
+       .ip_input       = ipv4_input,
+};
+
+#ifdef INET6
+static const struct veb_pf_ip_family veb_pf_ipv6 = {
+       .af             = AF_INET6,
+       .ip_check       = ipv6_check,
+       .ip_input       = ipv6_input,
+};
+#endif
+
 static struct mbuf *
 veb_pf(struct ifnet *ifp0, int dir, struct mbuf *m)
 {
        struct ether_header *eh, copy;
-       sa_family_t af = AF_UNSPEC;
-       void (*ip_input)(struct ifnet *, struct mbuf *) = NULL;
+       const struct veb_pf_ip_family *fam;
 
        /*
         * pf runs on vport interfaces when they enter or leave the
@@ -515,13 +534,11 @@ veb_pf(struct ifnet *ifp0, int dir, stru
        eh = mtod(m, struct ether_header *);
        switch (ntohs(eh->ether_type)) {
        case ETHERTYPE_IP:
-               af = AF_INET;
-               ip_input = ipv4_input;
+               fam = &veb_pf_ipv4;
                break;
 #ifdef INET6
        case ETHERTYPE_IPV6:
-               af = AF_INET6;
-               ip_input = ipv6_input;
+               fam = &veb_pf_ipv6;
                break;
 #endif
        default:
@@ -531,7 +548,13 @@ veb_pf(struct ifnet *ifp0, int dir, stru
        copy = *eh;
        m_adj(m, sizeof(*eh));
 
-       if (pf_test(af, dir, ifp0, &m) != PF_PASS) {
+       if (dir == PF_IN) {
+               m = (*fam->ip_check)(ifp0, m);
+               if (m == NULL)
+                       return (NULL);
+       }
+
+       if (pf_test(fam->af, dir, ifp0, &m) != PF_PASS) {
                m_freem(m);
                return (NULL);
        }
@@ -541,7 +564,7 @@ veb_pf(struct ifnet *ifp0, int dir, stru
        if (dir == PF_IN && ISSET(m->m_pkthdr.pf.flags, PF_TAG_DIVERTED)) {
                pf_mbuf_unlink_state_key(m);
                pf_mbuf_unlink_inpcb(m);
-               (*ip_input)(ifp0, m);
+               (*fam->ip_input)(ifp0, m);
                return (NULL);
        }
 
Index: netinet/in.h
===================================================================
RCS file: /cvs/src/sys/netinet/in.h,v
retrieving revision 1.140
diff -u -p -r1.140 in.h
--- netinet/in.h        18 Jan 2021 12:22:40 -0000      1.140
+++ netinet/in.h        31 May 2021 04:21:51 -0000
@@ -772,6 +772,8 @@ struct ifaddr;
 struct in_ifaddr;
 
 void      ipv4_input(struct ifnet *, struct mbuf *);
+struct mbuf *
+          ipv4_check(struct ifnet *, struct mbuf *);
 
 int       in_broadcast(struct in_addr, u_int);
 int       in_canforward(struct in_addr);
Index: netinet/ip_input.c
===================================================================
RCS file: /cvs/src/sys/netinet/ip_input.c,v
retrieving revision 1.360
diff -u -p -r1.360 ip_input.c
--- netinet/ip_input.c  15 May 2021 08:07:20 -0000      1.360
+++ netinet/ip_input.c  31 May 2021 04:21:51 -0000
@@ -244,37 +244,36 @@ ipv4_input(struct ifnet *ifp, struct mbu
        KASSERT(nxt == IPPROTO_DONE);
 }
 
-int
-ip_input_if(struct mbuf **mp, int *offp, int nxt, int af, struct ifnet *ifp)
+struct mbuf *
+ipv4_check(struct ifnet *ifp, struct mbuf *m)
 {
-       struct mbuf     *m = *mp;
-       struct rtentry  *rt = NULL;
-       struct ip       *ip;
+       struct ip *ip;
        int hlen, len;
-       in_addr_t pfrdr = 0;
-
-       KASSERT(*offp == 0);
 
-       ipstat_inc(ips_total);
-       if (m->m_len < sizeof (struct ip) &&
-           (m = *mp = m_pullup(m, sizeof (struct ip))) == NULL) {
-               ipstat_inc(ips_toosmall);
-               goto bad;
+       if (m->m_len < sizeof(*ip)) {
+               m = m_pullup(m, sizeof(*ip));
+               if (m == NULL) {
+                       ipstat_inc(ips_toosmall);
+                       return (NULL);
+               }
        }
+
        ip = mtod(m, struct ip *);
        if (ip->ip_v != IPVERSION) {
                ipstat_inc(ips_badvers);
                goto bad;
        }
+
        hlen = ip->ip_hl << 2;
-       if (hlen < sizeof(struct ip)) { /* minimum header length */
+       if (hlen < sizeof(*ip)) {       /* minimum header length */
                ipstat_inc(ips_badhlen);
                goto bad;
        }
        if (hlen > m->m_len) {
-               if ((m = *mp = m_pullup(m, hlen)) == NULL) {
+               m = m_pullup(m, hlen);
+               if (m == NULL) {
                        ipstat_inc(ips_badhlen);
-                       goto bad;
+                       return (NULL);
                }
                ip = mtod(m, struct ip *);
        }
@@ -329,6 +328,28 @@ ip_input_if(struct mbuf **mp, int *offp,
                } else
                        m_adj(m, len - m->m_pkthdr.len);
        }
+
+       return (m);
+bad:
+       m_freem(m);
+       return (NULL);
+}
+
+int
+ip_input_if(struct mbuf **mp, int *offp, int nxt, int af, struct ifnet *ifp)
+{
+       struct mbuf     *m;
+       struct rtentry  *rt = NULL;
+       struct ip       *ip;
+       int hlen;
+       in_addr_t pfrdr = 0;
+
+       KASSERT(*offp == 0);
+
+       ipstat_inc(ips_total);
+       m = *mp = ipv4_check(ifp, *mp);
+       if (m == NULL)
+               goto bad;
 
 #if NCARP > 0
        if (carp_lsdrop(ifp, m, AF_INET, &ip->ip_src.s_addr,
Index: netinet6/in6.h
===================================================================
RCS file: /cvs/src/sys/netinet6/in6.h,v
retrieving revision 1.108
diff -u -p -r1.108 in6.h
--- netinet6/in6.h      10 Mar 2021 10:21:49 -0000      1.108
+++ netinet6/in6.h      31 May 2021 04:21:52 -0000
@@ -415,6 +415,8 @@ struct in6_ifaddr;
 struct ifnet;
 
 void   ipv6_input(struct ifnet *, struct mbuf *);
+struct mbuf *
+       ipv6_check(struct ifnet *, struct mbuf *);
 
 int    in6_cksum(struct mbuf *, uint8_t, uint32_t, uint32_t);
 void   in6_proto_cksum_out(struct mbuf *, struct ifnet *);
Index: netinet6/ip6_input.c
===================================================================
RCS file: /cvs/src/sys/netinet6/ip6_input.c,v
retrieving revision 1.234
diff -u -p -r1.234 ip6_input.c
--- netinet6/ip6_input.c        17 May 2021 10:09:53 -0000      1.234
+++ netinet6/ip6_input.c        31 May 2021 04:21:52 -0000
@@ -172,28 +172,16 @@ ipv6_input(struct ifnet *ifp, struct mbu
        KASSERT(nxt == IPPROTO_DONE);
 }
 
-int
-ip6_input_if(struct mbuf **mp, int *offp, int nxt, int af, struct ifnet *ifp)
+struct mbuf *
+ipv6_check(struct ifnet *ifp, struct mbuf *m)
 {
-       struct mbuf *m = *mp;
        struct ip6_hdr *ip6;
-       struct sockaddr_in6 sin6;
-       struct rtentry *rt = NULL;
-       int ours = 0;
-       u_int16_t src_scope, dst_scope;
-#if NPF > 0
-       struct in6_addr odst;
-#endif
-       int srcrt = 0;
-
-       KASSERT(*offp == 0);
 
-       ip6stat_inc(ip6s_total);
-
-       if (m->m_len < sizeof(struct ip6_hdr)) {
-               if ((m = *mp = m_pullup(m, sizeof(struct ip6_hdr))) == NULL) {
+       if (m->m_len < sizeof(*ip6)) {
+               m = m_pullup(m, sizeof(*ip6));
+               if (m == NULL) {
                        ip6stat_inc(ip6s_toosmall);
-                       goto bad;
+                       return (NULL);
                }
        }
 
@@ -204,13 +192,6 @@ ip6_input_if(struct mbuf **mp, int *offp
                goto bad;
        }
 
-#if NCARP > 0
-       if (carp_lsdrop(ifp, m, AF_INET6, ip6->ip6_src.s6_addr32,
-           ip6->ip6_dst.s6_addr32, (ip6->ip6_nxt == IPPROTO_ICMPV6 ? 0 : 1)))
-               goto bad;
-#endif
-       ip6stat_inc(ip6s_nxthist + ip6->ip6_nxt);
-
        /*
         * Check against address spoofing/corruption.
         */
@@ -225,8 +206,8 @@ ip6_input_if(struct mbuf **mp, int *offp
        if ((IN6_IS_ADDR_LOOPBACK(&ip6->ip6_src) ||
            IN6_IS_ADDR_LOOPBACK(&ip6->ip6_dst)) &&
            (ifp->if_flags & IFF_LOOPBACK) == 0) {
-                   ip6stat_inc(ip6s_badscope);
-                   goto bad;
+               ip6stat_inc(ip6s_badscope);
+               goto bad;
        }
        /* Drop packets if interface ID portion is already filled. */
        if (((IN6_IS_SCOPE_EMBED(&ip6->ip6_src) && ip6->ip6_src.s6_addr16[1]) ||
@@ -275,6 +256,41 @@ ip6_input_if(struct mbuf **mp, int *offp
                ip6stat_inc(ip6s_badscope);
                goto bad;
        }
+
+       return (m);
+bad:
+       m_freem(m);
+       return (NULL);
+}
+
+int
+ip6_input_if(struct mbuf **mp, int *offp, int nxt, int af, struct ifnet *ifp)
+{
+       struct mbuf *m = *mp;
+       struct ip6_hdr *ip6;
+       struct sockaddr_in6 sin6;
+       struct rtentry *rt = NULL;
+       int ours = 0;
+       u_int16_t src_scope, dst_scope;
+#if NPF > 0
+       struct in6_addr odst;
+#endif
+       int srcrt = 0;
+
+       KASSERT(*offp == 0);
+
+       ip6stat_inc(ip6s_total);
+
+       m = *mp = ipv6_check(ifp, *mp);
+       if (m == NULL)
+               goto bad;
+
+#if NCARP > 0
+       if (carp_lsdrop(ifp, m, AF_INET6, ip6->ip6_src.s6_addr32,
+           ip6->ip6_dst.s6_addr32, (ip6->ip6_nxt == IPPROTO_ICMPV6 ? 0 : 1)))
+               goto bad;
+#endif
+       ip6stat_inc(ip6s_nxthist + ip6->ip6_nxt);
 
        /*
         * If the packet has been received on a loopback interface it

Reply via email to