The branch main has been updated by ks:

URL: 
https://cgit.FreeBSD.org/src/commit/?id=b9c0321d54e96d0a6591e9c609c7581916d3ddd3

commit b9c0321d54e96d0a6591e9c609c7581916d3ddd3
Author:     Kajetan Staszkiewicz <[email protected]>
AuthorDate: 2024-11-23 21:21:22 +0000
Commit:     Kajetan Staszkiewicz <[email protected]>
CommitDate: 2024-11-28 17:31:55 +0000

    pf: Fix source node locking
    
    Source nodes are created quite early in pf_create_state(), even before
    the state is allocated, locked and inserted into its hash row. They are
    prone to being freed by source node killing or clearing ioctl while
    pf_create_state() is still running.
    
    The function pf_map_addr_sn() can be called in two very different paths.
    
    One is for filter rules where it is called from
    pf_create_state() after pf_insert_src_node(). In this case it is called
    with a given source node and does not perform its own search and must
    return the source node.
    
    The other one is for NAT rules where it is called from
    pf_get_translation() or its descendants. In this case it is called with
    no known source node and performs its own search for source nodes. This
    source node is then passed back to pf_create_state() without locking.
    
    The states property of source node is increased in pf_find_src_node()
    which allows for the counter to increase when a packet matches the NAT
    rule but not a pass keep state rule.
    
    The function pf_map_addr() operates on unlocked source node.
    
    Modify pf_find_src_node() to return locked on source node found, so
    that any subsequent operations can access the source node safely.
    
    Move sn->states++ counter increase to pf_insert_src_node() to ensure
    that it's called only from pf_create_state() and not from NAT ruleset
    path, and have it increased only if the source node has really been
    inserted or found, simplifying the cleanup.
    
    Add locking in pf_src_connlimit() and pf_map_addr(). Sprinkle mutex
    assertions in pf_map_addr().
    
    Add a function pf_src_node_exists() to check a known source node is
    still valid. Use it in pf_create_state() where it's impossible to hold
    locks from pf_insert_src_node() because that would cause LoR (nodes
    first, then state) against pf_src_connlimit() (state first, then node).
    
    Don't propagate the source node found while parsing the NAT ruleset to
    pf_create_state() because it must be found again and locked or created.
    
    Reviewed by:            kp
    Approved by:            kp (mentor)
    Sponsored by:           InnoGames GmbH
    Differential Revision:  https://reviews.freebsd.org/D47770
---
 sys/net/pfvar.h                   |   8 ++-
 sys/netpfil/pf/pf.c               | 148 ++++++++++++++++++++++----------------
 sys/netpfil/pf/pf_lb.c            |  66 ++++++++++-------
 tests/sys/netpfil/pf/src_track.sh |  15 ++--
 4 files changed, 143 insertions(+), 94 deletions(-)

diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h
index e00101ba2b78..51f525c7383b 100644
--- a/sys/net/pfvar.h
+++ b/sys/net/pfvar.h
@@ -2334,6 +2334,9 @@ extern int                         
pf_udp_mapping_insert(struct pf_udp_mapping
                                    *mapping);
 extern void                     pf_udp_mapping_release(struct pf_udp_mapping
                                    *mapping);
+uint32_t                        pf_hashsrc(struct pf_addr *, sa_family_t);
+extern bool                     pf_src_node_exists(struct pf_ksrc_node **,
+                                   struct pf_srchash *);
 extern struct pf_ksrc_node     *pf_find_src_node(struct pf_addr *,
                                    struct pf_krule *, sa_family_t,
                                    struct pf_srchash **, bool);
