Brief background: divert(4) sockets can be used to send packets to a
userspace program.  The program can inspect the packets and decide to
either reinject them back into the kernel or drop them.

According to the divert(4) man page, "The packets' checksums are
recalculated upon reinjection."  This makes sense, because the userspace
program could have modified the packet.

However, in my tests, I found that the checksums are not actually
recalculated upon reinjection.  I ran into this bug when trying to use
divert-packet PF rules together with nat-to and rdr-to, e.g.:

    match out on em5 inet nat-to (em5:0)
    pass out on em5 proto tcp to port 80 divert-packet port 700
    pass in on em5 proto tcp to port 22 divert-packet port 700 rdr-to 
192.168.30.8

With the above rules, inbound packets would drop and "netstat -p ip -s"
shows that "bad header checksums" consistently increments by one for
every inbound packet.

The diff below fixes this bug by making divert(4) recalculate the IP(v4)
and TCP/UDP/ICMP/ICMPv6 checksums of reinjected packets on both IPv4 and
IPv6 (done with a ton of help and feedback from blambert@ who reviewed
many versions of this diff, thank you!).

If you are using divert(4), could you please test this diff to ensure
that it does not break your setup?  Note that this diff only applies to
divert-packet PF rules, not divert-to/divert-reply.

I am also looking for feedback from developers to review the approach
and code that was used to fix this bug.

Thank you,
Lawrence


Index: netinet/ip_divert.c
===================================================================
RCS file: /cvs/src/sys/netinet/ip_divert.c,v
retrieving revision 1.10
diff -u -p -r1.10 ip_divert.c
--- netinet/ip_divert.c 21 Oct 2012 13:06:03 -0000      1.10
+++ netinet/ip_divert.c 4 Mar 2013 16:14:33 -0000
@@ -37,6 +37,9 @@
 #include <netinet/ip_var.h>
 #include <netinet/in_pcb.h>
 #include <netinet/ip_divert.h>
+#include <netinet/tcp.h>
+#include <netinet/udp.h>
+#include <netinet/ip_icmp.h>
 
 struct inpcbtable      divbtable;
 struct divstat         divstat;
@@ -83,8 +86,12 @@ divert_output(struct mbuf *m, ...)
        struct sockaddr_in *sin;
        struct socket *so;
        struct ifaddr *ifa;
-       int s, error = 0;
+       int s, error = 0, p_hdrlen = 0;
        va_list ap;
+       struct ip *ip;
+       u_int16_t off, csum = 0;
+       u_int8_t nxt;
+       size_t p_off = 0;
 
        va_start(ap, m);
        inp = va_arg(ap, struct inpcb *);
@@ -102,15 +109,68 @@ divert_output(struct mbuf *m, ...)
        sin = mtod(nam, struct sockaddr_in *);
        so = inp->inp_socket;
 
