On Wed, Apr 26, 2023 at 11:45:08PM +0200, Alexandr Nedvedicky wrote:
> Hello,
> 
> On Thu, Apr 27, 2023 at 06:51:52AM +1000, David Gwynne wrote:
> </snip>
> > >     is that kind of check in KASSET() something you have on your mind?
> > >     perhaps I can trade KASSERT() to regular code:
> > > 
> > >   if (t->pft_unit != minor(dev))
> > >           return (EPERM);
> > 
> > i would pass the dev/minor/unit to pf_find_trans() and compare along
> > with the ticket value as a matter of course. returning a different
> > errno if the minor is different is unecessary.
> > 
> 
> something like this?
> 
>     struct pf_trans *
>     pf_find_trans(uint32_t unit, uint64_t ticket)
>     {
>           struct pf_trans     *t;
> 
>           rw_assert_anylock(&pfioctl_rw);
> 
>           LIST_FOREACH(t, &pf_ioctl_trans, pft_entry) {
>                   if (t->pft_ticket == ticket)
>                           break;
>           }

t could be NULL here. just do the unit check inside the loop?

> 
>           if (t->pft_unit != unit)
>                   return (NULL);
> 
>           return (t);
>     }
> 
> just return NULL on unit mismatch.  updated diff is below.

ok once you fix the nit above.

