Each tdb (SA) bundle will receive a "flow" identificator that will
be reassigned to the newly established SAs upon rekeying.  Later
this will be passed as IP_IPSECFLOWINFO control message to userland.
Discussed with with Markus and Bret Lambert.  OK?


---
 sys/netinet/ip_ipsp.c | 101 +++++++++++++++++++++++++++++++++++++++++++++++++-
 sys/netinet/ip_ipsp.h |   3 ++
 2 files changed, 103 insertions(+), 1 deletion(-)

diff --git sys/netinet/ip_ipsp.c sys/netinet/ip_ipsp.c
index 184c28f..55c3871 100644
--- sys/netinet/ip_ipsp.c
+++ sys/netinet/ip_ipsp.c
@@ -118,24 +118,26 @@ struct xformsw xformsw[] = {
          tcp_signature_tdb_zeroize,    tcp_signature_tdb_input,
          tcp_signature_tdb_output, }
 #endif /* TCP_SIGNATURE */
 };
 
 struct xformsw *xformswNXFORMSW = &xformsw[nitems(xformsw)];
 
 #define        TDB_HASHSIZE_INIT       32
 
 static struct tdb **tdbh = NULL;
+static struct tdb **tdbflow = NULL;
 static struct tdb **tdbdst = NULL;
 static struct tdb **tdbsrc = NULL;
 static u_int tdb_hashmask = TDB_HASHSIZE_INIT - 1;
 static int tdb_count;
+static u_int tdb_flow;
 
 /*
  * Our hashing function needs to stir things with a non-zero random multiplier
  * so we cannot be DoS-attacked via choosing of the data to hash.
  */
 int
 tdb_hash(u_int rdomain, u_int32_t spi, union sockaddr_union *dst,
     u_int8_t proto)
 {
        static u_int32_t mult1 = 0, mult2 = 0;
@@ -286,20 +288,41 @@ gettdb(u_int rdomain, u_int32_t spi, union sockaddr_union 
*dst, u_int8_t proto)
 
        for (tdbp = tdbh[hashval]; tdbp != NULL; tdbp = tdbp->tdb_hnext)
                if ((tdbp->tdb_spi == spi) && (tdbp->tdb_sproto == proto) &&
                    (tdbp->tdb_rdomain == rdomain) &&
                    !memcmp(&tdbp->tdb_dst, dst, SA_LEN(&dst->sa)))
                        break;
 
        return tdbp;
 }
 
+struct tdb *
+gettdbbyflow(u_int rdomain, u_int32_t flow, union sockaddr_union *dst,
+    u_int8_t proto)
+{
+       u_int32_t hashval;
+       struct tdb *tdbp;
+
+       if (tdbflow == NULL)
+               return (NULL);
+
+       hashval = tdb_hash(rdomain, flow, dst, proto);
+
+       for (tdbp = tdbflow[hashval]; tdbp != NULL; tdbp = tdbp->tdb_hnext)
+               if ((tdbp->tdb_flow == flow) && (tdbp->tdb_sproto == proto) &&
+                   (tdbp->tdb_rdomain == rdomain) &&
+                   !memcmp(&tdbp->tdb_dst, dst, SA_LEN(&dst->sa)))
+                       break;
+
+       return (tdbp);
+}
+
 /*
  * Same as gettdb() but compare SRC as well, so we
  * use the tdbsrc[] hash table.  Setting spi to 0
  * matches all SPIs.
  */
 struct tdb *
 gettdbbysrcdst(u_int rdomain, u_int32_t spi, union sockaddr_union *src,
     union sockaddr_union *dst, u_int8_t proto)
 {
        u_int32_t hashval;
@@ -577,43 +600,55 @@ tdb_soft_firstuse(void *v)
                pfkeyv2_expire(tdb, SADB_EXT_LIFETIME_SOFT);
        tdb->tdb_flags &= ~TDBF_SOFT_FIRSTUSE;
 }
 
 /*
  * Caller is responsible for splsoftnet().
  */
 void
 tdb_rehash(void)
 {
-       struct tdb **new_tdbh, **new_tdbdst, **new_srcaddr, *tdbp, *tdbnp;
+       struct tdb **new_tdbh, **new_tdbflow, **new_tdbdst, **new_srcaddr;
+       struct tdb *tdbp, *tdbnp;
        u_int i, old_hashmask = tdb_hashmask;
        u_int32_t hashval;
 
        tdb_hashmask = (tdb_hashmask << 1) | 1;
 
        new_tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
            M_WAITOK | M_ZERO);