@@ -2622,10 +2625,9 @@ u_short                   pf_map_addr(u_int8_t, struct 
pf_krule *,
 u_short                         pf_map_addr_sn(u_int8_t, struct pf_krule *,
                            struct pf_addr *, struct pf_addr *,
                            struct pfi_kkif **nkif, struct pf_addr *,
-                           struct pf_ksrc_node **);
+                           struct pf_ksrc_node **, struct pf_srchash **);
 u_short                         pf_get_translation(struct pf_pdesc *,
-                           int, struct pf_ksrc_node **,
-                           struct pf_state_key **, struct pf_state_key **,
+                           int, struct pf_state_key **, struct pf_state_key **,
                            struct pf_addr *, struct pf_addr *,
                            uint16_t, uint16_t, struct pf_kanchor_stackframe *,
                            struct pf_krule **,
diff --git a/sys/netpfil/pf/pf.c b/sys/netpfil/pf/pf.c
index 9436a4247411..9f8fec51e420 100644
--- a/sys/netpfil/pf/pf.c
+++ b/sys/netpfil/pf/pf.c
@@ -332,8 +332,7 @@ static int           pf_test_rule(struct pf_krule **, 
struct pf_kstate **,
                            struct pf_kruleset **, struct inpcb *);
 static int              pf_create_state(struct pf_krule *, struct pf_krule *,
                            struct pf_krule *, struct pf_pdesc *,
-                           struct pf_ksrc_node *, struct pf_state_key *,
-                           struct pf_state_key *,
+                           struct pf_state_key *, struct pf_state_key *,
                            u_int16_t, u_int16_t, int *,
                            struct pf_kstate **, int, u_int16_t, u_int16_t,
                            struct pf_krule_slist *, struct pf_udp_mapping *);
@@ -372,14 +371,15 @@ static void                pf_patch_8(struct mbuf *, 
u_int16_t *, u_int8_t *, u_int8_t,
                            bool, u_int8_t);
 static struct pf_kstate        *pf_find_state(struct pfi_kkif *,
                            const struct pf_state_key_cmp *, u_int);
-static int              pf_src_connlimit(struct pf_kstate *);
+static bool             pf_src_connlimit(struct pf_kstate *);
 static int              pf_match_rcvif(struct mbuf *, struct pf_krule *);
 static void             pf_counters_inc(int, struct pf_pdesc *,
                            struct pf_kstate *, struct pf_krule *,
                            struct pf_krule *);
 static void             pf_overload_task(void *v, int pending);
 static u_short          pf_insert_src_node(struct pf_ksrc_node **,
-                           struct pf_krule *, struct pf_addr *, sa_family_t);
+                           struct pf_srchash **, struct pf_krule *,
+                           struct pf_addr *, sa_family_t);
 static u_int            pf_purge_expired_states(u_int, int);
 static void             pf_purge_unlinked_rules(void);
 static int              pf_mtag_uminit(void *, int, int);
@@ -701,7 +701,7 @@ pf_hashkey(const struct pf_state_key *sk)
        return (h & V_pf_hashmask);
 }
 
-static __inline uint32_t
+__inline uint32_t
 pf_hashsrc(struct pf_addr *addr, sa_family_t af)
 {
        uint32_t h;
@@ -812,17 +812,14 @@ pf_check_threshold(struct pf_threshold *threshold)
        return (threshold->count > threshold->limit);
 }
 
-static int
+static bool
 pf_src_connlimit(struct pf_kstate *state)
 {
        struct pf_overload_entry *pfoe;
-       int bad = 0;
+       bool limited = false;
 
        PF_STATE_LOCK_ASSERT(state);
-       /*
-        * XXXKS: The src node is accessed unlocked!
-        * PF_SRC_NODE_LOCK_ASSERT(state->src_node);
-        */
+       PF_SRC_NODE_LOCK(state->src_node);
 
        state->src_node->conn++;
        state->src.tcp_est = 1;
@@ -832,29 +829,29 @@ pf_src_connlimit(struct pf_kstate *state)
            state->rule->max_src_conn <
            state->src_node->conn) {
                counter_u64_add(V_pf_status.lcounters[LCNT_SRCCONN], 1);
-               bad++;
+               limited = true;
        }
 
        if (state->rule->max_src_conn_rate.limit &&
            pf_check_threshold(&state->src_node->conn_rate)) {
                counter_u64_add(V_pf_status.lcounters[LCNT_SRCCONNRATE], 1);
-               bad++;
+               limited = true;
        }
 
-       if (!bad)
-               return (0);
+       if (!limited)
+               goto done;
 
        /* Kill this state. */
        state->timeout = PFTM_PURGE;
        pf_set_protostate(state, PF_PEER_BOTH, TCPS_CLOSED);
 
        if (state->rule->overload_tbl == NULL)
-               return (1);
+               goto done;
 
        /* Schedule overloading and flushing task. */
        pfoe = malloc(sizeof(*pfoe), M_PFTEMP, M_NOWAIT);
        if (pfoe == NULL)
-               return (1);     /* too bad :( */
+               goto done;  /* too bad :( */
 
        bcopy(&state->src_node->addr, &pfoe->addr, sizeof(pfoe->addr));
        pfoe->af = state->key[PF_SK_WIRE]->af;
@@ -865,7 +862,9 @@ pf_src_connlimit(struct pf_kstate *state)
        PF_OVERLOADQ_UNLOCK();
        taskqueue_enqueue(taskqueue_swi, &V_pf_overloadtask);
 
