OK bluhm@
On Sun, Jan 03, 2016 at 04:32:58PM +0100, Alexandr Nedvedicky wrote:
> Index: kern/uipc_mbuf.c
> ===================================================================
> RCS file: /cvs/src/sys/kern/uipc_mbuf.c,v
> retrieving revision 1.216
> diff -u -p -r1.216 uipc_mbuf.c
> --- kern/uipc_mbuf.c 23 Dec 2015 21:04:55 -0000 1.216
> +++ kern/uipc_mbuf.c 3 Jan 2016 14:40:29 -0000
> @@ -72,6 +72,8 @@
> * Research Laboratory (NRL).
> */
>
> +#include "pf.h"
> +
> #include <sys/param.h>
> #include <sys/systm.h>
> #include <sys/malloc.h>
> @@ -85,6 +87,9 @@
> #include <sys/socket.h>
> #include <sys/socketvar.h>
> #include <net/if.h>
> +#if NPF > 0
> +#include <net/pfvar.h>
> +#endif /* NPF > 0 */
>
>
> #include <uvm/uvm_extern.h>
> @@ -261,6 +266,10 @@ m_resethdr(struct mbuf *m)
> /* delete all mbuf tags to reset the state */
> m_tag_delete_chain(m);
>
> +#if NPF > 0
> + pf_pkt_unlink_state_key(m);
> +#endif /* NPF > 0 */
> +
> /* like m_inithdr(), but keep any associated data and mbufs */
> memset(&m->m_pkthdr, 0, sizeof(m->m_pkthdr));
> m->m_pkthdr.pf.prio = IFQ_DEFPRIO;
> @@ -350,8 +359,12 @@ m_free(struct mbuf *m)
> if (n)
> n->m_flags |= M_ZEROIZE;
> }
> - if (m->m_flags & M_PKTHDR)
> + if (m->m_flags & M_PKTHDR) {
> m_tag_delete_chain(m);
> +#if NPF > 0
> + pf_pkt_unlink_state_key(m);
> +#endif /* NPF > 0 */
> + }
> if (m->m_flags & M_EXT)
> m_extfree(m);
>
> @@ -1201,6 +1214,10 @@ m_dup_pkthdr(struct mbuf *to, struct mbu
> to->m_flags = (to->m_flags & (M_EXT | M_EXTWR));
> to->m_flags |= (from->m_flags & M_COPYFLAGS);
> to->m_pkthdr = from->m_pkthdr;
> +
> +#if NPF > 0
> + pf_pkt_state_key_ref(to);
> +#endif /* NPF > 0 */
>
> SLIST_INIT(&to->m_pkthdr.ph_tags);
>
> Index: sys/mbuf.h
> ===================================================================
> RCS file: /cvs/src/sys/sys/mbuf.h,v
> retrieving revision 1.205
> diff -u -p -r1.205 mbuf.h
> --- sys/mbuf.h 21 Nov 2015 11:46:25 -0000 1.205
> +++ sys/mbuf.h 3 Jan 2016 14:40:30 -0000
> @@ -316,6 +316,7 @@ struct mbuf {
> (to)->m_pkthdr = (from)->m_pkthdr; \
> (from)->m_flags &= ~M_PKTHDR; \
> SLIST_INIT(&(from)->m_pkthdr.ph_tags); \
> + (from)->m_pkthdr.pf.statekey = NULL; \
> } while (/* CONSTCOND */ 0)
>
> /*
> Index: net/pf.c
> ===================================================================
> RCS file: /cvs/src/sys/net/pf.c,v
> retrieving revision 1.962
> diff -u -p -r1.962 pf.c
> --- net/pf.c 23 Dec 2015 21:04:55 -0000 1.962
> +++ net/pf.c 3 Jan 2016 14:40:33 -0000
> @@ -231,6 +231,11 @@ int pf_step_out_of_anchor(int *,
> stru
> void pf_counters_inc(int, struct pf_pdesc *,
> struct pf_state *, struct pf_rule *,
> struct pf_rule *);
> +void pf_state_key_link(struct pf_state_key *,
> + struct pf_state_key *);
> +void pf_inpcb_unlink_state_key(struct inpcb *);
> +void pf_state_key_unlink_reverse(struct pf_state_key *);
> +
> #if NPFLOG > 0
> void pf_log_matches(struct pf_pdesc *, struct pf_rule *,
> struct pf_rule *, struct pf_ruleset *,
> @@ -694,8 +699,11 @@ pf_state_key_attach(struct pf_state_key
> }
> pool_put(&pf_state_key_pl, sk);
> s->key[idx] = cur;
> - } else
> + } else {
> s->key[idx] = sk;
> + /* need to grab reference for PF */
> + pf_state_key_ref(sk);
> + }
>
> if ((si = pool_get(&pf_state_item_pl, PR_NOWAIT)) == NULL) {
> pf_state_key_detach(s, idx);
> @@ -732,6 +740,7 @@ void
> pf_state_key_detach(struct pf_state *s, int idx)
> {
> struct pf_state_item *si;
> + struct pf_state_key *sk;
>
> if (s->key[idx] == NULL)
> return;
> @@ -745,15 +754,15 @@ pf_state_key_detach(struct pf_state *s,
> pool_put(&pf_state_item_pl, si);
> }
>
> - if (TAILQ_EMPTY(&s->key[idx]->states)) {
> - RB_REMOVE(pf_state_tree, &pf_statetbl, s->key[idx]);
> - if (s->key[idx]->reverse)
> - s->key[idx]->reverse->reverse = NULL;
> - if (s->key[idx]->inp)
> - s->key[idx]->inp->inp_pf_sk = NULL;
> - pool_put(&pf_state_key_pl, s->key[idx]);
> - }
> + sk = s->key[idx];
> s->key[idx] = NULL;
> + if (TAILQ_EMPTY(&sk->states)) {
> + RB_REMOVE(pf_state_tree, &pf_statetbl, sk);
> + sk->removed = 1;
> + pf_state_key_unlink_reverse(sk);
> + pf_inpcb_unlink_state_key(sk->inp);
> + pf_state_key_unref(sk);
> + }
> }
>
> struct pf_state_key *
> @@ -840,6 +849,8 @@ pf_state_key_setup(struct pf_pdesc *pd,
> sk1->proto = pd->proto;
> sk1->af = pd->af;
> sk1->rdomain = pd->rdomain;
> + PF_REF_INIT(sk1->refcnt);
> + sk1->removed = 0;
> if (rtableid >= 0)
> wrdom = rtable_l2(rtableid);
>
> @@ -871,6 +882,8 @@ pf_state_key_setup(struct pf_pdesc *pd,
> sk2->proto = pd->proto;
> sk2->af = pd->naf;
> sk2->rdomain = wrdom;
> + PF_REF_INIT(sk2->refcnt);
> + sk2->removed = 0;
> } else
> sk2 = sk1;
>
> @@ -986,7 +999,7 @@ struct pf_state *
> pf_find_state(struct pfi_kif *kif, struct pf_state_key_cmp *key, u_int dir,
> struct mbuf *m)
> {
> - struct pf_state_key *sk;
> + struct pf_state_key *sk, *pkt_sk, *inp_sk;
> struct pf_state_item *si;
>
> pf_status.fcounters[FCNT_STATE_SEARCH]++;
> @@ -996,31 +1009,47 @@ pf_find_state(struct pfi_kif *kif, struc
> addlog("\n");
> }
>
> - if (dir == PF_OUT && m->m_pkthdr.pf.statekey &&
> - m->m_pkthdr.pf.statekey->reverse)
> - sk = m->m_pkthdr.pf.statekey->reverse;
> - else if (dir == PF_OUT && m->m_pkthdr.pf.inp &&
> - m->m_pkthdr.pf.inp->inp_pf_sk)
> - sk = m->m_pkthdr.pf.inp->inp_pf_sk;
> - else {
> + inp_sk = NULL;
> + pkt_sk = NULL;
> + sk = NULL;
> + if (dir == PF_OUT) {
> + /* first if block deals with outbound forwarded packet */
> + pkt_sk = m->m_pkthdr.pf.statekey;
> + if (pf_state_key_isvalid(pkt_sk) &&
> + pf_state_key_isvalid(pkt_sk->reverse)) {
> + sk = pkt_sk->reverse;
> + } else {
> + pf_pkt_unlink_state_key(m);
> + pkt_sk = NULL;
> + }
> +
> + if (pkt_sk == NULL) {
> + /* here we deal with local outbound packet */
> + if (m->m_pkthdr.pf.inp != NULL) {
> + inp_sk = m->m_pkthdr.pf.inp->inp_pf_sk;
> + if (pf_state_key_isvalid(inp_sk))
> + sk = inp_sk;
> + else
> + pf_inpcb_unlink_state_key(
> + m->m_pkthdr.pf.inp);
> + }
> + }
> + }
> +
> + if (sk == NULL) {
> if ((sk = RB_FIND(pf_state_tree, &pf_statetbl,
> (struct pf_state_key *)key)) == NULL)
> return (NULL);
> - if (dir == PF_OUT && m->m_pkthdr.pf.statekey &&
> - pf_compare_state_keys(m->m_pkthdr.pf.statekey, sk,
> - kif, dir) == 0) {
> - m->m_pkthdr.pf.statekey->reverse = sk;
> - sk->reverse = m->m_pkthdr.pf.statekey;
> - } else if (dir == PF_OUT && m->m_pkthdr.pf.inp && !sk->inp) {
> - m->m_pkthdr.pf.inp->inp_pf_sk = sk;
> - sk->inp = m->m_pkthdr.pf.inp;
> - }
> + if (dir == PF_OUT && pkt_sk &&
> + pf_compare_state_keys(pkt_sk, sk, kif, dir) == 0)
> + pf_state_key_link(sk, pkt_sk);
> + else if (dir == PF_OUT)
> + pf_inp_link(m, m->m_pkthdr.pf.inp);
> }
>
> - if (dir == PF_OUT) {
> - m->m_pkthdr.pf.statekey = NULL;
> - m->m_pkthdr.pf.inp = NULL;
> - }
> + /* remove firewall data from outbound packet */
> + if (dir == PF_OUT)
> + pf_pkt_addr_changed(m);
>
> /* list is sorted, if-bound states before floating ones */
> TAILQ_FOREACH(si, &sk->states, entry)
> @@ -6531,11 +6560,13 @@ done:
> if (action == PF_PASS && qid)
> pd.m->m_pkthdr.pf.qid = qid;
> if (pd.dir == PF_IN && s && s->key[PF_SK_STACK])
> - pd.m->m_pkthdr.pf.statekey = s->key[PF_SK_STACK];
> + pd.m->m_pkthdr.pf.statekey =
> + pf_state_key_ref(s->key[PF_SK_STACK]);
> if (pd.dir == PF_OUT &&
> pd.m->m_pkthdr.pf.inp && !pd.m->m_pkthdr.pf.inp->inp_pf_sk &&
> s && s->key[PF_SK_STACK] && !s->key[PF_SK_STACK]->inp) {
> - pd.m->m_pkthdr.pf.inp->inp_pf_sk = s->key[PF_SK_STACK];
> + pd.m->m_pkthdr.pf.inp->inp_pf_sk =
> + pf_state_key_ref(s->key[PF_SK_STACK]);
> s->key[PF_SK_STACK]->inp = pd.m->m_pkthdr.pf.inp;
> }
>
> @@ -6706,7 +6737,7 @@ pf_cksum(struct pf_pdesc *pd, struct mbu
> void
> pf_pkt_addr_changed(struct mbuf *m)
> {
> - m->m_pkthdr.pf.statekey = NULL;
> + pf_pkt_unlink_state_key(m);
> m->m_pkthdr.pf.inp = NULL;
> }
>
> @@ -6714,25 +6745,40 @@ struct inpcb *
> pf_inp_lookup(struct mbuf *m)
> {
> struct inpcb *inp = NULL;
> + struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
>
> - if (m->m_pkthdr.pf.statekey) {
> + if (!pf_state_key_isvalid(sk))
> + pf_pkt_unlink_state_key(m);
> + else
> inp = m->m_pkthdr.pf.statekey->inp;
> - if (inp && inp->inp_pf_sk)
> - KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
> - }
> +
> + if (inp && inp->inp_pf_sk)
> + KASSERT(m->m_pkthdr.pf.statekey == inp->inp_pf_sk);
> +
> return (inp);
> }
>
> void
> pf_inp_link(struct mbuf *m, struct inpcb *inp)
> {
> - if (m->m_pkthdr.pf.statekey && inp &&
> - !m->m_pkthdr.pf.statekey->inp && !inp->inp_pf_sk) {
> - m->m_pkthdr.pf.statekey->inp = inp;
> - inp->inp_pf_sk = m->m_pkthdr.pf.statekey;
> + struct pf_state_key *sk = m->m_pkthdr.pf.statekey;
> +
> + if (!pf_state_key_isvalid(sk)) {
> + pf_pkt_unlink_state_key(m);
> + return;
> + }
> +
> + /*
> + * we don't need to grab PF-lock here. At worst case we link inp to
> + * state, which might be just being marked as deleted by another
> + * thread.
> + */
> + if (inp && !sk->inp && !inp->inp_pf_sk) {
> + sk->inp = inp;
> + inp->inp_pf_sk = pf_state_key_ref(sk);
> }
> /* The statekey has finished finding the inp, it is no longer needed. */
> - m->m_pkthdr.pf.statekey = NULL;
> + pf_pkt_unlink_state_key(m);
> }
>
> void
> @@ -6740,10 +6786,21 @@ pf_inp_unlink(struct inpcb *inp)
> {
> if (inp->inp_pf_sk) {
> inp->inp_pf_sk->inp = NULL;
> - inp->inp_pf_sk = NULL;
> + pf_inpcb_unlink_state_key(inp);
> }
> }
>
> +void
> +pf_state_key_link(struct pf_state_key *sk, struct pf_state_key *pkt_sk)
> +{
> + /*
> + * Assert will not wire as long as we are called by pf_find_state()
> + */
> + KASSERT((pkt_sk->reverse == NULL) && (sk->reverse == NULL));
> + pkt_sk->reverse = pf_state_key_ref(sk);
> + sk->reverse = pf_state_key_ref(pkt_sk);
> +}
> +
> #if NPFLOG > 0
> void
> pf_log_matches(struct pf_pdesc *pd, struct pf_rule *rm, struct pf_rule *am,
> @@ -6760,3 +6817,66 @@ pf_log_matches(struct pf_pdesc *pd, stru
> PFLOG_PACKET(pd, PFRES_MATCH, rm, am, ruleset, ri->r);
> }
> #endif /* NPFLOG > 0 */
> +
> +struct pf_state_key *
> +pf_state_key_ref(struct pf_state_key *sk)
> +{
> + if (sk != NULL)
> + PF_REF_TAKE(sk->refcnt);
> +
> + return (sk);
> +}
> +
> +void
> +pf_state_key_unref(struct pf_state_key *sk)
> +{
> + if ((sk != NULL) && PF_REF_RELE(sk->refcnt)) {
> + /* state key must be removed from tree */
> + KASSERT(!pf_state_key_isvalid(sk));
> + /* state key must be unlinked from reverse key */
> + KASSERT(sk->reverse == NULL);
> + /* state key must be unlinked from socket */
> + KASSERT((sk->inp == NULL) || (sk->inp->inp_pf_sk == NULL));
> + sk->inp = NULL;
> + pool_put(&pf_state_key_pl, sk);
> + }
> +}
> +
> +int
> +pf_state_key_isvalid(struct pf_state_key *sk)
> +{
> + return ((sk != NULL) && (sk->removed == 0));
> +}
> +
> +void
> +pf_pkt_unlink_state_key(struct mbuf *m)
> +{
> + pf_state_key_unref(m->m_pkthdr.pf.statekey);
> + m->m_pkthdr.pf.statekey = NULL;
> +}
> +
> +void
> +pf_pkt_state_key_ref(struct mbuf *m)
> +{
> + pf_state_key_ref(m->m_pkthdr.pf.statekey);
> +}
> +
> +void
> +pf_inpcb_unlink_state_key(struct inpcb *inp)
> +{
> + if (inp != NULL) {
> + pf_state_key_unref(inp->inp_pf_sk);
> + inp->inp_pf_sk = NULL;
> + }
> +}
> +
> +void
> +pf_state_key_unlink_reverse(struct pf_state_key *sk)
> +{
> + if ((sk != NULL) && (sk->reverse != NULL)) {
> + pf_state_key_unref(sk->reverse->reverse);
> + sk->reverse->reverse = NULL;
> + pf_state_key_unref(sk->reverse);
> + sk->reverse = NULL;
> + }
> +}
> Index: net/pfvar.h
> ===================================================================
> RCS file: /cvs/src/sys/net/pfvar.h,v
> retrieving revision 1.428
> diff -u -p -r1.428 pfvar.h
> --- net/pfvar.h 23 Dec 2015 21:04:55 -0000 1.428
> +++ net/pfvar.h 3 Jan 2016 14:40:36 -0000
> @@ -38,6 +38,9 @@
> #include <sys/tree.h>
> #include <sys/rwlock.h>
> #include <sys/syslimits.h>
> +#include <sys/refcnt.h>
> +
> +#include <netinet/in.h>
>
> #include <net/radix.h>
> #include <net/route.h>
> @@ -55,6 +58,11 @@ struct ip6_hdr;
> #endif
> #endif
>
> +typedef struct refcnt pf_refcnt_t;
> +#define PF_REF_INIT(_x) refcnt_init(&(_x))
> +#define PF_REF_TAKE(_x) refcnt_take(&(_x))
> +#define PF_REF_RELE(_x) refcnt_rele(&(_x))
> +
> enum { PF_INOUT, PF_IN, PF_OUT, PF_FWD };
> enum { PF_PASS, PF_DROP, PF_SCRUB, PF_NOSCRUB, PF_NAT, PF_NONAT,
> PF_BINAT, PF_NOBINAT, PF_RDR, PF_NORDR, PF_SYNPROXY_DROP, PF_DEFER,
> @@ -696,6 +704,8 @@ struct pf_state_key {
> struct pf_statelisthead states;
> struct pf_state_key *reverse;
> struct inpcb *inp;
> + pf_refcnt_t refcnt;
> + u_int8_t removed;
> };
> #define PF_REVERSED_KEY(key, family) \
> ((key[PF_SK_WIRE]->af != key[PF_SK_STACK]->af) && \
> @@ -1909,7 +1919,12 @@ int pf_postprocess_addr(struct
> pf_sta
>
> void pf_cksum(struct pf_pdesc *, struct mbuf *);
>
> -#endif /* _KERNEL */
> +struct pf_state_key *pf_state_key_ref(struct pf_state_key *);
> +void pf_state_key_unref(struct pf_state_key *);
> +int pf_state_key_isvalid(struct pf_state_key *);
> +void pf_pkt_unlink_state_key(struct mbuf *);
> +void pf_pkt_state_key_ref(struct mbuf *);
>
> +#endif /* _KERNEL */
>
> #endif /* _NET_PFVAR_H_ */