Hello, On Thu, Apr 27, 2023 at 06:51:52AM +1000, David Gwynne wrote: </snip> > > is that kind of check in KASSET() something you have on your mind? > > perhaps I can trade KASSERT() to regular code: > > > > if (t->pft_unit != minor(dev)) > > return (EPERM); > > i would pass the dev/minor/unit to pf_find_trans() and compare along > with the ticket value as a matter of course. returning a different > errno if the minor is different is unecessary. >
something like this? struct pf_trans * pf_find_trans(uint32_t unit, uint64_t ticket) { struct pf_trans *t; rw_assert_anylock(&pfioctl_rw); LIST_FOREACH(t, &pf_ioctl_trans, pft_entry) { if (t->pft_ticket == ticket) break; } if (t->pft_unit != unit) return (NULL); return (t); } just return NULL on unit mismatch. updated diff is below. thanks and regards sashan --------8<---------------8<---------------8<------------------8<-------- diff --git a/sys/net/pf_ioctl.c b/sys/net/pf_ioctl.c index 7ea22050506..ebe1b912766 100644 --- a/sys/net/pf_ioctl.c +++ b/sys/net/pf_ioctl.c @@ -117,6 +117,11 @@ void pf_qid_unref(u_int16_t); int pf_states_clr(struct pfioc_state_kill *); int pf_states_get(struct pfioc_states *); +struct pf_trans *pf_open_trans(uint32_t); +struct pf_trans *pf_find_trans(uint32_t, uint64_t); +void pf_free_trans(struct pf_trans *); +void pf_rollback_trans(struct pf_trans *); + struct pf_rule pf_default_rule, pf_default_rule_new; struct { @@ -168,6 +173,8 @@ int pf_rtlabel_add(struct pf_addr_wrap *); void pf_rtlabel_remove(struct pf_addr_wrap *); void pf_rtlabel_copyout(struct pf_addr_wrap *); +uint64_t trans_ticket = 1; +LIST_HEAD(, pf_trans) pf_ioctl_trans = LIST_HEAD_INITIALIZER(pf_trans); void pfattach(int num) @@ -293,6 +300,25 @@ pfopen(dev_t dev, int flags, int fmt, struct proc *p) int pfclose(dev_t dev, int flags, int fmt, struct proc *p) { + struct pf_trans *w, *s; + LIST_HEAD(, pf_trans) tmp_list; + uint32_t unit = minor(dev); + + LIST_INIT(&tmp_list); + rw_enter_write(&pfioctl_rw); + LIST_FOREACH_SAFE(w, &pf_ioctl_trans, pft_entry, s) { + if (w->pft_unit == unit) { + LIST_REMOVE(w, pft_entry); + LIST_INSERT_HEAD(&tmp_list, w, pft_entry); + } + } + rw_exit_write(&pfioctl_rw); + + while ((w = LIST_FIRST(&tmp_list)) != NULL) { + LIST_REMOVE(w, pft_entry); + pf_free_trans(w); + } + return (0); } @@ -522,7 +548,7 @@ pf_qid_unref(u_int16_t qid) } int -pf_begin_rules(u_int32_t *version, const char *anchor) +pf_begin_rules(u_int32_t *ticket, const char *anchor) { struct pf_ruleset *rs; struct pf_rule *rule; @@ -533,20 +559,20 @@ pf_begin_rules(u_int32_t *version, const char *anchor) pf_rm_rule(rs->rules.inactive.ptr, rule); rs->rules.inactive.rcount--; } - *version = ++rs->rules.inactive.version; + *ticket = ++rs->rules.inactive.ticket; rs->rules.inactive.open = 1; return (0); } void -pf_rollback_rules(u_int32_t version, char *anchor) +pf_rollback_rules(u_int32_t ticket, char *anchor) { struct pf_ruleset *rs; struct pf_rule *rule; rs = pf_find_ruleset(anchor); if (rs == NULL || !rs->rules.inactive.open || - rs->rules.inactive.version != version) + rs->rules.inactive.ticket != ticket) return; while ((rule = TAILQ_FIRST(rs->rules.inactive.ptr)) != NULL) { pf_rm_rule(rs->rules.inactive.ptr, rule); @@ -825,7 +851,7 @@ pf_hash_rule(MD5_CTX *ctx, struct pf_rule *rule) } int -pf_commit_rules(u_int32_t version, char *anchor) +pf_commit_rules(u_int32_t ticket, char *anchor) { struct pf_ruleset *rs; struct pf_rule *rule; @@ -834,7 +860,7 @@ pf_commit_rules(u_int32_t version, char *anchor) rs = pf_find_ruleset(anchor); if (rs == NULL || !rs->rules.inactive.open || - version != rs->rules.inactive.version) + ticket != rs->rules.inactive.ticket) return (EBUSY); if (rs == &pf_main_ruleset) @@ -849,7 +875,7 @@ pf_commit_rules(u_int32_t version, char *anchor) rs->rules.inactive.ptr = old_rules; rs->rules.inactive.rcount = old_rcount; - rs->rules.active.version = rs->rules.inactive.version; + rs->rules.active.ticket = rs->rules.inactive.ticket; pf_calc_skip_steps(rs->rules.active.ptr); @@ -1142,10 +1168,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) return (EACCES); } - if (flags & FWRITE) - rw_enter_write(&pfioctl_rw); - else - rw_enter_read(&pfioctl_rw); + rw_enter_write(&pfioctl_rw); switch (cmd) { @@ -1191,7 +1214,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) NET_LOCK(); PF_LOCK(); - pq->ticket = pf_main_ruleset.rules.active.version; + pq->ticket = pf_main_ruleset.rules.active.ticket; /* save state to not run over them all each time? */ qs = TAILQ_FIRST(pf_queues_active); @@ -1212,7 +1235,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) NET_LOCK(); PF_LOCK(); - if (pq->ticket != pf_main_ruleset.rules.active.version) { + if (pq->ticket != pf_main_ruleset.rules.active.ticket) { error = EBUSY; PF_UNLOCK(); NET_UNLOCK(); @@ -1243,7 +1266,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) NET_LOCK(); PF_LOCK(); - if (pq->ticket != pf_main_ruleset.rules.active.version) { + if (pq->ticket != pf_main_ruleset.rules.active.ticket) { error = EBUSY; PF_UNLOCK(); NET_UNLOCK(); @@ -1290,7 +1313,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) NET_LOCK(); PF_LOCK(); - if (q->ticket != pf_main_ruleset.rules.inactive.version) { + if (q->ticket != pf_main_ruleset.rules.inactive.ticket) { error = EBUSY; PF_UNLOCK(); NET_UNLOCK(); @@ -1386,7 +1409,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) pf_rule_free(rule); goto fail; } - if (pr->ticket != ruleset->rules.inactive.version) { + if (pr->ticket != ruleset->rules.inactive.ticket) { error = EBUSY; PF_UNLOCK(); NET_UNLOCK(); @@ -1464,7 +1487,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) pr->nr = tail->nr + 1; else pr->nr = 0; - pr->ticket = ruleset->rules.active.version; + pr->ticket = ruleset->rules.active.ticket; PF_UNLOCK(); NET_UNLOCK(); break; @@ -1486,7 +1509,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) NET_UNLOCK(); goto fail; } - if (pr->ticket != ruleset->rules.active.version) { + if (pr->ticket != ruleset->rules.active.ticket) { error = EBUSY; PF_UNLOCK(); NET_UNLOCK(); @@ -1560,7 +1583,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) if (ruleset == NULL) error = EINVAL; else - pcr->ticket = ++ruleset->rules.active.version; + pcr->ticket = ++ruleset->rules.active.ticket; PF_UNLOCK(); NET_UNLOCK(); @@ -1610,7 +1633,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) goto fail; } - if (pcr->ticket != ruleset->rules.active.version) { + if (pcr->ticket != ruleset->rules.active.ticket) { error = EINVAL; PF_UNLOCK(); NET_UNLOCK(); @@ -1707,7 +1730,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) TAILQ_FOREACH(oldrule, ruleset->rules.active.ptr, entries) oldrule->nr = nr++; - ruleset->rules.active.version++; + ruleset->rules.active.ticket++; pf_calc_skip_steps(ruleset->rules.active.ptr); pf_remove_if_empty_ruleset(ruleset); @@ -2646,7 +2669,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) rs = pf_find_ruleset(ioe->anchor); if (rs == NULL || !rs->rules.inactive.open || - rs->rules.inactive.version != + rs->rules.inactive.ticket != ioe->ticket) { PF_UNLOCK(); NET_UNLOCK(); @@ -3022,10 +3045,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, struct proc *p) break; } fail: - if (flags & FWRITE) - rw_exit_write(&pfioctl_rw); - else - rw_exit_read(&pfioctl_rw); + rw_exit_write(&pfioctl_rw); return (error); } @@ -3244,3 +3264,55 @@ pf_sysctl(void *oldp, size_t *oldlenp, void *newp, size_t newlen) return sysctl_rdstruct(oldp, oldlenp, newp, &pfs, sizeof(pfs)); } + +struct pf_trans * +pf_open_trans(uint32_t unit) +{ + static uint64_t ticket = 1; + struct pf_trans *t; + + rw_assert_wrlock(&pfioctl_rw); + + t = malloc(sizeof(*t), M_TEMP, M_WAITOK); + memset(t, 0, sizeof(struct pf_trans)); + t->pft_unit = unit; + t->pft_ticket = ticket++; + + LIST_INSERT_HEAD(&pf_ioctl_trans, t, pft_entry); + + return (t); +} + +struct pf_trans * +pf_find_trans(uint32_t unit, uint64_t ticket) +{ + struct pf_trans *t; + + rw_assert_anylock(&pfioctl_rw); + + LIST_FOREACH(t, &pf_ioctl_trans, pft_entry) { + if (t->pft_ticket == ticket) + break; + } + + if (t->pft_unit != unit) + return (NULL); + + return (t); +} + +void +pf_free_trans(struct pf_trans *t) +{ + free(t, M_TEMP, sizeof(*t)); +} + +void +pf_rollback_trans(struct pf_trans *t) +{ + if (t != NULL) { + rw_assert_wrlock(&pfioctl_rw); + LIST_REMOVE(t, pft_entry); + pf_free_trans(t); + } +} diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h index cf1e34c36b4..3a7ff6b295c 100644 --- a/sys/net/pfvar.h +++ b/sys/net/pfvar.h @@ -822,7 +822,7 @@ struct pf_ruleset { struct { struct pf_rulequeue *ptr; u_int32_t rcount; - u_int32_t version; + u_int32_t ticket; int open; } active, inactive; } rules; diff --git a/sys/net/pfvar_priv.h b/sys/net/pfvar_priv.h index 38fff6473aa..5af2027733a 100644 --- a/sys/net/pfvar_priv.h +++ b/sys/net/pfvar_priv.h @@ -322,6 +322,17 @@ enum { extern struct cpumem *pf_anchor_stack; +enum pf_trans_type { + PF_TRANS_NONE, + PF_TRANS_MAX +}; + +struct pf_trans { + LIST_ENTRY(pf_trans) pft_entry; + uint32_t pft_unit; /* process id */ + uint64_t pft_ticket; + enum pf_trans_type pft_type; +}; extern struct task pf_purge_task; extern struct timeout pf_purge_to;