+       /* Do basic sanity checks. */
+       if (m->m_pkthdr.len < sizeof(struct ip))
+               goto fail;
+       if ((m = m_pullup(m, sizeof(struct ip))) == NULL) {
+               /* m_pullup() has freed the mbuf, so just return. */
+               divstat.divs_errors++;
+               return (ENOBUFS);
+       }
+       ip = mtod(m, struct ip *);
+       if (ip->ip_v != IPVERSION)
+               goto fail;
+       off = ip->ip_hl << 2;
+       if (off < sizeof(struct ip) || ntohs(ip->ip_len) < off ||
+           m->m_pkthdr.len < ntohs(ip->ip_len))
+               goto fail;
+
+       /*
+        * Recalculate IP and protocol checksums since the userspace application
+        * may have modified the packet prior to reinjection.
+        */
+       ip->ip_sum = 0;
+       ip->ip_sum = in_cksum(m, off);
+       nxt = ip->ip_p;
+       switch (ip->ip_p) {
+       case IPPROTO_TCP:
+               p_hdrlen = sizeof(struct tcphdr);
+               p_off = offsetof(struct tcphdr, th_sum);
+               break;
+       case IPPROTO_UDP:
+               p_hdrlen = sizeof(struct udphdr);
+               p_off = offsetof(struct udphdr, uh_sum);
+               break;
+       case IPPROTO_ICMP:
+               p_hdrlen = sizeof(struct icmp);
+               p_off = offsetof(struct icmp, icmp_cksum);
+               nxt = 0;
+               break;
+       default:
+               /* nothing */
+               break;
+       }
+       if (p_hdrlen) {
+               if (m->m_pkthdr.len < off + p_hdrlen)
+                       goto fail;
+
+               if ((error = m_copyback(m, off + p_off, sizeof(csum), &csum, 
M_NOWAIT)))
+                       goto fail;
+               csum = in4_cksum(m, nxt, off, m->m_pkthdr.len - off);
+               if (ip->ip_p == IPPROTO_UDP && csum == 0)
+                       csum = 0xffff;
+               if ((error = m_copyback(m, off + p_off, sizeof(csum), &csum, 
M_NOWAIT)))
+                       goto fail;
+       }
+
        m->m_pkthdr.pf.flags |= PF_TAG_DIVERTED_PACKET;
 
        if (sin->sin_addr.s_addr != INADDR_ANY) {
                ipaddr.sin_addr = sin->sin_addr;
                ifa = ifa_ifwithaddr(sintosa(&ipaddr), m->m_pkthdr.rdomain);
                if (ifa == NULL) {
-                       divstat.divs_errors++;
-                       m_freem(m);
-                       return (EADDRNOTAVAIL);
+                       error = EADDRNOTAVAIL;
+                       goto fail;
                }
                m->m_pkthdr.rcvif = ifa->ifa_ifp;
 
@@ -131,6 +191,11 @@ divert_output(struct mbuf *m, ...)
 
        divstat.divs_opackets++;
        return (error);
+
+fail:
+       m_freem(m);
+       divstat.divs_errors++;
+       return (error ? error : EINVAL);
 }
 
 int
Index: netinet6/ip6_divert.c
===================================================================
RCS file: /cvs/src/sys/netinet6/ip6_divert.c,v
retrieving revision 1.8
diff -u -p -r1.8 ip6_divert.c
--- netinet6/ip6_divert.c       6 Nov 2012 12:32:42 -0000       1.8
+++ netinet6/ip6_divert.c       4 Mar 2013 16:15:14 -0000
@@ -40,6 +40,9 @@
 #include <netinet6/in6_var.h>
 #include <netinet6/in6_var.h>
 #include <netinet6/ip6_divert.h>
+#include <netinet/tcp.h>
+#include <netinet/udp.h>
+#include <netinet/icmp6.h>
 
 struct inpcbtable      divb6table;
 struct div6stat        div6stat;
@@ -88,8 +91,16 @@ divert6_output(struct mbuf *m, ...)
        struct sockaddr_in6 *sin6;
        struct socket *so;
        struct ifaddr *ifa;
-       int s, error = 0;
+       int s, error = 0, p_hdrlen = 0, rthdr_cnt = 0;
        va_list ap;
+       struct ip6_hdr *ip6;
+       u_int8_t nxt;
+       u_int16_t csum = 0;
+       u_int32_t off;
+       size_t p_off = 0;
+       struct ip6_frag frag;
+       struct ip6_ext ext;
+       struct ip6_rthdr rthdr;
 
        va_start(ap, m);
        inp = va_arg(ap, struct inpcb *);
@@ -107,15 +118,99 @@ divert6_output(struct mbuf *m, ...)
        sin6 = mtod(nam, struct sockaddr_in6 *);
        so = inp->inp_socket;
 