+       new_tdbflow = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
+           M_WAITOK | M_ZERO);
        new_tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
            M_WAITOK | M_ZERO);
        new_srcaddr = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *), M_TDB,
            M_WAITOK | M_ZERO);
 
        for (i = 0; i <= old_hashmask; i++) {
                for (tdbp = tdbh[i]; tdbp != NULL; tdbp = tdbnp) {
                        tdbnp = tdbp->tdb_hnext;
                        hashval = tdb_hash(tdbp->tdb_rdomain,
                            tdbp->tdb_spi, &tdbp->tdb_dst,
                            tdbp->tdb_sproto);
                        tdbp->tdb_hnext = new_tdbh[hashval];
                        new_tdbh[hashval] = tdbp;
                }
 
+               for (tdbp = tdbflow[i]; tdbp != NULL; tdbp = tdbnp) {
+                       tdbnp = tdbp->tdb_fnext;
+                       hashval = tdb_hash(tdbp->tdb_rdomain,
+                           tdbp->tdb_flow, &tdbp->tdb_dst,
+                           tdbp->tdb_sproto);
+                       tdbp->tdb_fnext = new_tdbh[hashval];
+                       new_tdbflow[hashval] = tdbp;
+               }
+
                for (tdbp = tdbdst[i]; tdbp != NULL; tdbp = tdbnp) {
                        tdbnp = tdbp->tdb_dnext;
                        hashval = tdb_hash(tdbp->tdb_rdomain,
                            0, &tdbp->tdb_dst,
                            tdbp->tdb_sproto);
                        tdbp->tdb_dnext = new_tdbdst[hashval];
                        new_tdbdst[hashval] = tdbp;
                }
 
                for (tdbp = tdbsrc[i]; tdbp != NULL; tdbp = tdbnp) {
@@ -622,39 +657,45 @@ tdb_rehash(void)
                            0, &tdbp->tdb_src,
                            tdbp->tdb_sproto);
                        tdbp->tdb_snext = new_srcaddr[hashval];
                        new_srcaddr[hashval] = tdbp;
                }
        }
 
        free(tdbh, M_TDB, 0);
        tdbh = new_tdbh;
 
