Hello,

On Thu, Apr 27, 2023 at 06:51:52AM +1000, David Gwynne wrote:
</snip>
> >     is that kind of check in KASSET() something you have on your mind?
> >     perhaps I can trade KASSERT() to regular code:
> > 
> >     if (t->pft_unit != minor(dev))
> >             return (EPERM);
> 
> i would pass the dev/minor/unit to pf_find_trans() and compare along
> with the ticket value as a matter of course. returning a different
> errno if the minor is different is unecessary.
> 

something like this?

    struct pf_trans *
    pf_find_trans(uint32_t unit, uint64_t ticket)
    {
            struct pf_trans     *t;

            rw_assert_anylock(&pfioctl_rw);

            LIST_FOREACH(t, &pf_ioctl_trans, pft_entry) {
                    if (t->pft_ticket == ticket)
                            break;
            }

            if (t->pft_unit != unit)
                    return (NULL);

            return (t);
    }

just return NULL on unit mismatch.  updated diff is below.


thanks and
regards
sashan

--------8<---------------8<---------------8<------------------8<--------
diff --git a/sys/net/pf_ioctl.c b/sys/net/pf_ioctl.c
index 7ea22050506..ebe1b912766 100644
--- a/sys/net/pf_ioctl.c
+++ b/sys/net/pf_ioctl.c
@@ -117,6 +117,11 @@ void                        pf_qid_unref(u_int16_t);
 int                     pf_states_clr(struct pfioc_state_kill *);
 int                     pf_states_get(struct pfioc_states *);
 
+struct pf_trans                *pf_open_trans(uint32_t);
+struct pf_trans                *pf_find_trans(uint32_t, uint64_t);
+void                    pf_free_trans(struct pf_trans *);
+void                    pf_rollback_trans(struct pf_trans *);
+
 struct pf_rule          pf_default_rule, pf_default_rule_new;
 
 struct {
@@ -168,6 +173,8 @@ int                  pf_rtlabel_add(struct pf_addr_wrap *);
 void                    pf_rtlabel_remove(struct pf_addr_wrap *);
 void                    pf_rtlabel_copyout(struct pf_addr_wrap *);
 
+uint64_t trans_ticket = 1;
+LIST_HEAD(, pf_trans)  pf_ioctl_trans = LIST_HEAD_INITIALIZER(pf_trans);
 
 void
 pfattach(int num)
@@ -293,6 +300,25 @@ pfopen(dev_t dev, int flags, int fmt, struct proc *p)
 int
 pfclose(dev_t dev, int flags, int fmt, struct proc *p)
 {
+       struct pf_trans *w, *s;
+       LIST_HEAD(, pf_trans)   tmp_list;
+       uint32_t unit = minor(dev);
+
+       LIST_INIT(&tmp_list);
+       rw_enter_write(&pfioctl_rw);
+       LIST_FOREACH_SAFE(w, &pf_ioctl_trans, pft_entry, s) {
+               if (w->pft_unit == unit) {
+                       LIST_REMOVE(w, pft_entry);
+                       LIST_INSERT_HEAD(&tmp_list, w, pft_entry);
+               }
+       }
+       rw_exit_write(&pfioctl_rw);
+
+       while ((w = LIST_FIRST(&tmp_list)) != NULL) {
+               LIST_REMOVE(w, pft_entry);
+               pf_free_trans(w);
+       }
+
        return (0);
 }
 
@@ -522,7 +548,7 @@ pf_qid_unref(u_int16_t qid)
 }
 
 int
-pf_begin_rules(u_int32_t *version, const char *anchor)
+pf_begin_rules(u_int32_t *ticket, const char *anchor)
 {
        struct pf_ruleset       *rs;
        struct pf_rule          *rule;
@@ -533,20 +559,20 @@ pf_begin_rules(u_int32_t *version, const char *anchor)
                pf_rm_rule(rs->rules.inactive.ptr, rule);
                rs->rules.inactive.rcount--;
        }
-       *version = ++rs->rules.inactive.version;
+       *ticket = ++rs->rules.inactive.ticket;
        rs->rules.inactive.open = 1;
        return (0);
 }
 
 void