+       /* Do basic sanity checks. */
+       off = sizeof(struct ip6_hdr);
+       if (m->m_pkthdr.len < off)
+               goto fail;
+       if ((m = m_pullup(m, sizeof(struct ip6_hdr))) == NULL) {
+               /* m_pullup() has freed the mbuf, so just return. */
+               div6stat.divs_errors++;
+               return (ENOBUFS);
+       }
+       ip6 = mtod(m, struct ip6_hdr *);
+       if ((ip6->ip6_vfc & IPV6_VERSION_MASK) != IPV6_VERSION)
+               goto fail;
+       if (m->m_pkthdr.len < off + ntohs(ip6->ip6_plen))
+               goto fail;
+
+       /*
+        * Walk the IPv6 header chain to access and recalculate the
+        * protocol checksum, since the userspace application may have
+        * modified the packet prior to reinjection.  The header chain
+        * walking logic is borrowed from pf_walk_header6().
+        */
+       nxt = ip6->ip6_nxt;
+       for (;;) {
+               switch (nxt) {
+               case IPPROTO_FRAGMENT:
+                       if (m->m_pkthdr.len < off + sizeof(frag))
+                               goto fail;
+                       m_copydata(m, off, sizeof(frag), (caddr_t)&frag);
+                       /* stop walking over non initial fragments */
+                       if (ntohs((frag.ip6f_offlg & IP6F_OFF_MASK)) != 0)
+                               goto done;
+                       off += sizeof(frag);
+                       nxt = frag.ip6f_nxt;
+                       break;
+               case IPPROTO_ROUTING:
+                       if (rthdr_cnt++)
+                               goto fail;
+                       if (m->m_pkthdr.len < off + sizeof(rthdr))
+                               goto fail;
+                       m_copydata(m, off, sizeof(rthdr), (caddr_t)&rthdr);
+                       if (rthdr.ip6r_type == IPV6_RTHDR_TYPE_0)
+                               goto fail;
+                       /* FALLTHROUGH */
+               case IPPROTO_AH:
+               case IPPROTO_HOPOPTS:
+               case IPPROTO_DSTOPTS:
+                       if (m->m_pkthdr.len < off + sizeof(ext))
+                               goto fail;
+                       m_copydata(m, off, sizeof(ext), (caddr_t)&ext);
+                       if (nxt == IPPROTO_AH)
+                               off += (ext.ip6e_len + 2) * 4;
+                       else
+                               off += (ext.ip6e_len + 1) * 8;
+                       nxt = ext.ip6e_nxt;
+                       break;
+               case IPPROTO_TCP:
+                       p_hdrlen = sizeof(struct tcphdr);
+                       p_off = offsetof(struct tcphdr, th_sum);
+                       goto done;
+               case IPPROTO_UDP:
+                       p_hdrlen = sizeof(struct udphdr);
+                       p_off = offsetof(struct udphdr, uh_sum);
+                       goto done;
+               case IPPROTO_ICMPV6:
+                       p_hdrlen = sizeof(struct icmp6_hdr);
+                       p_off = offsetof(struct icmp6_hdr, icmp6_cksum);
+                       goto done;
+               default:
+                       goto done;
+               }
+       }
+done:
+       if (p_hdrlen) {
+               if (m->m_pkthdr.len < off + p_hdrlen)
+                       goto fail;
+
+               if ((error = m_copyback(m, off + p_off, sizeof(csum), &csum, 
M_NOWAIT)))
+                       goto fail;
+               csum = in6_cksum(m, nxt, off, m->m_pkthdr.len - off);
+               if (nxt == IPPROTO_UDP && csum == 0)
+                       csum = 0xffff;
+               if ((error = m_copyback(m, off + p_off, sizeof(csum), &csum, 
M_NOWAIT)))
+                       goto fail;
+       }
+
        m->m_pkthdr.pf.flags |= PF_TAG_DIVERTED_PACKET;
 
        if (!IN6_IS_ADDR_UNSPECIFIED(&sin6->sin6_addr)) {
                ip6addr.sin6_addr = sin6->sin6_addr;
                ifa = ifa_ifwithaddr(sin6tosa(&ip6addr), m->m_pkthdr.rdomain);
                if (ifa == NULL) {
-                       div6stat.divs_errors++;
-                       m_freem(m);
-                       return (EADDRNOTAVAIL);
+                       error = EADDRNOTAVAIL;
+                       goto fail;
                }
                m->m_pkthdr.rcvif = ifa->ifa_ifp;
 
@@ -133,6 +228,11 @@ divert6_output(struct mbuf *m, ...)
 
        div6stat.divs_opackets++;
        return (error);
+
+fail:
+       div6stat.divs_errors++;
+       m_freem(m);
+       return (error ? error : EINVAL);
 }
 
 int

Reply via email to