-       return (1);
+done:
+       PF_SRC_NODE_UNLOCK(state->src_node);
+       return (limited);
 }
 
 static void
@@ -962,8 +961,7 @@ pf_overload_task(void *v, int pending)
 }
 
 /*
- * Can return locked on failure, so that we can consistently
- * allocate and insert a new one.
+ * On node found always returns locked. On not found its configurable.
  */
 struct pf_ksrc_node *
 pf_find_src_node(struct pf_addr *src, struct pf_krule *rule, sa_family_t af,
@@ -981,15 +979,34 @@ pf_find_src_node(struct pf_addr *src, struct pf_krule 
*rule, sa_family_t af,
                    (af == AF_INET6 && bcmp(&n->addr, src, sizeof(*src)) == 0)))
                        break;
 
-       if (n != NULL) {
-               n->states++;
-               PF_HASHROW_UNLOCK(*sh);
-       } else if (returnlocked == false)
+       if (n == NULL && !returnlocked)
                PF_HASHROW_UNLOCK(*sh);
 
        return (n);
 }
 
+bool
+pf_src_node_exists(struct pf_ksrc_node **sn, struct pf_srchash *sh)
+{
+       struct pf_ksrc_node     *cur;
+
+       if ((*sn) == NULL)
+               return (false);
+
+       KASSERT(sh != NULL, ("%s: sh is NULL", __func__));
+
+       counter_u64_add(V_pf_status.scounters[SCNT_SRC_NODE_SEARCH], 1);
+       PF_HASHROW_LOCK(sh);
+       LIST_FOREACH(cur, &(sh->nodes), entry) {
+               if (cur == (*sn) &&
+                   cur->expire != 1) /* Ignore nodes being killed */
+                       return (true);
+       }
+       PF_HASHROW_UNLOCK(sh);
+       (*sn) = NULL;
+       return (false);
+}
+
 static void
 pf_free_src_node(struct pf_ksrc_node *sn)
 {
@@ -1002,33 +1019,33 @@ pf_free_src_node(struct pf_ksrc_node *sn)
 }
 
 static u_short
-pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_krule *rule,
-    struct pf_addr *src, sa_family_t af)
+pf_insert_src_node(struct pf_ksrc_node **sn, struct pf_srchash **sh,
+    struct pf_krule *rule, struct pf_addr *src, sa_family_t af)
 {
        u_short                  reason = 0;
-       struct pf_srchash       *sh = NULL;
 
        KASSERT((rule->rule_flag & PFRULE_SRCTRACK ||
            rule->rpool.opts & PF_POOL_STICKYADDR),
            ("%s for non-tracking rule %p", __func__, rule));
 
+       /*
+        * Request the sh to always be locked, as we might insert a new sn.
+        */
        if (*sn == NULL)
-               *sn = pf_find_src_node(src, rule, af, &sh, true);
+               *sn = pf_find_src_node(src, rule, af, sh, true);
 
        if (*sn == NULL) {
-               PF_HASHROW_ASSERT(sh);
+               PF_HASHROW_ASSERT(*sh);
 
                if (rule->max_src_nodes &&
                    counter_u64_fetch(rule->src_nodes) >= rule->max_src_nodes) {
                        counter_u64_add(V_pf_status.lcounters[LCNT_SRCNODES], 
1);
-                       PF_HASHROW_UNLOCK(sh);
                        reason = PFRES_SRCLIMIT;
                        goto done;
                }
 
                (*sn) = uma_zalloc(V_pf_sources_z, M_NOWAIT | M_ZERO);
                if ((*sn) == NULL) {
-                       PF_HASHROW_UNLOCK(sh);
                        reason = PFRES_MEMORY;
                        goto done;
                }
@@ -1039,7 +1056,6 @@ pf_insert_src_node(struct pf_ksrc_node **sn, struct 
pf_krule *rule,
 
                        if ((*sn)->bytes[i] == NULL || (*sn)->packets[i] == 
NULL) {
                                pf_free_src_node(*sn);
-                               PF_HASHROW_UNLOCK(sh);
                                reason = PFRES_MEMORY;
                                goto done;
                        }
@@ -1050,18 +1066,16 @@ pf_insert_src_node(struct pf_ksrc_node **sn, struct 
pf_krule *rule,
                    rule->max_src_conn_rate.seconds);
 
                MPASS((*sn)->lock == NULL);
-               (*sn)->lock = &sh->lock;
+               (*sn)->lock = &(*sh)->lock;
 
                (*sn)->af = af;
                (*sn)->rule = rule;
                PF_ACPY(&(*sn)->addr, src, af);
-               LIST_INSERT_HEAD(&sh->nodes, *sn, entry);
+               LIST_INSERT_HEAD(&(*sh)->nodes, *sn, entry);
                (*sn)->creation = time_uptime;
                (*sn)->ruletype = rule->action;
-               (*sn)->states = 1;
                if ((*sn)->rule != NULL)
                        counter_u64_add((*sn)->rule->src_nodes, 1);
-               PF_HASHROW_UNLOCK(sh);
                counter_u64_add(V_pf_status.scounters[SCNT_SRC_NODE_INSERT], 1);
        } else {
                if (rule->max_src_states &&
@@ -1073,6 +1087,12 @@ pf_insert_src_node(struct pf_ksrc_node **sn, struct 
pf_krule *rule,
                }
        }
 done:
+       if (reason == 0)
+               (*sn)->states++;
+       else
+               (*sn) = NULL;
+
+       PF_HASHROW_UNLOCK(*sh);
        return (reason);
 }
 
@@ -4880,7 +4900,6 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm,
        struct pf_kruleset      *ruleset = NULL;
        struct pf_krule_slist    match_rules;
        struct pf_krule_item    *ri;
-       struct pf_ksrc_node     *nsn = NULL;
        struct tcphdr           *th = &pd->hdr.tcp;
        struct pf_state_key     *sk = NULL, *nk = NULL;
        u_short                  reason, transerror;
@@ -4960,8 +4979,8 @@ pf_test_rule(struct pf_krule **rm, struct pf_kstate **sm,
        r = TAILQ_FIRST(pf_main_ruleset.rules[PF_RULESET_FILTER].active.ptr);
 
        /* check packet for BINAT/NAT/RDR */
-       transerror = pf_get_translation(pd, pd->off, &nsn, &sk,
-           &nk, saddr, daddr, sport, dport, anchor_stack, &nr, &udp_mapping);
+       transerror = pf_get_translation(pd, pd->off, &sk, &nk, saddr, daddr,
+           sport, dport, anchor_stack, &nr, &udp_mapping);
        switch (transerror) {
        default:
                /* A translation error occurred. */
@@ -5290,7 +5309,7 @@ nextrule:
           (!state_icmp && (r->keep_state || nr != NULL ||
            (pd->flags & PFDESC_TCP_NORM)))) {
                int action;
-               action = pf_create_state(r, nr, a, pd, nsn, nk, sk,
+               action = pf_create_state(r, nr, a, pd, nk, sk,
                    sport, dport, &rewrite, sm, tag, bproto_sum, bip_sum,
                    &match_rules, udp_mapping);
                if (action != PF_PASS) {
@@ -5345,14 +5364,16 @@ cleanup:
 
 static int
 pf_create_state(struct pf_krule *r, struct pf_krule *nr, struct pf_krule *a,
-    struct pf_pdesc *pd, struct pf_ksrc_node *nsn, struct pf_state_key *nk,
-    struct pf_state_key *sk, u_int16_t sport,
-    u_int16_t dport, int *rewrite, struct pf_kstate **sm,
+    struct pf_pdesc *pd, struct pf_state_key *nk, struct pf_state_key *sk,
+    u_int16_t sport, u_int16_t dport, int *rewrite, struct pf_kstate **sm,
     int tag, u_int16_t bproto_sum, u_int16_t bip_sum,
     struct pf_krule_slist *match_rules, struct pf_udp_mapping *udp_mapping)
 {
        struct pf_kstate        *s = NULL;
        struct pf_ksrc_node     *sn = NULL;
+       struct pf_srchash       *snh = NULL;
+       struct pf_ksrc_node     *nsn = NULL;
+       struct pf_srchash       *nsnh = NULL;
        struct tcphdr           *th = &pd->hdr.tcp;
        u_int16_t                mss = V_tcp_mssdflt;
        u_short                  reason, sn_reason;
@@ -5368,13 +5389,13 @@ pf_create_state(struct pf_krule *r, struct pf_krule 
*nr, struct pf_krule *a,
        /* src node for filter rule */
        if ((r->rule_flag & PFRULE_SRCTRACK ||
            r->rpool.opts & PF_POOL_STICKYADDR) &&
-           (sn_reason = pf_insert_src_node(&sn, r, pd->src, pd->af)) != 0) {
+           (sn_reason = pf_insert_src_node(&sn, &snh, r, pd->src, pd->af)) != 
0) {
                REASON_SET(&reason, sn_reason);
                goto csfailed;
        }
        /* src node for translation rule */
        if (nr != NULL && (nr->rpool.opts & PF_POOL_STICKYADDR) &&
-           (sn_reason = pf_insert_src_node(&nsn, nr, &sk->addr[pd->sidx],
+           (sn_reason = pf_insert_src_node(&nsn, &nsnh, nr, 
&sk->addr[pd->sidx],
            pd->af)) != 0 ) {
                REASON_SET(&reason, sn_reason);
                goto csfailed;
@@ -5468,20 +5489,13 @@ pf_create_state(struct pf_krule *r, struct pf_krule 
*nr, struct pf_krule *a,
        if (r->rt) {
                /* pf_map_addr increases the reason counters */
                if ((reason = pf_map_addr_sn(pd->af, r, pd->src, &s->rt_addr,
-                   &s->rt_kif, NULL, &sn)) != 0)
+                   &s->rt_kif, NULL, &sn, &snh)) != 0)
                        goto csfailed;
                s->rt = r->rt;
        }
 
        s->creation = s->expire = pf_get_uptime();
 
-       if (sn != NULL)
-               s->src_node = sn;
-       if (nsn != NULL) {
-               /* XXX We only modify one side for now. */
-               PF_ACPY(&nsn->raddr, &nk->addr[1], pd->af);
-               s->nat_src_node = nsn;
-       }
        if (pd->proto == IPPROTO_TCP) {
                if (s->state_flags & PFSTATE_SCRUB_TCP &&
                    pf_normalize_tcp_init(pd, th, &s->src, &s->dst)) {
@@ -5528,6 +5542,20 @@ pf_create_state(struct pf_krule *r, struct pf_krule *nr, 
struct pf_krule *a,
        } else
                *sm = s;
 
+       /*
+        * Lock order is important: first state, then source node.
+        */
+       if (pf_src_node_exists(&sn, snh)) {
+               s->src_node = sn;
+               PF_HASHROW_UNLOCK(snh);
+       }
+       if (pf_src_node_exists(&nsn, nsnh)) {
+               /* XXX We only modify one side for now. */
+               PF_ACPY(&nsn->raddr, &nk->addr[1], pd->af);
+               s->nat_src_node = nsn;
+               PF_HASHROW_UNLOCK(nsnh);
+       }
+
        if (tag > 0)
                s->tag = tag;
        if (pd->proto == IPPROTO_TCP && (th->th_flags & (TH_SYN|TH_ACK)) ==
@@ -5578,26 +5606,24 @@ csfailed:
        uma_zfree(V_pf_state_key_z, sk);
        uma_zfree(V_pf_state_key_z, nk);
 
-       if (sn != NULL) {
-               PF_SRC_NODE_LOCK(sn);
+       if (pf_src_node_exists(&sn, snh)) {
                if (--sn->states == 0 && sn->expire == 0) {
                        pf_unlink_src_node(sn);
-                       uma_zfree(V_pf_sources_z, sn);
+                       pf_free_src_node(sn);
                        counter_u64_add(
                            V_pf_status.scounters[SCNT_SRC_NODE_REMOVALS], 1);
                }
-               PF_SRC_NODE_UNLOCK(sn);
+               PF_HASHROW_UNLOCK(snh);
        }
 
-       if (nsn != sn && nsn != NULL) {
-               PF_SRC_NODE_LOCK(nsn);
+       if (sn != nsn && pf_src_node_exists(&nsn, nsnh)) {
                if (--nsn->states == 0 && nsn->expire == 0) {
                        pf_unlink_src_node(nsn);
-                       uma_zfree(V_pf_sources_z, nsn);
+                       pf_free_src_node(nsn);
                        counter_u64_add(
                            V_pf_status.scounters[SCNT_SRC_NODE_REMOVALS], 1);
                }
-               PF_SRC_NODE_UNLOCK(nsn);
+               PF_HASHROW_UNLOCK(nsnh);
        }
 
 drop:
diff --git a/sys/netpfil/pf/pf_lb.c b/sys/netpfil/pf/pf_lb.c
index 5777cf19b067..e180f87d2998 100644
--- a/sys/netpfil/pf/pf_lb.c
+++ b/sys/netpfil/pf/pf_lb.c
@@ -69,7 +69,7 @@ static struct pf_krule        *pf_match_translation(struct 
pf_pdesc *,
                            struct pf_kanchor_stackframe *);
 static int pf_get_sport(sa_family_t, uint8_t, struct pf_krule *,
     struct pf_addr *, uint16_t, struct pf_addr *, uint16_t, struct pf_addr *,
-    uint16_t *, uint16_t, uint16_t, struct pf_ksrc_node **,
+    uint16_t *, uint16_t, uint16_t, struct pf_ksrc_node **, struct 
pf_srchash**,
     struct pf_udp_mapping **);
 static bool             pf_islinklocal(const sa_family_t, const struct pf_addr 
*);
 
@@ -225,12 +225,11 @@ static int
 pf_get_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
     struct pf_addr *saddr, uint16_t sport, struct pf_addr *daddr,
     uint16_t dport, struct pf_addr *naddr, uint16_t *nport, uint16_t low,
-    uint16_t high, struct pf_ksrc_node **sn,
+    uint16_t high, struct pf_ksrc_node **sn, struct pf_srchash **sh,
     struct pf_udp_mapping **udp_mapping)
 {
        struct pf_state_key_cmp key;
        struct pf_addr          init_addr;
-       struct pf_srchash       *sh = NULL;
 
        bzero(&init_addr, sizeof(init_addr));
 
@@ -255,7 +254,9 @@ pf_get_sport(sa_family_t af, u_int8_t proto, struct 
pf_krule *r,
                        /* Try to find a src_node as per pf_map_addr(). */
                        if (*sn == NULL && r->rpool.opts & PF_POOL_STICKYADDR &&
                            (r->rpool.opts & PF_POOL_TYPEMASK) != PF_POOL_NONE)
-                               *sn = pf_find_src_node(saddr, r, af, &sh, 0);
+                               *sn = pf_find_src_node(saddr, r, af, sh, false);
+                       if (*sn != NULL)
+                               PF_SRC_NODE_UNLOCK(*sn);
                        return (0);
                } else {
                        *udp_mapping = pf_udp_mapping_create(af, saddr, sport, 
&init_addr, 0);
@@ -264,7 +265,7 @@ pf_get_sport(sa_family_t af, u_int8_t proto, struct 
pf_krule *r,
                }
        }
 
-       if (pf_map_addr_sn(af, r, saddr, naddr, NULL, &init_addr, sn))
+       if (pf_map_addr_sn(af, r, saddr, naddr, NULL, &init_addr, sn, sh))
                goto failed;
 
        if (proto == IPPROTO_ICMP) {
@@ -385,7 +386,8 @@ pf_get_sport(sa_family_t af, u_int8_t proto, struct 
pf_krule *r,
                         * pick a different source address since we're out
                         * of free port choices for the current one.
                         */
-                       if (pf_map_addr_sn(af, r, saddr, naddr, NULL, 
&init_addr, sn))
+                       (*sn) = NULL;
+                       if (pf_map_addr_sn(af, r, saddr, naddr, NULL, 
&init_addr, sn, sh))
                                return (1);
                        break;
                case PF_POOL_NONE:
@@ -414,7 +416,8 @@ static int
 pf_get_mape_sport(sa_family_t af, u_int8_t proto, struct pf_krule *r,
     struct pf_addr *saddr, uint16_t sport, struct pf_addr *daddr,
     uint16_t dport, struct pf_addr *naddr, uint16_t *nport,
-    struct pf_ksrc_node **sn, struct pf_udp_mapping **udp_mapping)
+    struct pf_ksrc_node **sn, struct pf_srchash **sh,
+    struct pf_udp_mapping **udp_mapping)
 {
        uint16_t psmask, low, highmask;
        uint16_t i, ahigh, cut;
@@ -434,13 +437,13 @@ pf_get_mape_sport(sa_family_t af, u_int8_t proto, struct 
pf_krule *r,
        for (i = cut; i <= ahigh; i++) {
                low = (i << ashift) | psmask;
                if (!pf_get_sport(af, proto, r, saddr, sport, daddr, dport,
-                   naddr, nport, low, low | highmask, sn, udp_mapping))
+                   naddr, nport, low, low | highmask, sn, sh, udp_mapping))
                        return (0);
        }
        for (i = cut - 1; i > 0; i--) {
                low = (i << ashift) | psmask;
                if (!pf_get_sport(af, proto, r, saddr, sport, daddr, dport,
-                   naddr, nport, low, low | highmask, sn, udp_mapping))
+                   naddr, nport, low, low | highmask, sn, sh, udp_mapping))
                        return (0);
        }
        return (1);
@@ -623,23 +626,31 @@ done_pool_mtx:
 u_short
 pf_map_addr_sn(sa_family_t af, struct pf_krule *r, struct pf_addr *saddr,
     struct pf_addr *naddr, struct pfi_kkif **nkif, struct pf_addr *init_addr,
-    struct pf_ksrc_node **sn)
+    struct pf_ksrc_node **sn, struct pf_srchash **sh)
 {
        u_short                  reason = 0;
        struct pf_kpool         *rpool = &r->rpool;
-       struct pf_srchash       *sh = NULL;
 
-       /* Try to find a src_node if none was given and this
-          is a sticky-address rule. */
-       if (*sn == NULL && r->rpool.opts & PF_POOL_STICKYADDR &&
-           (r->rpool.opts & PF_POOL_TYPEMASK) != PF_POOL_NONE)
-               *sn = pf_find_src_node(saddr, r, af, &sh, false);
+       /*
+        * Try to find a src_node if none was given and this is
+        * a sticky-address rule. Request the sh to be unlocked if
+        * sn was not found, as here we never insert a new sn.
+        */
+       if (*sn == NULL) {
+               if (r->rpool.opts & PF_POOL_STICKYADDR &&
+                   (r->rpool.opts & PF_POOL_TYPEMASK) != PF_POOL_NONE)
+                       *sn = pf_find_src_node(saddr, r, af, sh, false);
+       } else {
+               pf_src_node_exists(sn, *sh);
+       }
 
        /* If a src_node was found or explicitly given and it has a non-zero
           route address, use this address. A zeroed address is found if the
           src node was created just a moment ago in pf_create_state and it
           needs to be filled in with routing decision calculated here. */
        if (*sn != NULL && !PF_AZERO(&(*sn)->raddr, af)) {
+               PF_SRC_NODE_LOCK_ASSERT(*sn);
+
                /* If the supplied address is the same as the current one we've
                 * been asked before, so tell the caller that there's no other
                 * address to be had. */
@@ -673,6 +684,8 @@ pf_map_addr_sn(sa_family_t af, struct pf_krule *r, struct 
pf_addr *saddr,
        }
 
        if (*sn != NULL) {
+               PF_SRC_NODE_LOCK_ASSERT(*sn);
+
                PF_ACPY(&(*sn)->raddr, naddr, af);
                if (nkif)
                        (*sn)->rkif = *nkif;
@@ -688,6 +701,9 @@ pf_map_addr_sn(sa_family_t af, struct pf_krule *r, struct 
pf_addr *saddr,
        }
 
 done:
+       if ((*sn) != NULL)
+               PF_SRC_NODE_UNLOCK(*sn);
+
        if (reason) {
                counter_u64_add(V_pf_status.counters[reason], 1);
        }
@@ -697,14 +713,15 @@ done:
 
 u_short
 pf_get_translation(struct pf_pdesc *pd, int off,
-    struct pf_ksrc_node **sn, struct pf_state_key **skp,
-    struct pf_state_key **nkp, struct pf_addr *saddr, struct pf_addr *daddr,
-    uint16_t sport, uint16_t dport, struct pf_kanchor_stackframe *anchor_stack,
-    struct pf_krule **rp,
+    struct pf_state_key **skp, struct pf_state_key **nkp, struct pf_addr 
*saddr,
+    struct pf_addr *daddr, uint16_t sport, uint16_t dport,
+    struct pf_kanchor_stackframe *anchor_stack, struct pf_krule **rp,
     struct pf_udp_mapping **udp_mapping)
 {
        struct pf_krule *r = NULL;
        struct pf_addr  *naddr;
+       struct pf_ksrc_node     *sn = NULL;
+       struct pf_srchash       *sh = NULL;
        uint16_t        *nportp;
        uint16_t         low, high;
        u_short          reason;
@@ -765,7 +782,8 @@ pf_get_translation(struct pf_pdesc *pd, int off,
                }
                if (r->rpool.mape.offset > 0) {
                        if (pf_get_mape_sport(pd->af, pd->proto, r, saddr,
-                           sport, daddr, dport, naddr, nportp, sn, 
udp_mapping)) {
+                           sport, daddr, dport, naddr, nportp, &sn, &sh,
+                           udp_mapping)) {
                                DPFPRINTF(PF_DEBUG_MISC,
                                    ("pf: MAP-E port allocation (%u/%u/%u)"
                                    " failed\n",
@@ -776,7 +794,8 @@ pf_get_translation(struct pf_pdesc *pd, int off,
                                goto notrans;
                        }
                } else if (pf_get_sport(pd->af, pd->proto, r, saddr, sport,
-                   daddr, dport, naddr, nportp, low, high, sn, udp_mapping)) {
+                   daddr, dport, naddr, nportp, low, high, &sn, &sh,
+                   udp_mapping)) {
                        DPFPRINTF(PF_DEBUG_MISC,
                            ("pf: NAT proxy port allocation (%u-%u) failed\n",
                            r->rpool.proxy_port[0], r->rpool.proxy_port[1]));
@@ -863,7 +882,7 @@ pf_get_translation(struct pf_pdesc *pd, int off,
                int tries;
                uint16_t cut, low, high, nport;
 
-               reason = pf_map_addr_sn(pd->af, r, saddr, naddr, NULL, NULL, 
sn);
+               reason = pf_map_addr_sn(pd->af, r, saddr, naddr, NULL, NULL, 
&sn, &sh);
                if (reason != 0)
                        goto notrans;
                if ((r->rpool.opts & PF_POOL_TYPEMASK) == PF_POOL_BITMASK)
@@ -970,7 +989,6 @@ notrans:
        uma_zfree(V_pf_state_key_z, *nkp);
        uma_zfree(V_pf_state_key_z, *skp);
        *skp = *nkp = NULL;
-       *sn = NULL;
 
        return (reason);
 }
diff --git a/tests/sys/netpfil/pf/src_track.sh 
b/tests/sys/netpfil/pf/src_track.sh
index 5349e61ec76b..620f1353f9fe 100755
--- a/tests/sys/netpfil/pf/src_track.sh
+++ b/tests/sys/netpfil/pf/src_track.sh
@@ -217,28 +217,31 @@ max_src_states_rule_body()
        # 2 connections from host ::1 matching rule_A will be allowed, 1 will 
fail to create a state.
        ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4211 
--fromaddr 2001:db8:44::1
        ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4212 
--fromaddr 2001:db8:44::1
-       ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4213 
--fromaddr 2001:db8:44::1
+       ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4213 
--fromaddr 2001:db8:44::1
+       ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4214 
--fromaddr 2001:db8:44::1
 
        # 2 connections from host ::1 matching rule_B will be allowed, 1 will 
fail to create a state.
        # Limits from rule_A don't interfere with rule_B.
        ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4221 
--fromaddr 2001:db8:44::1
        ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4222 
--fromaddr 2001:db8:44::1
-       ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4223 
--fromaddr 2001:db8:44::1
+       ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4223 
--fromaddr 2001:db8:44::1
+       ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4224 
--fromaddr 2001:db8:44::1
 
        # 2 connections from host ::2 matching rule_B will be allowed, 1 will 
fail to create a state.
        # Limits for host ::1 will not interfere with host ::2.
        ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4224 
--fromaddr 2001:db8:44::2
        ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4225 
--fromaddr 2001:db8:44::2
-       ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4226 
--fromaddr 2001:db8:44::2
+       ping_server_check_reply exit:0 --ping-type=tcp3way --send-sport=4226 
--fromaddr 2001:db8:44::2
+       ping_server_check_reply exit:1 --ping-type=tcp3way --send-sport=4227 
--fromaddr 2001:db8:44::2
 
        # We will check the resulting source nodes, though.
        # Order of source nodes in output is not guaranteed, find each one 
separately.
        nodes=$(mktemp) || exit 1
        jexec router pfctl -qvsS | normalize_pfctl_s > $nodes
        for node_regexp in \
-               '2001:db8:44::1 -> :: \( states 2, connections 2, rate 
[0-9/\.]+s \) age [0-9:]+, 6 pkts, [0-9]+ bytes, filter rule 3$' \
-               '2001:db8:44::1 -> :: \( states 2, connections 2, rate 
[0-9/\.]+s \) age [0-9:]+, 6 pkts, [0-9]+ bytes, filter rule 4$' \
-               '2001:db8:44::2 -> :: \( states 2, connections 2, rate 
[0-9/\.]+s \) age [0-9:]+, 6 pkts, [0-9]+ bytes, filter rule 4$' \
+               '2001:db8:44::1 -> :: \( states 3, connections 3, rate 
[0-9/\.]+s \) age [0-9:]+, 9 pkts, [0-9]+ bytes, filter rule 3$' \
+               '2001:db8:44::1 -> :: \( states 3, connections 3, rate 
[0-9/\.]+s \) age [0-9:]+, 9 pkts, [0-9]+ bytes, filter rule 4$' \
+               '2001:db8:44::2 -> :: \( states 3, connections 3, rate 
[0-9/\.]+s \) age [0-9:]+, 9 pkts, [0-9]+ bytes, filter rule 4$' \
        ; do
                grep -qE "$node_regexp" $nodes || atf_fail "Source nodes not 
matching expected output"
        done

Reply via email to