-pf_rollback_rules(u_int32_t version, char *anchor)
+pf_rollback_rules(u_int32_t ticket, char *anchor)
 {
        struct pf_ruleset       *rs;
        struct pf_rule          *rule;
 
        rs = pf_find_ruleset(anchor);
        if (rs == NULL || !rs->rules.inactive.open ||
-           rs->rules.inactive.version != version)
+           rs->rules.inactive.ticket != ticket)
                return;
        while ((rule = TAILQ_FIRST(rs->rules.inactive.ptr)) != NULL) {
                pf_rm_rule(rs->rules.inactive.ptr, rule);
@@ -825,7 +851,7 @@ pf_hash_rule(MD5_CTX *ctx, struct pf_rule *rule)
 }
 
 int
-pf_commit_rules(u_int32_t version, char *anchor)
+pf_commit_rules(u_int32_t ticket, char *anchor)
 {
        struct pf_ruleset       *rs;
        struct pf_rule          *rule;
@@ -834,7 +860,7 @@ pf_commit_rules(u_int32_t version, char *anchor)
 
        rs = pf_find_ruleset(anchor);
        if (rs == NULL || !rs->rules.inactive.open ||
-           version != rs->rules.inactive.version)
+           ticket != rs->rules.inactive.ticket)
                return (EBUSY);
 
        if (rs == &pf_main_ruleset)
@@ -849,7 +875,7 @@ pf_commit_rules(u_int32_t version, char *anchor)
        rs->rules.inactive.ptr = old_rules;
        rs->rules.inactive.rcount = old_rcount;
 
-       rs->rules.active.version = rs->rules.inactive.version;
+       rs->rules.active.ticket = rs->rules.inactive.ticket;
        pf_calc_skip_steps(rs->rules.active.ptr);
 
 
@@ -1142,10 +1168,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                        return (EACCES);
                }
 
