Hi,

Convert ip_input(), ip_our(), ip_deliver() functions to pr_input
parameter passing and protocol return style.  Reset mp to NULL in
a few places to fail at mbuf use after free.  Rename ipv4_input()
to ip_input().

Goal is to prepare the code that both mpi@'s and bluhm@'s diff
apply.

ok?

bluhm

Index: netinet/ip_input.c
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/ip_input.c,v
retrieving revision 1.306
diff -u -p -r1.306 ip_input.c
--- netinet/ip_input.c  28 May 2017 12:22:54 -0000      1.306
+++ netinet/ip_input.c  29 May 2017 21:38:51 -0000
@@ -126,7 +126,7 @@ int ip_sysctl_ipstat(void *, size_t *, v
 
 static struct mbuf_queue       ipsend_mq;
 
-void   ip_ours(struct mbuf *);
+int    ip_ours(struct mbuf **, int *, int, int);
 int    ip_dooptions(struct mbuf *, struct ifnet *);
 int    in_ouraddr(struct mbuf *, struct ifnet *, struct rtentry **);
 
@@ -211,6 +211,7 @@ void
 ipintr(void)
 {
        struct mbuf *m;
+       int off;
 
        /*
         * Get next datagram off input queue and get IP header
@@ -221,7 +222,8 @@ ipintr(void)
                if ((m->m_flags & M_PKTHDR) == 0)
                        panic("ipintr no HDR");
 #endif
-               ipv4_input(m);
+               off = 0;
+               ip_input(&m, &off, IPPROTO_IPV4, AF_UNSPEC);
        }
 }
 
@@ -230,39 +232,42 @@ ipintr(void)
  *
  * Checksum and byte swap header.  Process options. Forward or deliver.
  */
-void
-ipv4_input(struct mbuf *m)
+int
+ip_input(struct mbuf **mp, int *offp, int nxt, int af)
 {
+       struct mbuf     *m = *mp;
        struct ifnet    *ifp;
        struct rtentry  *rt = NULL;
        struct ip       *ip;
        int hlen, len;
        in_addr_t pfrdr = 0;
 
+       KASSERT(*offp == 0);
+
        ifp = if_get(m->m_pkthdr.ph_ifidx);
        if (ifp == NULL)
-               goto bad;
+               goto done;
 
        ipstat_inc(ips_total);
        if (m->m_len < sizeof (struct ip) &&
-           (m = m_pullup(m, sizeof (struct ip))) == NULL) {
+           (m = *mp = m_pullup(m, sizeof (struct ip))) == NULL) {
                ipstat_inc(ips_toosmall);
-               goto out;
+               goto done;
        }
        ip = mtod(m, struct ip *);
        if (ip->ip_v != IPVERSION) {
                ipstat_inc(ips_badvers);
-               goto bad;
+               goto done;
        }
        hlen = ip->ip_hl << 2;
        if (hlen < sizeof(struct ip)) { /* minimum header length */
                ipstat_inc(ips_badhlen);
-               goto bad;
+               goto done;
        }
        if (hlen > m->m_len) {
-               if ((m = m_pullup(m, hlen)) == NULL) {
+               if ((m = *mp = m_pullup(m, hlen)) == NULL) {
                        ipstat_inc(ips_badhlen);
-                       goto out;
+                       goto done;
                }
                ip = mtod(m, struct ip *);
        }
@@ -272,20 +277,20 @@ ipv4_input(struct mbuf *m)
            (ntohl(ip->ip_src.s_addr) >> IN_CLASSA_NSHIFT) == IN_LOOPBACKNET) {
                if ((ifp->if_flags & IFF_LOOPBACK) == 0) {
                        ipstat_inc(ips_badaddr);
-                       goto bad;
+                       goto done;
                }
        }
 
        if ((m->m_pkthdr.csum_flags & M_IPV4_CSUM_IN_OK) == 0) {
                if (m->m_pkthdr.csum_flags & M_IPV4_CSUM_IN_BAD) {
                        ipstat_inc(ips_badsum);
-                       goto bad;
+                       goto done;
                }
 
                ipstat_inc(ips_inswcsum);
                if (in_cksum(m, hlen) != 0) {
                        ipstat_inc(ips_badsum);
-                       goto bad;
+                       goto done;
                }
        }
 
@@ -297,7 +302,7 @@ ipv4_input(struct mbuf *m)
         */
        if (len < hlen) {
                ipstat_inc(ips_badlen);
-               goto bad;
+               goto done;
        }
 
        /*
@@ -308,7 +313,7 @@ ipv4_input(struct mbuf *m)
         */
        if (m->m_pkthdr.len < len) {
                ipstat_inc(ips_tooshort);
-               goto bad;
+               goto done;
        }
        if (m->m_pkthdr.len > len) {
                if (m->m_len == m->m_pkthdr.len) {
@@ -321,7 +326,7 @@ ipv4_input(struct mbuf *m)
 #if NCARP > 0
        if (ifp->if_type == IFT_CARP && ip->ip_p != IPPROTO_ICMP &&
            carp_lsdrop(m, AF_INET, &ip->ip_src.s_addr, &ip->ip_dst.s_addr))
-               goto bad;
+               goto done;
 #endif
 
 #if NPF > 0
@@ -329,10 +334,11 @@ ipv4_input(struct mbuf *m)
         * Packet filter
         */
        pfrdr = ip->ip_dst.s_addr;
-       if (pf_test(AF_INET, PF_IN, ifp, &m) != PF_PASS)
-               goto bad;
+       if (pf_test(AF_INET, PF_IN, ifp, mp) != PF_PASS)
+               goto done;
+       m = *mp;
        if (m == NULL)
-               goto out;
+               goto done;
 
        ip = mtod(m, struct ip *);
        hlen = ip->ip_hl << 2;
@@ -346,17 +352,18 @@ ipv4_input(struct mbuf *m)
         * to be sent and the original packet to be freed).
         */
        if (hlen > sizeof (struct ip) && ip_dooptions(m, ifp)) {
-               goto out;
+               m = *mp = NULL;
+               goto done;
        }
 
        if (ip->ip_dst.s_addr == INADDR_BROADCAST ||
            ip->ip_dst.s_addr == INADDR_ANY) {
-               ip_ours(m);
+               nxt = ip_ours(mp, offp, nxt, af);
                goto out;
        }
 
        if (in_ouraddr(m, ifp, &rt)) {
-               ip_ours(m);
+               nxt = ip_ours(mp, offp, nxt, af);
                goto out;
        }
 
@@ -373,9 +380,9 @@ ipv4_input(struct mbuf *m)
                        int rv;
 
                        if (m->m_flags & M_EXT) {
-                               if ((m = m_pullup(m, hlen)) == NULL) {
+                               if ((m = *mp = m_pullup(m, hlen)) == NULL) {
                                        ipstat_inc(ips_toosmall);
-                                       goto out;
+                                       goto done;
                                }
                                ip = mtod(m, struct ip *);
                        }
@@ -396,7 +403,7 @@ ipv4_input(struct mbuf *m)
                        KERNEL_UNLOCK();
                        if (rv != 0) {
                                ipstat_inc(ips_cantforward);
-                               goto bad;
+                               goto done;
                        }
 
                        /*
@@ -405,7 +412,7 @@ ipv4_input(struct mbuf *m)
                         * host belongs to their destination groups.
                         */
                        if (ip->ip_p == IPPROTO_IGMP) {
-                               ip_ours(m);
+                               nxt = ip_ours(mp, offp, nxt, af);
                                goto out;
                        }
                        ipstat_inc(ips_forward);
@@ -419,23 +426,23 @@ ipv4_input(struct mbuf *m)
                        ipstat_inc(ips_notmember);
                        if (!IN_LOCAL_GROUP(ip->ip_dst.s_addr))
                                ipstat_inc(ips_cantforward);
-                       goto bad;
+                       goto done;
                }
-               ip_ours(m);
+               nxt = ip_ours(mp, offp, nxt, af);
                goto out;
        }
 
 #if NCARP > 0
        if (ifp->if_type == IFT_CARP && ip->ip_p == IPPROTO_ICMP &&
            carp_lsdrop(m, AF_INET, &ip->ip_src.s_addr, &ip->ip_dst.s_addr))
-               goto bad;
+               goto done;
 #endif
        /*
         * Not for us; forward if possible and desirable.
         */
        if (ipforwarding == 0) {
                ipstat_inc(ips_cantforward);
-               goto bad;
+               goto done;
        }
 #ifdef IPSEC
        if (ipsec_in_use) {
@@ -446,7 +453,7 @@ ipv4_input(struct mbuf *m)
                KERNEL_UNLOCK();
                if (rv != 0) {
                        ipstat_inc(ips_cantforward);
-                       goto bad;
+                       goto done;
                }
                /*
                 * Fall through, forward packet. Outbound IPsec policy
@@ -456,13 +463,17 @@ ipv4_input(struct mbuf *m)
 #endif /* IPSEC */
 
        ip_forward(m, ifp, rt, pfrdr);
+       *mp = NULL;
        if_put(ifp);
-       return;
-bad:
-       m_freem(m);
-out:
+       return IPPROTO_DONE;
+ done:
+       nxt = IPPROTO_DONE;
+       m_freem(*mp);
+       *mp = NULL;
+ out:
        rtfree(rt);
        if_put(ifp);
+       return nxt;
 }
 
 /*
@@ -470,9 +481,10 @@ out:
  *
  * If fragmented try to reassemble.  Pass to next level.
  */
-void
-ip_ours(struct mbuf *m)
+int
+ip_ours(struct mbuf **mp, int *offp, int nxt, int af)
 {
+       struct mbuf *m = *mp;
        struct ip *ip = mtod(m, struct ip *);
        struct ipq *fp;
        struct ipqent *ipqe;
@@ -489,9 +501,9 @@ ip_ours(struct mbuf *m)
         */
        if (ip->ip_off &~ htons(IP_DF | IP_RF)) {
                if (m->m_flags & M_EXT) {               /* XXX */
-                       if ((m = m_pullup(m, hlen)) == NULL) {
+                       if ((m = *mp = m_pullup(m, hlen)) == NULL) {
                                ipstat_inc(ips_toosmall);
-                               return;
+                               goto done;
                        }
                        ip = mtod(m, struct ip *);
                }
@@ -524,7 +536,7 @@ found:
                        if (ntohs(ip->ip_len) == 0 ||
                            (ntohs(ip->ip_len) & 0x7) != 0) {
                                ipstat_inc(ips_badfrags);
-                               goto bad;
+                               goto done;
                        }
                }
                ip->ip_off = htons(ntohs(ip->ip_off) << 3);
@@ -539,22 +551,21 @@ found:
                        if (ip_frags + 1 > ip_maxqueue) {
                                ip_flush();
                                ipstat_inc(ips_rcvmemdrop);
-                               goto bad;
+                               goto done;
                        }
 
                        ipqe = pool_get(&ipqent_pool, PR_NOWAIT);
                        if (ipqe == NULL) {
                                ipstat_inc(ips_rcvmemdrop);
-                               goto bad;
+                               goto done;
                        }
                        ip_frags++;
                        ipqe->ipqe_mff = mff;
                        ipqe->ipqe_m = m;
                        ipqe->ipqe_ip = ip;
-                       m = ip_reass(ipqe, fp);
-                       if (m == NULL) {
-                               return;
-                       }
+                       m = *mp = ip_reass(ipqe, fp);
+                       if (m == NULL)
+                               goto done;
                        ipstat_inc(ips_reassembled);
                        ip = mtod(m, struct ip *);
                        hlen = ip->ip_hl << 2;
@@ -564,13 +575,15 @@ found:
                                ip_freef(fp);
        }
 
-       ip_deliver(&m, &hlen, ip->ip_p, AF_INET);
-       return;
-bad:
-       m_freem(m);
+       *offp = hlen;
+       return ip_deliver(mp, offp, ip->ip_p, AF_INET);
+ done:
+       m_freem(*mp);
+       *mp = NULL;
+       return IPPROTO_DONE;
 }
 
-void
+int
 ip_deliver(struct mbuf **mp, int *offp, int nxt, int af)
 {
        KERNEL_ASSERT_LOCKED();
@@ -582,7 +595,7 @@ ip_deliver(struct mbuf **mp, int *offp, 
        if (ipsec_in_use) {
                if (ipsec_local_check(*mp, *offp, nxt, af) != 0) {
                        ipstat_inc(ips_cantforward);
-                       goto bad;
+                       goto done;
                }
        }
        /* Otherwise, just fall through and deliver the packet */
@@ -594,11 +607,13 @@ ip_deliver(struct mbuf **mp, int *offp, 
        ipstat_inc(ips_delivered);
        nxt = (*inetsw[ip_protox[nxt]].pr_input)(mp, offp, nxt, af);
        KASSERT(nxt == IPPROTO_DONE);
-       return;
+       return nxt;
 #ifdef IPSEC
- bad:
+ done:
 #endif
        m_freem(*mp);
+       *mp = NULL;
+       return IPPROTO_DONE;
 }
 
 int
@@ -867,7 +882,7 @@ dropfrag:
        m_freem(m);
        pool_put(&ipqent_pool, ipqe);
        ip_frags--;
-       return (0);
+       return (NULL);
 }
 
 /*
Index: netinet/ip_var.h
===================================================================
RCS file: /data/mirror/openbsd/cvs/src/sys/netinet/ip_var.h,v
retrieving revision 1.76
diff -u -p -r1.76 ip_var.h
--- netinet/ip_var.h    28 May 2017 09:25:51 -0000      1.76
+++ netinet/ip_var.h    29 May 2017 21:17:35 -0000
@@ -248,8 +248,8 @@ int  ip_sysctl(int *, u_int, void *, siz
 void    ip_savecontrol(struct inpcb *, struct mbuf **, struct ip *,
            struct mbuf *);
 void    ipintr(void);
-void    ipv4_input(struct mbuf *);
-void    ip_deliver(struct mbuf **, int *, int, int);
+int     ip_input(struct mbuf **, int *, int, int);
+int     ip_deliver(struct mbuf **, int *, int, int);
 void    ip_forward(struct mbuf *, struct ifnet *, struct rtentry *, int);
 int     rip_ctloutput(int, struct socket *, int, int, struct mbuf *);
 void    rip_init(void);

Reply via email to