On Wed, 1 Dec 2021 00:27:06 +0100
Alexander Bluhm <alexander.bl...@gmx.net> wrote:
> On Tue, Nov 30, 2021 at 05:53:34PM +0300, Vitaliy Makkoveev wrote:
>> Hi,
>> 
>> This question is mostly for bluhm@. Should the gettdbbyflow() grab the
>> extra reference on returned `tdbp' like other other gettdb*() do? I'm
>> pointing this because we are going to not rely on the netlock when doing
>> `tdbp' dereference.
> 
> Yes.  Call tdb_ref(tdbp) withing the tdb_sadb_mtx mutex.
> 
> The interesting question is when to unref it.  You use the same
> variable for the tdb parameter and the tdb from gettdbbyflow().
> Tracking when you don't use the new TDB anymore, gets tricky.

Let me update the diff.  That grabs a reference now.

Also the diff fixes gettdbbyflow().  Comparing ids was missing.


Index: sys/netinet/ip_ipsp.c
===================================================================
RCS file: /disk/cvs/openbsd/src/sys/netinet/ip_ipsp.c,v
retrieving revision 1.258
diff -u -p -r1.258 ip_ipsp.c
--- sys/netinet/ip_ipsp.c       29 Nov 2021 19:19:00 -0000      1.258
+++ sys/netinet/ip_ipsp.c       1 Dec 2021 12:19:53 -0000
@@ -90,6 +90,8 @@ void          tdb_firstuse(void *);
 void           tdb_soft_timeout(void *);
 void           tdb_soft_firstuse(void *);
 int            tdb_hash(u_int32_t, union sockaddr_union *, u_int8_t);
+int            sockaddr_encap_match(struct sockaddr_encap *,
+                   struct sockaddr_encap *, struct sockaddr_encap *);
 
 int ipsec_in_use = 0;
 u_int64_t ipsec_last_added = 0;
@@ -507,6 +509,78 @@ gettdbbysrc(u_int rdomain, union sockadd
        tdb_ref(tdbp);
        mtx_leave(&tdb_sadb_mtx);
        return tdbp;
+}
+
+/*
+ * Get an SA given the flow, the direction, the security protocol type, and
+ * the desired IDs.
+ */
+struct tdb *
+gettdbbyflow(u_int rdomain, int direction, struct sockaddr_encap *senflow,
+    u_int8_t sproto, struct ipsec_ids *ids)
+{
+       u_int32_t hashval;
+       struct tdb *tdbp;
+       union sockaddr_union srcdst;
+
+       if (ids == NULL)        /* ids is mandatory */
+               return NULL;
+
+       memset(&srcdst, 0, sizeof(srcdst));
+       switch (senflow->sen_type) {
+       case SENT_IP4:
+               srcdst.sin.sin_len = sizeof(srcdst.sin);
+               srcdst.sin.sin_family = AF_INET;
+               if (direction == IPSP_DIRECTION_OUT)
+                       srcdst.sin.sin_addr = senflow->Sen.Sip4.Dst;
+               else
+                       srcdst.sin.sin_addr = senflow->Sen.Sip4.Src;
+               break;
+       case SENT_IP6:
+               srcdst.sin6.sin6_len = sizeof(srcdst.sin6);
+               srcdst.sin6.sin6_family = AF_INET6;
+               if (direction == IPSP_DIRECTION_OUT)
+                       srcdst.sin6.sin6_addr = senflow->Sen.Sip6.Dst;
+               else
+                       srcdst.sin6.sin6_addr = senflow->Sen.Sip6.Src;
+               break;
+       }
+
+       mtx_enter(&tdb_sadb_mtx);
+       hashval = tdb_hash(0, &srcdst, sproto);
+
+       for (tdbp = tdbdst[hashval]; tdbp != NULL; tdbp = tdbp->tdb_dnext)
+               if (tdbp->tdb_sproto == sproto &&
+                   tdbp->tdb_rdomain == rdomain &&
+                   (tdbp->tdb_flags & TDBF_INVALID) == 0 &&
+                   ipsp_ids_match(ids, tdbp->tdb_ids) &&
+                   ((direction == IPSP_DIRECTION_OUT &&
+                   !memcmp(&tdbp->tdb_dst, &srcdst, srcdst.sa.sa_len)) ||
+                   (direction == IPSP_DIRECTION_IN &&
+                   !memcmp(&tdbp->tdb_src, &srcdst, srcdst.sa.sa_len)))) {
+                       if (sockaddr_encap_match(&tdbp->tdb_filter,
+                           &tdbp->tdb_filtermask, senflow))
+                               break;
+               }
+
+       tdb_ref(tdbp);
+       mtx_leave(&tdb_sadb_mtx);
+       return tdbp;
+}
+
+int
+sockaddr_encap_match(struct sockaddr_encap *addr, struct sockaddr_encap *mask,
+    struct sockaddr_encap *dest)
+{
+       size_t  off;
+
+       for (off = offsetof(struct sockaddr_encap, sen_type);
+           off < dest->sen_len; off++) {
+               if ((*((u_char *)addr + off) & *((u_char *)mask + off)) !=
+                   (*((u_char *)dest + off) & *((u_char *)mask + off)))
+                       break;
+       }
+       return (off < dest->sen_len)? 0 : 1;
 }
 
 #ifdef DDB