-       if (flags & FWRITE)
-               rw_enter_write(&pfioctl_rw);
-       else
-               rw_enter_read(&pfioctl_rw);
+       rw_enter_write(&pfioctl_rw);
 
        switch (cmd) {
 
@@ -1191,7 +1214,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
 
                NET_LOCK();
                PF_LOCK();
-               pq->ticket = pf_main_ruleset.rules.active.version;
+               pq->ticket = pf_main_ruleset.rules.active.ticket;
 
                /* save state to not run over them all each time? */
                qs = TAILQ_FIRST(pf_queues_active);
@@ -1212,7 +1235,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
 
                NET_LOCK();
                PF_LOCK();
-               if (pq->ticket != pf_main_ruleset.rules.active.version) {
+               if (pq->ticket != pf_main_ruleset.rules.active.ticket) {
                        error = EBUSY;
                        PF_UNLOCK();
                        NET_UNLOCK();
@@ -1243,7 +1266,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
 
                NET_LOCK();
                PF_LOCK();
-               if (pq->ticket != pf_main_ruleset.rules.active.version) {
+               if (pq->ticket != pf_main_ruleset.rules.active.ticket) {
                        error = EBUSY;
                        PF_UNLOCK();
                        NET_UNLOCK();
@@ -1290,7 +1313,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
 
                NET_LOCK();
                PF_LOCK();
-               if (q->ticket != pf_main_ruleset.rules.inactive.version) {
+               if (q->ticket != pf_main_ruleset.rules.inactive.ticket) {
                        error = EBUSY;
                        PF_UNLOCK();
                        NET_UNLOCK();
@@ -1386,7 +1409,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                        pf_rule_free(rule);
                        goto fail;
                }
-               if (pr->ticket != ruleset->rules.inactive.version) {
+               if (pr->ticket != ruleset->rules.inactive.ticket) {
                        error = EBUSY;
                        PF_UNLOCK();
                        NET_UNLOCK();
@@ -1464,7 +1487,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                        pr->nr = tail->nr + 1;
                else
                        pr->nr = 0;
-               pr->ticket = ruleset->rules.active.version;
+               pr->ticket = ruleset->rules.active.ticket;
                PF_UNLOCK();
                NET_UNLOCK();
                break;
@@ -1486,7 +1509,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                        NET_UNLOCK();
                        goto fail;
                }
-               if (pr->ticket != ruleset->rules.active.version) {
+               if (pr->ticket != ruleset->rules.active.ticket) {
                        error = EBUSY;
                        PF_UNLOCK();
                        NET_UNLOCK();
@@ -1560,7 +1583,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                        if (ruleset == NULL)
                                error = EINVAL;
                        else
-                               pcr->ticket = ++ruleset->rules.active.version;
+                               pcr->ticket = ++ruleset->rules.active.ticket;
 
                        PF_UNLOCK();
                        NET_UNLOCK();
@@ -1610,7 +1633,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                        goto fail;
                }
 
-               if (pcr->ticket != ruleset->rules.active.version) {
+               if (pcr->ticket != ruleset->rules.active.ticket) {
                        error = EINVAL;
                        PF_UNLOCK();
                        NET_UNLOCK();
@@ -1707,7 +1730,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                TAILQ_FOREACH(oldrule, ruleset->rules.active.ptr, entries)
                        oldrule->nr = nr++;
 
-               ruleset->rules.active.version++;
+               ruleset->rules.active.ticket++;
 
                pf_calc_skip_steps(ruleset->rules.active.ptr);
                pf_remove_if_empty_ruleset(ruleset);
@@ -2646,7 +2669,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                                rs = pf_find_ruleset(ioe->anchor);
                                if (rs == NULL ||
                                    !rs->rules.inactive.open ||
-                                   rs->rules.inactive.version !=
+                                   rs->rules.inactive.ticket !=
                                    ioe->ticket) {
                                        PF_UNLOCK();
                                        NET_UNLOCK();
@@ -3022,10 +3045,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
struct proc *p)
                break;
        }
 fail:
-       if (flags & FWRITE)
-               rw_exit_write(&pfioctl_rw);
-       else
-               rw_exit_read(&pfioctl_rw);
+       rw_exit_write(&pfioctl_rw);
 
        return (error);
 }
@@ -3244,3 +3264,55 @@ pf_sysctl(void *oldp, size_t *oldlenp, void *newp, 
size_t newlen)
 
        return sysctl_rdstruct(oldp, oldlenp, newp, &pfs, sizeof(pfs));
 }
+
+struct pf_trans *
+pf_open_trans(uint32_t unit)
+{
+       static uint64_t ticket = 1;
+       struct pf_trans *t;
+
+       rw_assert_wrlock(&pfioctl_rw);
+
+       t = malloc(sizeof(*t), M_TEMP, M_WAITOK);
+       memset(t, 0, sizeof(struct pf_trans));
+       t->pft_unit = unit;
+       t->pft_ticket = ticket++;
+
+       LIST_INSERT_HEAD(&pf_ioctl_trans, t, pft_entry);
+
+       return (t);
+}
+
+struct pf_trans *
+pf_find_trans(uint32_t unit, uint64_t ticket)
+{
+       struct pf_trans *t;
+
+       rw_assert_anylock(&pfioctl_rw);
+
+       LIST_FOREACH(t, &pf_ioctl_trans, pft_entry) {
+               if (t->pft_ticket == ticket)
+                       break;
+       }
+
+       if (t->pft_unit != unit)
+               return (NULL);
+
+       return (t);
+}
+
+void
+pf_free_trans(struct pf_trans *t)
+{
+       free(t, M_TEMP, sizeof(*t));
+}
+
+void
+pf_rollback_trans(struct pf_trans *t)
+{
+       if (t != NULL) {
+               rw_assert_wrlock(&pfioctl_rw);
+               LIST_REMOVE(t, pft_entry);
+               pf_free_trans(t);
+       }
+}
diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h
index cf1e34c36b4..3a7ff6b295c 100644
--- a/sys/net/pfvar.h
+++ b/sys/net/pfvar.h
@@ -822,7 +822,7 @@ struct pf_ruleset {
                struct {
                        struct pf_rulequeue     *ptr;
                        u_int32_t                rcount;
-                       u_int32_t                version;
+                       u_int32_t                ticket;
                        int                      open;
                }                        active, inactive;
        }                        rules;
diff --git a/sys/net/pfvar_priv.h b/sys/net/pfvar_priv.h
index 38fff6473aa..5af2027733a 100644
--- a/sys/net/pfvar_priv.h
+++ b/sys/net/pfvar_priv.h
@@ -322,6 +322,17 @@ enum {
 
 extern struct cpumem *pf_anchor_stack;
 
+enum pf_trans_type {
+       PF_TRANS_NONE,
+       PF_TRANS_MAX
+};
+
+struct pf_trans {
+       LIST_ENTRY(pf_trans)    pft_entry;
+       uint32_t                pft_unit;               /* process id */
+       uint64_t                pft_ticket;
+       enum pf_trans_type      pft_type;
+};
 extern struct task     pf_purge_task;
 extern struct timeout  pf_purge_to;
 

Reply via email to