Hi,

This removes some code duplication by merging the v4 and v6 input
functions.

Basically common code is moved into ipsec_protoff() which finds the
offset of the next protocol field in the previous header.

ok?

bluhm

Index: netinet/in_proto.c
===================================================================
RCS file: /cvs/src/sys/netinet/in_proto.c,v
retrieving revision 1.95
diff -u -p -r1.95 in_proto.c
--- netinet/in_proto.c  25 May 2021 22:45:09 -0000      1.95
+++ netinet/in_proto.c  24 Oct 2021 17:41:28 -0000
@@ -301,7 +301,7 @@ const struct protosw inetsw[] = {
   .pr_domain   = &inetdomain,
   .pr_protocol = IPPROTO_AH,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ah4_input,
+  .pr_input    = ah46_input,
   .pr_ctlinput = ah4_ctlinput,
   .pr_ctloutput        = rip_ctloutput,
   .pr_usrreq   = rip_usrreq,
@@ -314,7 +314,7 @@ const struct protosw inetsw[] = {
   .pr_domain   = &inetdomain,
   .pr_protocol = IPPROTO_ESP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = esp4_input,
+  .pr_input    = esp46_input,
   .pr_ctlinput = esp4_ctlinput,
   .pr_ctloutput        = rip_ctloutput,
   .pr_usrreq   = rip_usrreq,
@@ -327,7 +327,7 @@ const struct protosw inetsw[] = {
   .pr_domain   = &inetdomain,
   .pr_protocol = IPPROTO_IPCOMP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ipcomp4_input,
+  .pr_input    = ipcomp46_input,
   .pr_ctloutput        = rip_ctloutput,
   .pr_usrreq   = rip_usrreq,
   .pr_attach   = rip_attach,
Index: netinet/ip_ipsp.h
===================================================================
RCS file: /cvs/src/sys/netinet/ip_ipsp.h,v
retrieving revision 1.214
diff -u -p -r1.214 ip_ipsp.h
--- netinet/ip_ipsp.h   24 Oct 2021 17:08:27 -0000      1.214
+++ netinet/ip_ipsp.h   24 Oct 2021 17:41:29 -0000
@@ -577,14 +577,10 @@ int       ah_output_cb(struct tdb *, struct td
            int);
 int    ah_sysctl(int *, u_int, void *, size_t *, void *, size_t);
 
-int    ah4_input(struct mbuf **, int *, int, int);
+int    ah46_input(struct mbuf **, int *, int, int);
 void   ah4_ctlinput(int, struct sockaddr *, u_int, void *);
 void   udpencap_ctlinput(int, struct sockaddr *, u_int, void *);
 
-#ifdef INET6
-int    ah6_input(struct mbuf **, int *, int, int);
-#endif /* INET6 */
-
 /* XF_ESP */
 int    esp_attach(void);
 int    esp_init(struct tdb *, const struct xformsw *, struct ipsecinit *);
@@ -595,13 +591,9 @@ int        esp_input_cb(struct tdb *, uint8_t *
 int    esp_output(struct mbuf *, struct tdb *, int, int);
 int    esp_sysctl(int *, u_int, void *, size_t *, void *, size_t);
 
-int    esp4_input(struct mbuf **, int *, int, int);
+int    esp46_input(struct mbuf **, int *, int, int);
 void   esp4_ctlinput(int, struct sockaddr *, u_int, void *);
 
-#ifdef INET6
-int    esp6_input(struct mbuf **, int *, int, int);
-#endif /* INET6 */
-
 /* XF_IPCOMP */
 int    ipcomp_attach(void);
 int    ipcomp_init(struct tdb *, const struct xformsw *, struct ipsecinit *);
@@ -612,10 +604,7 @@ int        ipcomp_output(struct mbuf *, struct 
 int    ipcomp_output_cb(struct tdb *, struct tdb_crypto *, struct mbuf *, int,
            int);
 int    ipcomp_sysctl(int *, u_int, void *, size_t *, void *, size_t);
-int    ipcomp4_input(struct mbuf **, int *, int, int);
-#ifdef INET6
-int    ipcomp6_input(struct mbuf **, int *, int, int);
-#endif /* INET6 */
+int    ipcomp46_input(struct mbuf **, int *, int, int);
 
 /* XF_TCPSIGNATURE */
 int    tcp_signature_tdb_attach(void);
@@ -648,6 +637,8 @@ void        ipsec_init(void);
 int    ipsec_sysctl(int *, u_int, void *, size_t *, void *, size_t);
 int    ipsec_common_input(struct mbuf **, int, int, int, int, int);
 int    ipsec_common_input_cb(struct mbuf **, struct tdb *, int, int);
+int    ipsec_input_disabled(struct mbuf **, int *, int, int);
+int    ipsec_protoff(struct mbuf *, int, int);
 int    ipsec_delete_policy(struct ipsec_policy *);
 ssize_t        ipsec_hdrsz(struct tdb *);
 void   ipsec_adjust_mtu(struct mbuf *, u_int32_t);
Index: netinet/ipsec_input.c
===================================================================
RCS file: /cvs/src/sys/netinet/ipsec_input.c,v
retrieving revision 1.188
diff -u -p -r1.188 ipsec_input.c
--- netinet/ipsec_input.c       24 Oct 2021 17:08:27 -0000      1.188
+++ netinet/ipsec_input.c       24 Oct 2021 17:41:29 -0000
@@ -793,19 +793,42 @@ ipsec_sysctl_ipsecstat(void *oldp, size_
            sizeof(ipsecstat)));
 }
 
-/* IPv4 AH wrapper. */
 int
-ah4_input(struct mbuf **mp, int *offp, int proto, int af)
+ipsec_input_disabled(struct mbuf **mp, int *offp, int proto, int af)
 {
+       switch (af) {
+       case AF_INET:
+               return rip_input(mp, offp, proto, af);
+#ifdef INET6
+       case AF_INET6:
+               return rip6_input(mp, offp, proto, af);
+#endif
+       default:
+               unhandled_af(af);
+       }
+}
+
+int
+ah46_input(struct mbuf **mp, int *offp, int proto, int af)
+{
+       int protoff;
+
        if (
 #if NPF > 0
            ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
 #endif
            !ah_enable)
-               return rip_input(mp, offp, proto, af);
+               return ipsec_input_disabled(mp, offp, proto, af);
 
-       ipsec_common_input(mp, *offp, offsetof(struct ip, ip_p), AF_INET,
-           proto, 0);
+       protoff = ipsec_protoff(*mp, *offp, af);
+       if (protoff < 0) {
+               DPRINTF("bad packet header chain");
+               ahstat_inc(ahs_hdrops);
+               m_freemp(mp);
+               return IPPROTO_DONE;
+       }
+
+       ipsec_common_input(mp, *offp, protoff, af, proto, 0);
        return IPPROTO_DONE;
 }
 
@@ -819,35 +842,52 @@ ah4_ctlinput(int cmd, struct sockaddr *s
        ipsec_common_ctlinput(rdomain, cmd, sa, v, IPPROTO_AH);
 }
 
-/* IPv4 ESP wrapper. */
 int
-esp4_input(struct mbuf **mp, int *offp, int proto, int af)
+esp46_input(struct mbuf **mp, int *offp, int proto, int af)
 {
+       int protoff;
+
        if (
 #if NPF > 0
            ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
 #endif
            !esp_enable)
-               return rip_input(mp, offp, proto, af);
+               return ipsec_input_disabled(mp, offp, proto, af);
+
+       protoff = ipsec_protoff(*mp, *offp, af);
+       if (protoff < 0) {
+               DPRINTF("bad packet header chain");
+               espstat_inc(esps_hdrops);
+               m_freemp(mp);
+               return IPPROTO_DONE;
+       }
 
-       ipsec_common_input(mp, *offp, offsetof(struct ip, ip_p), AF_INET,
-           proto, 0);
+       ipsec_common_input(mp, *offp, protoff, af, proto, 0);
        return IPPROTO_DONE;
 }
 
 /* IPv4 IPCOMP wrapper */
 int
-ipcomp4_input(struct mbuf **mp, int *offp, int proto, int af)
+ipcomp46_input(struct mbuf **mp, int *offp, int proto, int af)
 {
+       int protoff;
+
        if (
 #if NPF > 0
            ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
 #endif
            !ipcomp_enable)
-               return rip_input(mp, offp, proto, af);
+               return ipsec_input_disabled(mp, offp, proto, af);
 
-       ipsec_common_input(mp, *offp, offsetof(struct ip, ip_p), AF_INET,
-           proto, 0);
+       protoff = ipsec_protoff(*mp, *offp, af);
+       if (protoff < 0) {
+               DPRINTF("bad packet header chain");
+               ipcompstat_inc(ipcomps_hdrops);
+               m_freemp(mp);
+               return IPPROTO_DONE;
+       }
+
+       ipsec_common_input(mp, *offp, protoff, af, proto, 0);
        return IPPROTO_DONE;
 }
 
@@ -969,179 +1009,58 @@ esp4_ctlinput(int cmd, struct sockaddr *
        ipsec_common_ctlinput(rdomain, cmd, sa, v, IPPROTO_ESP);
 }
 
-#ifdef INET6
-/* IPv6 AH wrapper. */
 int
-ah6_input(struct mbuf **mp, int *offp, int proto, int af)
+ipsec_protoff(struct mbuf *m, int off, int af)
 {
-       int l = 0;
-       int protoff, nxt;
        struct ip6_ext ip6e;
+       int protoff, nxt, l;
 
-       if (
-#if NPF > 0
-           ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
-#endif
-           !ah_enable)
-               return rip6_input(mp, offp, proto, af);
-
-       if (*offp < sizeof(struct ip6_hdr)) {
-               DPRINTF("bad offset");
-               ahstat_inc(ahs_hdrops);
-               m_freemp(mp);
-               return IPPROTO_DONE;
-       } else if (*offp == sizeof(struct ip6_hdr)) {
-               protoff = offsetof(struct ip6_hdr, ip6_nxt);
-       } else {
-               /* Chase down the header chain... */
-               protoff = sizeof(struct ip6_hdr);
-               nxt = (mtod(*mp, struct ip6_hdr *))->ip6_nxt;
-
-               do {
-                       protoff += l;
-                       m_copydata(*mp, protoff, sizeof(ip6e),
-                           (caddr_t) &ip6e);
-
-                       if (nxt == IPPROTO_AH)
-                               l = (ip6e.ip6e_len + 2) << 2;
-                       else
-                               l = (ip6e.ip6e_len + 1) << 3;
-#ifdef DIAGNOSTIC
-                       if (l <= 0)
-                               panic("ah6_input: l went zero or negative");
-#endif
-
-                       nxt = ip6e.ip6e_nxt;
-               } while (protoff + l < *offp);
-
-               /* Malformed packet check */
-               if (protoff + l != *offp) {
-                       DPRINTF("bad packet header chain");
-                       ahstat_inc(ahs_hdrops);
-                       m_freemp(mp);
-                       return IPPROTO_DONE;
-               }
-               protoff += offsetof(struct ip6_ext, ip6e_nxt);
-       }
-       ipsec_common_input(mp, *offp, protoff, AF_INET6, proto, 0);
-       return IPPROTO_DONE;
-}
-
-/* IPv6 ESP wrapper. */
-int
-esp6_input(struct mbuf **mp, int *offp, int proto, int af)
-{
-       int l = 0;
-       int protoff, nxt;
-       struct ip6_ext ip6e;
-
-       if (
-#if NPF > 0
-           ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
-#endif
-           !esp_enable)
-               return rip6_input(mp, offp, proto, af);
-
-       if (*offp < sizeof(struct ip6_hdr)) {
-               DPRINTF("bad offset");
-               espstat_inc(esps_hdrops);
-               m_freemp(mp);
-               return IPPROTO_DONE;
-       } else if (*offp == sizeof(struct ip6_hdr)) {
-               protoff = offsetof(struct ip6_hdr, ip6_nxt);
-       } else {
-               /* Chase down the header chain... */
-               protoff = sizeof(struct ip6_hdr);
-               nxt = (mtod(*mp, struct ip6_hdr *))->ip6_nxt;
-
-               do {
-                       protoff += l;
-                       m_copydata(*mp, protoff, sizeof(ip6e),
-                           (caddr_t) &ip6e);
-
-                       if (nxt == IPPROTO_AH)
-                               l = (ip6e.ip6e_len + 2) << 2;
-                       else
-                               l = (ip6e.ip6e_len + 1) << 3;
-#ifdef DIAGNOSTIC
-                       if (l <= 0)
-                               panic("esp6_input: l went zero or negative");
+       switch (af) {
+       case AF_INET:
+               return offsetof(struct ip, ip_p);
+#ifdef INET6
+       case AF_INET6:
+               break;
 #endif
-
-                       nxt = ip6e.ip6e_nxt;
-               } while (protoff + l < *offp);
-
-               /* Malformed packet check */
-               if (protoff + l != *offp) {
-                       DPRINTF("bad packet header chain");
-                       espstat_inc(esps_hdrops);
-                       m_freemp(mp);
-                       return IPPROTO_DONE;
-               }
-               protoff += offsetof(struct ip6_ext, ip6e_nxt);
+       default:
+               unhandled_af(af);
        }
-       ipsec_common_input(mp, *offp, protoff, AF_INET6, proto, 0);
-       return IPPROTO_DONE;
 
-}
+       if (off < sizeof(struct ip6_hdr))
+               return -1;
 
-/* IPv6 IPcomp wrapper */
-int
-ipcomp6_input(struct mbuf **mp, int *offp, int proto, int af)
-{
-       int l = 0;
-       int protoff, nxt;
-       struct ip6_ext ip6e;
+       if (off == sizeof(struct ip6_hdr))
+               return offsetof(struct ip6_hdr, ip6_nxt);
 
-       if (
-#if NPF > 0
-           ((*mp)->m_pkthdr.pf.flags & PF_TAG_DIVERTED) ||
-#endif
-           !ipcomp_enable)
-               return rip6_input(mp, offp, proto, af);
-
-       if (*offp < sizeof(struct ip6_hdr)) {
-               DPRINTF("bad offset");
-               ipcompstat_inc(ipcomps_hdrops);
-               m_freemp(mp);
-               return IPPROTO_DONE;
-       } else if (*offp == sizeof(struct ip6_hdr)) {
-               protoff = offsetof(struct ip6_hdr, ip6_nxt);
-       } else {
-               /* Chase down the header chain... */
-               protoff = sizeof(struct ip6_hdr);
-               nxt = (mtod(*mp, struct ip6_hdr *))->ip6_nxt;
-
-               do {
-                       protoff += l;
-                       m_copydata(*mp, protoff, sizeof(ip6e),
-                           (caddr_t) &ip6e);
-                       if (nxt == IPPROTO_AH)
-                               l = (ip6e.ip6e_len + 2) << 2;
-                       else
-                               l = (ip6e.ip6e_len + 1) << 3;
+       /* Chase down the header chain... */
+       protoff = sizeof(struct ip6_hdr);
+       nxt = (mtod(m, struct ip6_hdr *))->ip6_nxt;
+       l = 0;
+
+       do {
+               protoff += l;
+               m_copydata(m, protoff, sizeof(ip6e),
+                   (caddr_t) &ip6e);
+
+               if (nxt == IPPROTO_AH)
+                       l = (ip6e.ip6e_len + 2) << 2;
+               else
+                       l = (ip6e.ip6e_len + 1) << 3;
 #ifdef DIAGNOSTIC
-                       if (l <= 0)
-                               panic("l went zero or negative");
+               if (l <= 0)
+                       panic("ah6_input: l went zero or negative");
 #endif
 
-                       nxt = ip6e.ip6e_nxt;
-               } while (protoff + l < *offp);
+               nxt = ip6e.ip6e_nxt;
+       } while (protoff + l < off);
 
-               /* Malformed packet check */
-               if (protoff + l != *offp) {
-                       DPRINTF("bad packet header chain");
-                       ipcompstat_inc(ipcomps_hdrops);
-                       m_freemp(mp);
-                       return IPPROTO_DONE;
-               }
+       /* Malformed packet check */
+       if (protoff + l != off)
+               return -1;
 
-               protoff += offsetof(struct ip6_ext, ip6e_nxt);
-       }
-       ipsec_common_input(mp, *offp, protoff, AF_INET6, proto, 0);
-       return IPPROTO_DONE;
+       protoff += offsetof(struct ip6_ext, ip6e_nxt);
+       return protoff;
 }
-#endif /* INET6 */
 
 int
 ipsec_forward_check(struct mbuf *m, int hlen, int af)
Index: netinet6/in6_proto.c
===================================================================
RCS file: /cvs/src/sys/netinet6/in6_proto.c,v
retrieving revision 1.105
diff -u -p -r1.105 in6_proto.c
--- netinet6/in6_proto.c        25 May 2021 22:45:10 -0000      1.105
+++ netinet6/in6_proto.c        24 Oct 2021 17:41:30 -0000
@@ -214,7 +214,7 @@ const struct protosw inet6sw[] = {
   .pr_domain   = &inet6domain,
   .pr_protocol = IPPROTO_AH,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ah6_input,
+  .pr_input    = ah46_input,
   .pr_ctloutput        = rip6_ctloutput,
   .pr_usrreq   = rip6_usrreq,
   .pr_attach   = rip6_attach,
@@ -226,7 +226,7 @@ const struct protosw inet6sw[] = {
   .pr_domain   = &inet6domain,
   .pr_protocol = IPPROTO_ESP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = esp6_input,
+  .pr_input    = esp46_input,
   .pr_ctloutput        = rip6_ctloutput,
   .pr_usrreq   = rip6_usrreq,
   .pr_attach   = rip6_attach,
@@ -238,7 +238,7 @@ const struct protosw inet6sw[] = {
   .pr_domain   = &inet6domain,
   .pr_protocol = IPPROTO_IPCOMP,
   .pr_flags    = PR_ATOMIC|PR_ADDR,
-  .pr_input    = ipcomp6_input,
+  .pr_input    = ipcomp46_input,
   .pr_ctloutput        = rip6_ctloutput,
   .pr_usrreq   = rip6_usrreq,
   .pr_attach   = rip6_attach,

Reply via email to