+       free(tdbflow, M_TDB, 0);
+       tdbflow = new_tdbflow;
+
        free(tdbdst, M_TDB, 0);
        tdbdst = new_tdbdst;
 
        free(tdbsrc, M_TDB, 0);
        tdbsrc = new_srcaddr;
 }
 
 /*
  * Add TDB in the hash table.
  */
 void
 puttdb(struct tdb *tdbp)
 {
+       struct tdb *tdbpp;
        u_int32_t hashval;
        int s = splsoftnet();
 
        if (tdbh == NULL) {
                tdbh = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
                    M_TDB, M_WAITOK | M_ZERO);
+               tdbflow = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
+                   M_TDB, M_WAITOK | M_ZERO);
                tdbdst = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
                    M_TDB, M_WAITOK | M_ZERO);
                tdbsrc = mallocarray(tdb_hashmask + 1, sizeof(struct tdb *),
                    M_TDB, M_WAITOK | M_ZERO);
        }
 
        hashval = tdb_hash(tdbp->tdb_rdomain, tdbp->tdb_spi,
            &tdbp->tdb_dst, tdbp->tdb_sproto);
 
        /*
@@ -668,20 +709,61 @@ puttdb(struct tdb *tdbp)
        if (tdbh[hashval] != NULL && tdbh[hashval]->tdb_hnext != NULL &&
            tdb_count * 10 > tdb_hashmask + 1) {
                tdb_rehash();
                hashval = tdb_hash(tdbp->tdb_rdomain, tdbp->tdb_spi,
                    &tdbp->tdb_dst, tdbp->tdb_sproto);
        }
 
        tdbp->tdb_hnext = tdbh[hashval];
        tdbh[hashval] = tdbp;
 
+       if (ISSET(tdbp->tdb_flags, TDBF_INVALID))
+               goto noflow;    /* not ready yet */
+       /*
+        * Try to figure out whether we're rekeying an existing SA
+        * by getting an exact match based on our credentials.
+        */
+       tdbpp = gettdbbydst(tdbp->tdb_rdomain, &tdbp->tdb_dst,
+           tdbp->tdb_sproto, tdbp->tdb_srcid, tdbp->tdb_dstid,
+           tdbp->tdb_local_cred, NULL, NULL);
+       /*
+        * If that fails try to find an SA set up for the opposite
+        * direction; allocate a new flow ID otherwise.
+        */
+       if (tdbpp == NULL || tdbpp->tdb_flow == 0)
+               tdbpp = gettdbbysrc(tdbp->tdb_rdomain, &tdbp->tdb_dst,
+                   tdbp->tdb_sproto, tdbp->tdb_srcid, tdbp->tdb_dstid,
+                   NULL, NULL);
+       if (tdbpp != NULL && tdbpp->tdb_flow > 0) {
+               tdbp->tdb_flow = tdbpp->tdb_flow;
+       } else {
+               u_int start = tdb_flow;
+
+               while (++tdb_flow != start) {
+                       if (tdb_flow == 0)
+                               tdb_flow++;
+                       if (gettdbbyflow(tdbp->tdb_rdomain, tdb_flow,
+                           &tdbp->tdb_dst, tdbp->tdb_sproto) == NULL)
+                               break;
+               }
+               if (tdb_flow == start) {
+                       printf("puttdb: too many flows\n");
+                       goto noflow;
+               }
+               tdbp->tdb_flow = tdb_flow;
+       }
+       hashval = tdb_hash(tdbp->tdb_rdomain, tdbp->tdb_flow,
+           &tdbp->tdb_dst, tdbp->tdb_sproto);
+       tdbp->tdb_fnext = tdbflow[hashval];
+       tdbflow[hashval] = tdbp;
+ noflow:
+
        hashval = tdb_hash(tdbp->tdb_rdomain, 0, &tdbp->tdb_dst,
            tdbp->tdb_sproto);
        tdbp->tdb_dnext = tdbdst[hashval];
        tdbdst[hashval] = tdbp;
 
        hashval = tdb_hash(tdbp->tdb_rdomain, 0, &tdbp->tdb_src,
            tdbp->tdb_sproto);
        tdbp->tdb_snext = tdbsrc[hashval];
        tdbsrc[hashval] = tdbp;
 
@@ -717,20 +799,37 @@ tdb_delete(struct tdb *tdbp)
                    tdbpp = tdbpp->tdb_hnext) {
                        if (tdbpp->tdb_hnext == tdbp) {
                                tdbpp->tdb_hnext = tdbp->tdb_hnext;
                                break;
                        }
                }
        }
 
        tdbp->tdb_hnext = NULL;
 