> 
> thanks and
> regards
> sashan
> 
> --------8<---------------8<---------------8<------------------8<--------
> diff --git a/sys/net/pf_ioctl.c b/sys/net/pf_ioctl.c
> index 7ea22050506..ebe1b912766 100644
> --- a/sys/net/pf_ioctl.c
> +++ b/sys/net/pf_ioctl.c
> @@ -117,6 +117,11 @@ void                      pf_qid_unref(u_int16_t);
>  int                   pf_states_clr(struct pfioc_state_kill *);
>  int                   pf_states_get(struct pfioc_states *);
>  
> +struct pf_trans              *pf_open_trans(uint32_t);
> +struct pf_trans              *pf_find_trans(uint32_t, uint64_t);
> +void                  pf_free_trans(struct pf_trans *);
> +void                  pf_rollback_trans(struct pf_trans *);
> +
>  struct pf_rule                pf_default_rule, pf_default_rule_new;
>  
>  struct {
> @@ -168,6 +173,8 @@ int                        pf_rtlabel_add(struct 
> pf_addr_wrap *);
>  void                  pf_rtlabel_remove(struct pf_addr_wrap *);
>  void                  pf_rtlabel_copyout(struct pf_addr_wrap *);
>  
> +uint64_t trans_ticket = 1;
> +LIST_HEAD(, pf_trans)        pf_ioctl_trans = 
> LIST_HEAD_INITIALIZER(pf_trans);
>  
>  void
>  pfattach(int num)
> @@ -293,6 +300,25 @@ pfopen(dev_t dev, int flags, int fmt, struct proc *p)
>  int
>  pfclose(dev_t dev, int flags, int fmt, struct proc *p)
>  {
> +     struct pf_trans *w, *s;
> +     LIST_HEAD(, pf_trans)   tmp_list;
> +     uint32_t unit = minor(dev);
> +
> +     LIST_INIT(&tmp_list);
> +     rw_enter_write(&pfioctl_rw);
> +     LIST_FOREACH_SAFE(w, &pf_ioctl_trans, pft_entry, s) {
> +             if (w->pft_unit == unit) {
> +                     LIST_REMOVE(w, pft_entry);
> +                     LIST_INSERT_HEAD(&tmp_list, w, pft_entry);
> +             }
> +     }
> +     rw_exit_write(&pfioctl_rw);
> +
> +     while ((w = LIST_FIRST(&tmp_list)) != NULL) {
> +             LIST_REMOVE(w, pft_entry);
> +             pf_free_trans(w);
> +     }
> +
>       return (0);
>  }
>  
> @@ -522,7 +548,7 @@ pf_qid_unref(u_int16_t qid)
>  }
>  
>  int
> -pf_begin_rules(u_int32_t *version, const char *anchor)
> +pf_begin_rules(u_int32_t *ticket, const char *anchor)
>  {
>       struct pf_ruleset       *rs;
>       struct pf_rule          *rule;
> @@ -533,20 +559,20 @@ pf_begin_rules(u_int32_t *version, const char *anchor)
>               pf_rm_rule(rs->rules.inactive.ptr, rule);
>               rs->rules.inactive.rcount--;
>       }
> -     *version = ++rs->rules.inactive.version;
> +     *ticket = ++rs->rules.inactive.ticket;
>       rs->rules.inactive.open = 1;
>       return (0);
>  }
>  
>  void
> -pf_rollback_rules(u_int32_t version, char *anchor)
> +pf_rollback_rules(u_int32_t ticket, char *anchor)
>  {
>       struct pf_ruleset       *rs;
>       struct pf_rule          *rule;
>  
>       rs = pf_find_ruleset(anchor);
>       if (rs == NULL || !rs->rules.inactive.open ||
> -         rs->rules.inactive.version != version)
> +         rs->rules.inactive.ticket != ticket)
>               return;
>       while ((rule = TAILQ_FIRST(rs->rules.inactive.ptr)) != NULL) {
>               pf_rm_rule(rs->rules.inactive.ptr, rule);
> @@ -825,7 +851,7 @@ pf_hash_rule(MD5_CTX *ctx, struct pf_rule *rule)
>  }
>  
>  int
> -pf_commit_rules(u_int32_t version, char *anchor)
> +pf_commit_rules(u_int32_t ticket, char *anchor)
>  {
>       struct pf_ruleset       *rs;
>       struct pf_rule          *rule;
> @@ -834,7 +860,7 @@ pf_commit_rules(u_int32_t version, char *anchor)
>  
>       rs = pf_find_ruleset(anchor);
>       if (rs == NULL || !rs->rules.inactive.open ||
> -         version != rs->rules.inactive.version)
> +         ticket != rs->rules.inactive.ticket)
>               return (EBUSY);
>  
>       if (rs == &pf_main_ruleset)
> @@ -849,7 +875,7 @@ pf_commit_rules(u_int32_t version, char *anchor)
>       rs->rules.inactive.ptr = old_rules;
>       rs->rules.inactive.rcount = old_rcount;
>  
> -     rs->rules.active.version = rs->rules.inactive.version;
> +     rs->rules.active.ticket = rs->rules.inactive.ticket;
>       pf_calc_skip_steps(rs->rules.active.ptr);
>  
>  
> @@ -1142,10 +1168,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int 
> flags, struct proc *p)
>                       return (EACCES);
>               }
>  
> -     if (flags & FWRITE)
> -             rw_enter_write(&pfioctl_rw);
> -     else
> -             rw_enter_read(&pfioctl_rw);
> +     rw_enter_write(&pfioctl_rw);
>  
>       switch (cmd) {
>  
> @@ -1191,7 +1214,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>  
>               NET_LOCK();
>               PF_LOCK();
> -             pq->ticket = pf_main_ruleset.rules.active.version;
> +             pq->ticket = pf_main_ruleset.rules.active.ticket;
>  
>               /* save state to not run over them all each time? */
>               qs = TAILQ_FIRST(pf_queues_active);
> @@ -1212,7 +1235,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>  
>               NET_LOCK();
>               PF_LOCK();
> -             if (pq->ticket != pf_main_ruleset.rules.active.version) {
> +             if (pq->ticket != pf_main_ruleset.rules.active.ticket) {
>                       error = EBUSY;
>                       PF_UNLOCK();
>                       NET_UNLOCK();
> @@ -1243,7 +1266,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>  
>               NET_LOCK();
>               PF_LOCK();
> -             if (pq->ticket != pf_main_ruleset.rules.active.version) {
> +             if (pq->ticket != pf_main_ruleset.rules.active.ticket) {
>                       error = EBUSY;
>                       PF_UNLOCK();
>                       NET_UNLOCK();
> @@ -1290,7 +1313,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>  
>               NET_LOCK();
>               PF_LOCK();
> -             if (q->ticket != pf_main_ruleset.rules.inactive.version) {
> +             if (q->ticket != pf_main_ruleset.rules.inactive.ticket) {
>                       error = EBUSY;
>                       PF_UNLOCK();
>                       NET_UNLOCK();
> @@ -1386,7 +1409,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>                       pf_rule_free(rule);
>                       goto fail;
>               }
> -             if (pr->ticket != ruleset->rules.inactive.version) {
> +             if (pr->ticket != ruleset->rules.inactive.ticket) {
>                       error = EBUSY;
>                       PF_UNLOCK();
>                       NET_UNLOCK();
> @@ -1464,7 +1487,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>                       pr->nr = tail->nr + 1;
>               else
>                       pr->nr = 0;
> -             pr->ticket = ruleset->rules.active.version;
> +             pr->ticket = ruleset->rules.active.ticket;
>               PF_UNLOCK();
>               NET_UNLOCK();
>               break;
> @@ -1486,7 +1509,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>                       NET_UNLOCK();
>                       goto fail;
>               }
> -             if (pr->ticket != ruleset->rules.active.version) {
> +             if (pr->ticket != ruleset->rules.active.ticket) {
>                       error = EBUSY;
>                       PF_UNLOCK();
>                       NET_UNLOCK();
> @@ -1560,7 +1583,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>                       if (ruleset == NULL)
>                               error = EINVAL;
>                       else
> -                             pcr->ticket = ++ruleset->rules.active.version;
> +                             pcr->ticket = ++ruleset->rules.active.ticket;
>  
>                       PF_UNLOCK();
>                       NET_UNLOCK();
> @@ -1610,7 +1633,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>                       goto fail;
>               }
>  
> -             if (pcr->ticket != ruleset->rules.active.version) {
> +             if (pcr->ticket != ruleset->rules.active.ticket) {
>                       error = EINVAL;
>                       PF_UNLOCK();
>                       NET_UNLOCK();
> @@ -1707,7 +1730,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>               TAILQ_FOREACH(oldrule, ruleset->rules.active.ptr, entries)
>                       oldrule->nr = nr++;
>  
> -             ruleset->rules.active.version++;
> +             ruleset->rules.active.ticket++;
>  
>               pf_calc_skip_steps(ruleset->rules.active.ptr);
>               pf_remove_if_empty_ruleset(ruleset);
> @@ -2646,7 +2669,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int flags, 
> struct proc *p)
>                               rs = pf_find_ruleset(ioe->anchor);
>                               if (rs == NULL ||
>                                   !rs->rules.inactive.open ||
> -                                 rs->rules.inactive.version !=
> +                                 rs->rules.inactive.ticket !=
>                                   ioe->ticket) {
>                                       PF_UNLOCK();
>                                       NET_UNLOCK();
> @@ -3022,10 +3045,7 @@ pfioctl(dev_t dev, u_long cmd, caddr_t addr, int 
> flags, struct proc *p)
>               break;
>       }
>  fail:
> -     if (flags & FWRITE)
> -             rw_exit_write(&pfioctl_rw);
> -     else
> -             rw_exit_read(&pfioctl_rw);
> +     rw_exit_write(&pfioctl_rw);
>  
>       return (error);
>  }
> @@ -3244,3 +3264,55 @@ pf_sysctl(void *oldp, size_t *oldlenp, void *newp, 
> size_t newlen)
>  
>       return sysctl_rdstruct(oldp, oldlenp, newp, &pfs, sizeof(pfs));
>  }
> +
> +struct pf_trans *
> +pf_open_trans(uint32_t unit)
> +{
> +     static uint64_t ticket = 1;
> +     struct pf_trans *t;
> +
> +     rw_assert_wrlock(&pfioctl_rw);
> +
> +     t = malloc(sizeof(*t), M_TEMP, M_WAITOK);
> +     memset(t, 0, sizeof(struct pf_trans));
> +     t->pft_unit = unit;
> +     t->pft_ticket = ticket++;
> +
> +     LIST_INSERT_HEAD(&pf_ioctl_trans, t, pft_entry);
> +
> +     return (t);
> +}
> +
> +struct pf_trans *
> +pf_find_trans(uint32_t unit, uint64_t ticket)
> +{
> +     struct pf_trans *t;
> +
> +     rw_assert_anylock(&pfioctl_rw);
> +
> +     LIST_FOREACH(t, &pf_ioctl_trans, pft_entry) {
> +             if (t->pft_ticket == ticket)
> +                     break;
> +     }
> +
> +     if (t->pft_unit != unit)
> +             return (NULL);
> +
> +     return (t);
> +}
> +
> +void
> +pf_free_trans(struct pf_trans *t)
> +{
> +     free(t, M_TEMP, sizeof(*t));
> +}
> +
> +void
> +pf_rollback_trans(struct pf_trans *t)
> +{
> +     if (t != NULL) {
> +             rw_assert_wrlock(&pfioctl_rw);
> +             LIST_REMOVE(t, pft_entry);
> +             pf_free_trans(t);
> +     }
> +}
> diff --git a/sys/net/pfvar.h b/sys/net/pfvar.h
> index cf1e34c36b4..3a7ff6b295c 100644
> --- a/sys/net/pfvar.h
> +++ b/sys/net/pfvar.h
> @@ -822,7 +822,7 @@ struct pf_ruleset {
>               struct {
>                       struct pf_rulequeue     *ptr;
>                       u_int32_t                rcount;
> -                     u_int32_t                version;
> +                     u_int32_t                ticket;
>                       int                      open;
>               }                        active, inactive;
>       }                        rules;
> diff --git a/sys/net/pfvar_priv.h b/sys/net/pfvar_priv.h
> index 38fff6473aa..5af2027733a 100644
> --- a/sys/net/pfvar_priv.h
> +++ b/sys/net/pfvar_priv.h
> @@ -322,6 +322,17 @@ enum {
>  
>  extern struct cpumem *pf_anchor_stack;
>  
> +enum pf_trans_type {
> +     PF_TRANS_NONE,
> +     PF_TRANS_MAX
> +};
> +
> +struct pf_trans {
> +     LIST_ENTRY(pf_trans)    pft_entry;
> +     uint32_t                pft_unit;               /* process id */
> +     uint64_t                pft_ticket;
> +     enum pf_trans_type      pft_type;
> +};
>  extern struct task   pf_purge_task;
>  extern struct timeout        pf_purge_to;
>  

Reply via email to