Index: sys/netinet/ip_ipsp.h
===================================================================
RCS file: /disk/cvs/openbsd/src/sys/netinet/ip_ipsp.h,v
retrieving revision 1.224
diff -u -p -r1.224 ip_ipsp.h
--- sys/netinet/ip_ipsp.h       30 Nov 2021 13:17:43 -0000      1.224
+++ sys/netinet/ip_ipsp.h       1 Dec 2021 12:19:53 -0000
@@ -565,6 +565,8 @@ struct      tdb *gettdbbysrcdst_dir(u_int, u_
                union sockaddr_union *, u_int8_t, int);
 #define gettdbbysrcdst(a,b,c,d,e) gettdbbysrcdst_dir((a),(b),(c),(d),(e),0)
 #define gettdbbysrcdst_rev(a,b,c,d,e) gettdbbysrcdst_dir((a),(b),(c),(d),(e),1)
+struct tdb *gettdbbyflow(u_int, int, struct sockaddr_encap *, u_int8_t,
+               struct ipsec_ids *);
 void   puttdb(struct tdb *);
 void   puttdb_locked(struct tdb *);
 void   tdb_delete(struct tdb *);
Index: sys/netinet/ip_spd.c
===================================================================
RCS file: /disk/cvs/openbsd/src/sys/netinet/ip_spd.c,v
retrieving revision 1.106
diff -u -p -r1.106 ip_spd.c
--- sys/netinet/ip_spd.c        30 Nov 2021 13:17:43 -0000      1.106
+++ sys/netinet/ip_spd.c        1 Dec 2021 12:19:53 -0000
@@ -149,13 +149,14 @@ ipsp_spd_lookup(struct mbuf *m, int af, 
     struct tdb *tdbp, struct inpcb *inp, u_int32_t ipsecflowinfo)
 {
        struct radix_node_head *rnh;
-       struct radix_node *rn;
+       struct radix_node *rn = NULL;
        union sockaddr_union sdst, ssrc;
        struct sockaddr_encap *ddst, dst;
        struct ipsec_policy *ipo;
        struct ipsec_ids *ids = NULL;
        int signore = 0, dignore = 0;
        u_int rdomain = rtable_l2(m->m_pkthdr.ph_rtableid);
+       struct tdb *tdb, *tdblocal = NULL;
 
        NET_ASSERT_LOCKED();
 
@@ -179,6 +180,8 @@ ipsp_spd_lookup(struct mbuf *m, int af, 
                return NULL;
        }
 
+       if (ipsecflowinfo != 0)
+               ids = ipsp_ids_lookup(ipsecflowinfo);
        memset(&dst, 0, sizeof(dst));
        memset(&sdst, 0, sizeof(union sockaddr_union));
        memset(&ssrc, 0, sizeof(union sockaddr_union));
@@ -301,9 +304,32 @@ ipsp_spd_lookup(struct mbuf *m, int af, 
                return NULL;
        }
 
+       /*
+        * Prepare tdb for searching the correct SPD by rn_lookup().
+        * "tdb_filtemask" of the tdb is used to find the correct SPD when
+        * multiple policies are overlapped.
+        */
+       tdb = tdbp;
+       if (ipsecflowinfo != 0 && ids != NULL) {
+               KASSERT(tdbp == NULL);
+               KASSERT(direction == IPSP_DIRECTION_OUT);
+               if ((tdblocal = gettdbbyflow(rdomain, direction, &dst,
+                   IPPROTO_ESP, ids)) != NULL)
+                       tdb = tdblocal;
+       }
+
        /* Actual SPD lookup. */
-       if ((rnh = spd_table_get(rdomain)) == NULL ||
-           (rn = rn_match((caddr_t)&dst, rnh)) == NULL) {
+       rnh = spd_table_get(rdomain);
+       if (rnh != NULL) {
+               if (tdb != NULL)
+                       rn = rn_lookup((caddr_t)&tdb->tdb_filter,
+                           (caddr_t)&tdb->tdb_filtermask, rnh);
+               else
+                       rn = rn_match((caddr_t)&dst, rnh);
+       }
+       tdb_unref(tdblocal);
+
+       if (rn == NULL) {
                /*
                 * Return whatever the socket requirements are, there are no
                 * system-wide policies.
@@ -396,9 +422,6 @@ ipsp_spd_lookup(struct mbuf *m, int af, 
                        }
                }
 
-               if (ipsecflowinfo)
-                       ids = ipsp_ids_lookup(ipsecflowinfo);
-
                /* Check that the cached TDB (if present), is appropriate. */
                if (ipo->ipo_tdb != NULL) {
                        if ((ipo->ipo_last_searched <= ipsec_last_added) ||
@@ -513,10 +536,19 @@ ipsp_spd_lookup(struct mbuf *m, int af, 
                                goto nomatchin;
 
                        /* Match source/dest IDs. */
-                       if (ipo->ipo_ids)
-                               if (tdbp->tdb_ids == NULL ||
-                                   !ipsp_ids_match(ipo->ipo_ids, 
tdbp->tdb_ids))
+                       if (ipo->ipo_ids != NULL) {
+                               if ((tdbp->tdb_flags & TDBF_TUNNELING) == 0 &&
+                                   (tdbp->tdb_flags & TDBF_UDPENCAP) != 0) {
+                                       /*
+                                        * Skip IDs check for transport mode
+                                        * with NAT-T.  Multiple clients (IDs)
+                                        * can use a same policy.
+                                        */
+                               } else if (tdbp->tdb_ids == NULL &&
+                                   !ipsp_ids_match(ipo->ipo_ids,
+                                   tdbp->tdb_ids))
                                        goto nomatchin;
+                       }
 
                        /* Add it to the cache. */
                        if (ipo->ipo_tdb != NULL) {

Reply via email to