+       hashval = tdb_hash(tdbp->tdb_rdomain, tdbp->tdb_flow,
+           &tdbp->tdb_dst, tdbp->tdb_sproto);
+
+       if (tdbflow[hashval] == tdbp) {
+               tdbflow[hashval] = tdbp->tdb_fnext;
+       } else {
+               for (tdbpp = tdbflow[hashval]; tdbpp != NULL;
+                   tdbpp = tdbpp->tdb_fnext) {
+                       if (tdbpp->tdb_fnext == tdbp) {
+                               tdbpp->tdb_fnext = tdbp->tdb_fnext;
+                               break;
+                       }
+               }
+       }
+
+       tdbp->tdb_fnext = NULL;
+
        hashval = tdb_hash(tdbp->tdb_rdomain, 0, &tdbp->tdb_dst,
            tdbp->tdb_sproto);
 
        if (tdbdst[hashval] == tdbp) {
                tdbdst[hashval] = tdbp->tdb_dnext;
        } else {
                for (tdbpp = tdbdst[hashval]; tdbpp != NULL;
                    tdbpp = tdbpp->tdb_dnext) {
                        if (tdbpp->tdb_dnext == tdbp) {
                                tdbpp->tdb_dnext = tdbp->tdb_dnext;
diff --git sys/netinet/ip_ipsp.h sys/netinet/ip_ipsp.h
index 47a5670..e15ed67 100644
--- sys/netinet/ip_ipsp.h
+++ sys/netinet/ip_ipsp.h
@@ -264,20 +264,21 @@ struct ipsec_policy {
 struct tdb {                           /* tunnel descriptor block */
        /*
         * Each TDB is on three hash tables: one keyed on dst/spi/sproto,
         * one keyed on dst/sproto, and one keyed on src/sproto. The first
         * is used for finding a specific TDB, the second for finding TDBs
         * for outgoing policy matching, and the third for incoming
         * policy matching. The following three fields maintain the hash
         * queues in those three tables.
         */
        struct tdb      *tdb_hnext;     /* dst/spi/sproto table */
+       struct tdb      *tdb_fnext;     /* dst/flow/sproto table */
        struct tdb      *tdb_dnext;     /* dst/sproto table */
        struct tdb      *tdb_snext;     /* src/sproto table */
        struct tdb      *tdb_inext;
        struct tdb      *tdb_onext;
 
        struct xformsw          *tdb_xform;             /* Transform to use */
        struct enc_xform        *tdb_encalgxform;       /* Enc algorithm */
        struct auth_hash        *tdb_authalgxform;      /* Auth algorithm */
        struct comp_algo        *tdb_compalgxform;      /* Compression algo */
 
@@ -325,20 +326,21 @@ struct tdb {                              /* tunnel 
descriptor block */
        u_int64_t       tdb_exp_first_use;      /* Expire if tdb_first_use +
                                                 * tdb_exp_first_use <= curtime
                                                 */
 
        u_int64_t       tdb_last_used;  /* When was this SA last used */
        u_int64_t       tdb_last_marked;/* Last SKIPCRYPTO status change */
 
        u_int64_t       tdb_cryptoid;   /* Crypto session ID */
 
        u_int32_t       tdb_spi;        /* SPI */
+       u_int32_t       tdb_flow;       /* Flow ID for IP_IPSECFLOWINFO */
        u_int16_t       tdb_amxkeylen;  /* Raw authentication key length */
        u_int16_t       tdb_emxkeylen;  /* Raw encryption key length */
        u_int16_t       tdb_ivlen;      /* IV length */
        u_int8_t        tdb_sproto;     /* IPsec protocol */
        u_int8_t        tdb_wnd;        /* Replay window */
        u_int8_t        tdb_satype;     /* SA type (RFC2367, PF_KEY) */
        u_int8_t        tdb_updates;    /* pfsync update counter */
 
        union sockaddr_union    tdb_dst;        /* Destination address */
        union sockaddr_union    tdb_src;        /* Source address */
@@ -497,20 +499,21 @@ do {                                                      
                \
 uint8_t        get_sa_require(struct inpcb *);
 #ifdef ENCDEBUG
 const char *ipsp_address(union sockaddr_union, char *, socklen_t);
 #endif /* ENCDEBUG */
 
 /* TDB management routines */
 void   tdb_add_inp(struct tdb *, struct inpcb *, int);
 uint32_t reserve_spi(u_int, u_int32_t, u_int32_t, union sockaddr_union *,
                union sockaddr_union *, u_int8_t, int *);
 struct tdb *gettdb(u_int, u_int32_t, union sockaddr_union *, u_int8_t);
+struct tdb *gettdbbyflow(u_int, u_int32_t, union sockaddr_union *, u_int8_t);
 struct tdb *gettdbbydst(u_int, union sockaddr_union *, u_int8_t,
                struct ipsec_ref *, struct ipsec_ref *, struct ipsec_ref *,
                struct sockaddr_encap *, struct sockaddr_encap *);
 struct tdb *gettdbbysrc(u_int, union sockaddr_union *, u_int8_t,
                struct ipsec_ref *, struct ipsec_ref *,
                struct sockaddr_encap *, struct sockaddr_encap *);
 struct tdb *gettdbbysrcdst(u_int, u_int32_t, union sockaddr_union *,
                union sockaddr_union *, u_int8_t);
 void   puttdb(struct tdb *);
 void   tdb_delete(struct tdb *);
-- 
2.3.4

Reply via email to