OK bluhm@

On Sun, Jan 03, 2016 at 04:32:58PM +0100, Alexandr Nedvedicky wrote:
> Index: kern/uipc_mbuf.c
> ===================================================================
> RCS file: /cvs/src/sys/kern/uipc_mbuf.c,v
> retrieving revision 1.216
> diff -u -p -r1.216 uipc_mbuf.c
> --- kern/uipc_mbuf.c  23 Dec 2015 21:04:55 -0000      1.216
> +++ kern/uipc_mbuf.c  3 Jan 2016 14:40:29 -0000
> @@ -72,6 +72,8 @@
>   * Research Laboratory (NRL).
>   */
>  
> +#include "pf.h"
> +
>  #include <sys/param.h>
>  #include <sys/systm.h>
>  #include <sys/malloc.h>
> @@ -85,6 +87,9 @@
>  #include <sys/socket.h>
>  #include <sys/socketvar.h>
>  #include <net/if.h>
> +#if NPF > 0
> +#include <net/pfvar.h>
> +#endif       /* NPF > 0 */
>  
>  
>  #include <uvm/uvm_extern.h>
> @@ -261,6 +266,10 @@ m_resethdr(struct mbuf *m)
>       /* delete all mbuf tags to reset the state */
>       m_tag_delete_chain(m);
>  
> +#if NPF > 0
> +     pf_pkt_unlink_state_key(m);
> +#endif       /* NPF > 0 */
> +
>       /* like m_inithdr(), but keep any associated data and mbufs */
>       memset(&m->m_pkthdr, 0, sizeof(m->m_pkthdr));
>       m->m_pkthdr.pf.prio = IFQ_DEFPRIO;
> @@ -350,8 +359,12 @@ m_free(struct mbuf *m)
>               if (n)
>                       n->m_flags |= M_ZEROIZE;
>       }
> -     if (m->m_flags & M_PKTHDR)
> +     if (m->m_flags & M_PKTHDR) {
>               m_tag_delete_chain(m);
> +#if NPF > 0
> +             pf_pkt_unlink_state_key(m);
> +#endif       /* NPF > 0 */
> +     }
>       if (m->m_flags & M_EXT)
>               m_extfree(m);
>  
> @@ -1201,6 +1214,10 @@ m_dup_pkthdr(struct mbuf *to, struct mbu
>       to->m_flags = (to->m_flags & (M_EXT | M_EXTWR));
>       to->m_flags |= (from->m_flags & M_COPYFLAGS);
>       to->m_pkthdr = from->m_pkthdr;
> +
> +#if NPF > 0
> +     pf_pkt_state_key_ref(to);
> +#endif /* NPF > 0 */
>  
>       SLIST_INIT(&to->m_pkthdr.ph_tags);
>  
> Index: sys/mbuf.h
> ===================================================================
> RCS file: /cvs/src/sys/sys/mbuf.h,v
> retrieving revision 1.205
> diff -u -p -r1.205 mbuf.h
> --- sys/mbuf.h        21 Nov 2015 11:46:25 -0000      1.205
> +++ sys/mbuf.h        3 Jan 2016 14:40:30 -0000
> @@ -316,6 +316,7 @@ struct mbuf {
>       (to)->m_pkthdr = (from)->m_pkthdr;                              \
>       (from)->m_flags &= ~M_PKTHDR;                                   \
>       SLIST_INIT(&(from)->m_pkthdr.ph_tags);                          \
> +     (from)->m_pkthdr.pf.statekey = NULL;                            \
>  } while (/* CONSTCOND */ 0)
>  
>  /*
> Index: net/pf.c
> ===================================================================
> RCS file: /cvs/src/sys/net/pf.c,v
> retrieving revision 1.962
> diff -u -p -r1.962 pf.c
> --- net/pf.c  23 Dec 2015 21:04:55 -0000      1.962
> +++ net/pf.c  3 Jan 2016 14:40:33 -0000
> @@ -231,6 +231,11 @@ int                       pf_step_out_of_anchor(int *, 
> stru
>  void                  pf_counters_inc(int, struct pf_pdesc *,
>                           struct pf_state *, struct pf_rule *,
>                           struct pf_rule *);
> +void                  pf_state_key_link(struct pf_state_key *,
> +                         struct pf_state_key *);
> +void                  pf_inpcb_unlink_state_key(struct inpcb *);
> +void                  pf_state_key_unlink_reverse(struct pf_state_key *);
> +
>  #if NPFLOG > 0
>  void                  pf_log_matches(struct pf_pdesc *, struct pf_rule *,
>                           struct pf_rule *, struct pf_ruleset *,
> @@ -694,8 +699,11 @@ pf_state_key_attach(struct pf_state_key 
>                       }
>               pool_put(&pf_state_key_pl, sk);
>               s->key[idx] = cur;
> -     } else
> +     } else {
>               s->key[idx] = sk;
> +             /* need to grab reference for PF */
> +             pf_state_key_ref(sk);
> +     }
>  
>       if ((si = pool_get(&pf_state_item_pl, PR_NOWAIT)) == NULL) {
>               pf_state_key_detach(s, idx);
> @@ -732,6 +740,7 @@ void
>  pf_state_key_detach(struct pf_state *s, int idx)
>  {
>       struct pf_state_item    *si;
> +     struct pf_state_key     *sk;
>  
>       if (s->key[idx] == NULL)
>               return;
> @@ -745,15 +754,15 @@ pf_state_key_detach(struct pf_state *s, 
>               pool_put(&pf_state_item_pl, si);
>       }
>  
> -     if (TAILQ_EMPTY(&s->key[idx]->states)) {
> -             RB_REMOVE(pf_state_tree, &pf_statetbl, s->key[idx]);
> -             if (s->key[idx]->reverse)
> -                     s->key[idx]->reverse->reverse = NULL;
> -             if (s->key[idx]->inp)
> -                     s->key[idx]->inp->inp_pf_sk = NULL;
> -             pool_put(&pf_state_key_pl, s->key[idx]);
> -     }
> +     sk = s->key[idx];
>       s->key[idx] = NULL;
> +     if (TAILQ_EMPTY(&sk->states)) {
> +             RB_REMOVE(pf_state_tree, &pf_statetbl, sk);
> +             sk->removed = 1;
> +             pf_state_key_unlink_reverse(sk);
> +             pf_inpcb_unlink_state_key(sk->inp);
> +             pf_state_key_unref(sk);
> +     }
>  }
>  
>  struct pf_state_key *
> @@ -840,6 +849,8 @@ pf_state_key_setup(struct pf_pdesc *pd, 
>       sk1->proto = pd->proto;
>       sk1->af = pd->af;
>       sk1->rdomain = pd->rdomain;
> +     PF_REF_INIT(sk1->refcnt);
> +     sk1->removed = 0;
>       if (rtableid >= 0)
>               wrdom = rtable_l2(rtableid);
>  
> @@ -871,6 +882,8 @@ pf_state_key_setup(struct pf_pdesc *pd, 
>                       sk2->proto = pd->proto;
>               sk2->af = pd->naf;
>               sk2->rdomain = wrdom;
> +             PF_REF_INIT(sk2->refcnt);
> +             sk2->removed = 0;
>       } else
>               sk2 = sk1;
>  
> @@ -986,7 +999,7 @@ struct pf_state *
>  pf_find_state(struct pfi_kif *kif, struct pf_state_key_cmp *key, u_int dir,
>      struct mbuf *m)
>  {
> -     struct pf_state_key     *sk;
> +     struct pf_state_key     *sk, *pkt_sk, *inp_sk;
>       struct pf_state_item    *si;
>  
>       pf_status.fcounters[FCNT_STATE_SEARCH]++;
> @@ -996,31 +1009,47 @@ pf_find_state(struct pfi_kif *kif, struc
>               addlog("\n");
>       }
>  
> -     if (dir == PF_OUT && m->m_pkthdr.pf.statekey &&
> -         m->m_pkthdr.pf.statekey->reverse)
> -             sk = m->m_pkthdr.pf.statekey->reverse;
> -     else if (dir == PF_OUT && m->m_pkthdr.pf.inp &&
> -         m->m_pkthdr.pf.inp->inp_pf_sk)
> -             sk = m->m_pkthdr.pf.inp->inp_pf_sk;
> -     else {
> +     inp_sk = NULL;
> +     pkt_sk = NULL;
> +     sk = NULL;
> +     if (dir == PF_OUT) {
> +             /* first if block deals with outbound forwarded packet */
> +             pkt_sk = m->m_pkthdr.pf.statekey;
> +             if (pf_state_key_isvalid(pkt_sk) &&
> +                 pf_state_key_isvalid(pkt_sk->reverse)) {
> +                     sk = pkt_sk->reverse;
> +             } else {
> +                     pf_pkt_unlink_state_key(m);
> +                     pkt_sk = NULL;
> +             }
> +
> +             if (pkt_sk == NULL) {
> +                     /* here we deal with local outbound packet */
> +                     if (m->m_pkthdr.pf.inp != NULL) {
> +                             inp_sk = m->m_pkthdr.pf.inp->inp_pf_sk;
> +                             if (pf_state_key_isvalid(inp_sk))
> +                                     sk = inp_sk;
> +                             else
> +                                     pf_inpcb_unlink_state_key(
> +                                         m->m_pkthdr.pf.inp);
> +                     }
> +             }
> +     }
> +
> +     if (sk == NULL) {
>               if ((sk = RB_FIND(pf_state_tree, &pf_statetbl,
>                   (struct pf_state_key *)key)) == NULL)
>                       return (NULL);
> -             if (dir == PF_OUT && m->m_pkthdr.pf.statekey &&
> -                 pf_compare_state_keys(m->m_pkthdr.pf.statekey, sk,
> -                 kif, dir) == 0) {
> -                     m->m_pkthdr.pf.statekey->reverse = sk;
> -                     sk->reverse = m->m_pkthdr.pf.statekey;
> -             } else if (dir == PF_OUT && m->m_pkthdr.pf.inp && !sk->inp) {
> -                     m->m_pkthdr.pf.inp->inp_pf_sk = sk;
> -                     sk->inp = m->m_pkthdr.pf.inp;
> -             }
> +             if (dir == PF_OUT && pkt_sk &&
> +                 pf_compare_state_keys(pkt_sk, sk, kif, dir) == 0)
> +                     pf_state_key_link(sk, pkt_sk);
> +             else if (dir == PF_OUT)
> +                     pf_inp_link(m, m->m_pkthdr.pf.inp);
>       }
>  
> -     if (dir == PF_OUT) {
> -             m->m_pkthdr.pf.statekey = NULL;
> -             m->m_pkthdr.pf.inp = NULL;
> -     }
> +     /* remove firewall data from outbound packet */
> +     if (dir == PF_OUT)
> +             pf_pkt_addr_changed(m);
>  
>       /* list is sorted, if-bound states before floating ones */
>       TAILQ_FOREACH(si, &sk->states, entry)
> @@ -6531,11 +6560,13 @@ done:
>       if (action == PF_PASS && qid)
>               pd.m->m_pkthdr.pf.qid = qid;
>       if (pd.dir == PF_IN && s && s->key[PF_SK_STACK])
> -             pd.m->m_pkthdr.pf.statekey = s->key[PF_SK_STACK];
> +             pd.m->m_pkthdr.pf.statekey =
> +                 pf_state_key_ref(s->key[PF_SK_STACK]);
>       if (pd.dir == PF_OUT &&
>           pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk &&
>           s && s->key[PF_SK_STACK] && !s->key[PF_SK_STACK]->inp) {
> -             pd.m->m_pkthdr.pf.inp->inp_pf_sk = s->key[PF_SK_STACK];
> +             pd.m->m_pkthdr.pf.inp->inp_pf_sk =
> +                 pf_state_key_ref(s->key[PF_SK_STACK]);
>               s->key[PF_SK_STACK]->inp = pd.m->m_pkthdr.pf.inp;
>       }
>  
> @@ -6706,7 +6737,7 @@ pf_cksum(struct pf_pdesc *pd, struct mbu
>  void
>  pf_pkt_addr_changed(struct mbuf *m)
>  {
> -     m->m_pkthdr.pf.statekey = NULL;
> +     pf_pkt_unlink_state_key(m);
>       m->m_pkthdr.pf.inp = NULL;
>  }
>  
> @@ -6714,25 +6745,40 @@ struct inpcb *
>  pf_inp_lookup(struct mbuf *m)
>  {
>       struct inpcb *inp = NULL;
> +     struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
>  
> -     if (m->m_pkthdr.pf.statekey) {
> +     if (!pf_state_key_isvalid(sk))
> +             pf_pkt_unlink_state_key(m);
> +     else
>               inp = m->m_pkthdr.pf.statekey->inp;
> -             if (inp && inp->inp_pf_sk)
> -                     KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
> -     }
> +
> +     if (inp && inp->inp_pf_sk)
> +             KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
> +
>       return (inp);
>  }
>  
>  void
>  pf_inp_link(struct mbuf *m, struct inpcb *inp)
>  {
> -     if (m->m_pkthdr.pf.statekey && inp &&
> -         !m->m_pkthdr.pf.statekey->inp && !inp->inp_pf_sk) {
> -             m->m_pkthdr.pf.statekey->inp = inp;
> -             inp->inp_pf_sk = m->m_pkthdr.pf.statekey;
> +     struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
> +
> +     if (!pf_state_key_isvalid(sk)) {
> +             pf_pkt_unlink_state_key(m);
> +             return;
> +     }
> +
> +     /*
> +      * we don't need to grab PF-lock here. At worst case we link inp to
> +      * state, which might be just being marked as deleted by another
> +      * thread.
> +      */
> +     if (inp && !sk->inp && !inp->inp_pf_sk) {
> +             sk->inp = inp;
> +             inp->inp_pf_sk = pf_state_key_ref(sk);
>       }
>       /* The statekey has finished finding the inp, it is no longer needed. */
> -     m->m_pkthdr.pf.statekey = NULL;
> +     pf_pkt_unlink_state_key(m);
>  }
>  
>  void
> @@ -6740,10 +6786,21 @@ pf_inp_unlink(struct inpcb *inp)
>  {
>       if (inp->inp_pf_sk) {
>               inp->inp_pf_sk->inp = NULL;
> -             inp->inp_pf_sk = NULL;
> +             pf_inpcb_unlink_state_key(inp);
>       }
>  }
>  
> +void
> +pf_state_key_link(struct pf_state_key *sk, struct pf_state_key *pkt_sk)
> +{
> +     /*
> +      * Assert will not wire as long as we are called by pf_find_state()
> +      */
> +     KASSERT((pkt_sk->reverse == NULL) && (sk->reverse == NULL));
> +     pkt_sk->reverse = pf_state_key_ref(sk);
> +     sk->reverse = pf_state_key_ref(pkt_sk);
> +}
> +
>  #if NPFLOG > 0
>  void
>  pf_log_matches(struct pf_pdesc *pd, struct pf_rule *rm, struct pf_rule *am,
> @@ -6760,3 +6817,66 @@ pf_log_matches(struct pf_pdesc *pd, stru
>                       PFLOG_PACKET(pd, PFRES_MATCH, rm, am, ruleset, ri->r);
>  }
>  #endif       /* NPFLOG > 0 */
> +
> +struct pf_state_key *
> +pf_state_key_ref(struct pf_state_key *sk)
> +{
> +     if (sk != NULL)
> +             PF_REF_TAKE(sk->refcnt);
> +
> +     return (sk);
> +}
> +
> +void
> +pf_state_key_unref(struct pf_state_key *sk)
> +{
> +     if ((sk != NULL) && PF_REF_RELE(sk->refcnt)) {
> +             /* state key must be removed from tree */
> +             KASSERT(!pf_state_key_isvalid(sk));
> +             /* state key must be unlinked from reverse key */
> +             KASSERT(sk->reverse == NULL);
> +             /* state key must be unlinked from socket */
> +             KASSERT((sk->inp == NULL) || (sk->inp->inp_pf_sk == NULL));
> +             sk->inp = NULL;
> +             pool_put(&pf_state_key_pl, sk);
> +     }
> +}
> +
> +int
> +pf_state_key_isvalid(struct pf_state_key *sk)
> +{
> +     return ((sk != NULL) && (sk->removed == 0));
> +}
> +
> +void
> +pf_pkt_unlink_state_key(struct mbuf *m)
> +{
> +     pf_state_key_unref(m->m_pkthdr.pf.statekey);
> +     m->m_pkthdr.pf.statekey = NULL;
> +}
> +
> +void
> +pf_pkt_state_key_ref(struct mbuf *m)
> +{
> +     pf_state_key_ref(m->m_pkthdr.pf.statekey);
> +}
> +
> +void
> +pf_inpcb_unlink_state_key(struct inpcb *inp)
> +{
> +     if (inp != NULL) {
> +             pf_state_key_unref(inp->inp_pf_sk);
> +             inp->inp_pf_sk = NULL;
> +     }
> +}
> +
> +void
> +pf_state_key_unlink_reverse(struct pf_state_key *sk)
> +{
> +     if ((sk != NULL) && (sk->reverse != NULL)) {
> +             pf_state_key_unref(sk->reverse->reverse);
> +             sk->reverse->reverse = NULL;
> +             pf_state_key_unref(sk->reverse);
> +             sk->reverse = NULL;
> +     }
> +}
> Index: net/pfvar.h
> ===================================================================
> RCS file: /cvs/src/sys/net/pfvar.h,v
> retrieving revision 1.428
> diff -u -p -r1.428 pfvar.h
> --- net/pfvar.h       23 Dec 2015 21:04:55 -0000      1.428
> +++ net/pfvar.h       3 Jan 2016 14:40:36 -0000
> @@ -38,6 +38,9 @@
>  #include <sys/tree.h>
>  #include <sys/rwlock.h>
>  #include <sys/syslimits.h>
> +#include <sys/refcnt.h>
> +
> +#include <netinet/in.h>
>  
>  #include <net/radix.h>
>  #include <net/route.h>
> @@ -55,6 +58,11 @@ struct ip6_hdr;
>  #endif
>  #endif
>  
> +typedef struct refcnt        pf_refcnt_t;
> +#define      PF_REF_INIT(_x) refcnt_init(&(_x))
> +#define      PF_REF_TAKE(_x) refcnt_take(&(_x))
> +#define      PF_REF_RELE(_x) refcnt_rele(&(_x)) 
> +
>  enum { PF_INOUT, PF_IN, PF_OUT, PF_FWD };
>  enum { PF_PASS, PF_DROP, PF_SCRUB, PF_NOSCRUB, PF_NAT, PF_NONAT,
>         PF_BINAT, PF_NOBINAT, PF_RDR, PF_NORDR, PF_SYNPROXY_DROP, PF_DEFER,
> @@ -696,6 +704,8 @@ struct pf_state_key {
>       struct pf_statelisthead  states;
>       struct pf_state_key     *reverse;
>       struct inpcb            *inp;
> +     pf_refcnt_t              refcnt;
> +     u_int8_t                 removed;
>  };
>  #define PF_REVERSED_KEY(key, family)                         \
>       ((key[PF_SK_WIRE]->af != key[PF_SK_STACK]->af) &&       \
> @@ -1909,7 +1919,12 @@ int                     pf_postprocess_addr(struct 
> pf_sta
>  
>  void                  pf_cksum(struct pf_pdesc *, struct mbuf *);
>  
> -#endif /* _KERNEL */
> +struct pf_state_key  *pf_state_key_ref(struct pf_state_key *);
> +void                  pf_state_key_unref(struct pf_state_key *);
> +int                   pf_state_key_isvalid(struct pf_state_key *);
> +void                  pf_pkt_unlink_state_key(struct mbuf *);
> +void                  pf_pkt_state_key_ref(struct mbuf *);
>  
> +#endif /* _KERNEL */
>  
>  #endif /* _NET_PFVAR_H_ */

Reply via email to